-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
78 lines (66 loc) · 3.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import BertModel
from torch import nn, optim
class BertSentimentClassifier(nn.Module):
def __init__(self, bert_model, num_classes):
"""
@param bert: a BertModel object
@param num_classes: number of target labels
"""
super(BertSentimentClassifier, self).__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.dropout = nn.Dropout(0.3)
self.out = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
"""
Feed input to the model to compute output.
@param input_ids (torch.Tensor): an input tensor with shape (batch_size, max_length)
@param attention_mask (torch.Tensor): a tensor that hold attention mask information
with shape (batch_size, max_length)
@return out (logits) (torch.Tensor): an output tensor with shape (batch_size, num_labels)
"""
# pooling strategy over the final embeddings
_, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
output = self.dropout(pooled_output)
return self.out(output)
class BertSequentialSentimentClassifier(nn.Module):
def __init__(self, bert_model, num_classes, freeze_bert=False):
"""
@param bert: a BertModel object
@param num_classes: number of target labels
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
super(BertSequentialSentimentClassifier, self).__init__()
LINEAR_HIDDEN_SIZE = 50
self.bert = BertModel.from_pretrained(bert_model)
self.classifier = nn.Sequential(
nn.Linear(self.bert.config.hidden_size, LINEAR_HIDDEN_SIZE),
nn.ReLU(),
# nn.Dropout(0.3),
nn.Linear(LINEAR_HIDDEN_SIZE, num_classes)
)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
"""
Feed input to the model to compute output.
@param input_ids (torch.Tensor): an input tensor with shape (batch_size, max_length)
@param attention_mask (torch.Tensor): a tensor that hold attention mask information
with shape (batch_size, max_length)
@return out (logits) (torch.Tensor): an output tensor with shape (batch_size, num_labels)
"""
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task
# On the output of the final (12th) transformer, only the first embedding (corresponding
# to the [CLS] token) is used by the classifier.
"""
"The first token of every sequence is always a special classification token ([CLS]). The final
hidden state corresponding to this token is used as the aggregate sequence representation for
classification tasks." (from the BERT paper)
"""
last_hidden_state_cls = output[0][:, 0, :]
# Feed input to classifier to compute logits
logits = self.classifier(last_hidden_state_cls)
return logits