edmundhui commited on
Commit
c786710
·
verified ·
1 Parent(s): 612ae5b

Create regression_models

Browse files
Files changed (1) hide show
  1. regression_models +17 -0
regression_models ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, BertTokenizer
4
+
5
+ class BERTRegression(nn.Module):
6
+ def __init__(self):
7
+ super(BERTRegression, self).__init__()
8
+ self.bert = BertModel.from_pretrained("bert-base-uncased")
9
+ self.dropout = nn.Dropout(0.1)
10
+ self.linear = nn.Linear(self.bert.config.hidden_size, 1)
11
+
12
+ def forward(self, input_ids, attention_mask):
13
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
14
+ pooled_output = outputs.pooler_output
15
+ pooled_output = self.dropout(pooled_output)
16
+ logits = self.linear(pooled_output)
17
+ return logits.squeeze(-1)