-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
69 lines (50 loc) · 2.31 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#! /bin/python3
import torch
import torch.nn as nn
class BiLSTM(nn.Module):
"""
Implementation of a simple LSTM module, with a fully connected layer at the end.
"""
def __init__(self, embedding_dim, hidden_dim, label_size, batch_size, embedding_weights, dropout=0.2):
super(BiLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.word_embeddings = nn.Embedding.from_pretrained(embedding_weights, freeze=True)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(2 * hidden_dim, label_size)
self.act = nn.Softmax(dim=1)
def forward(self, sentence, src_len, train=True):
embeds = self.word_embeddings(sentence)
packed_embedded = nn.utils.rnn.pack_padded_sequence(embeds, src_len)
packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
hidden = self.dropout(hidden)
fc_output = self.fc(hidden)
outputs = self.act(fc_output)
return outputs
class LSTM(nn.Module):
"""
Implementation of a simple LSTM module, with a fully connected layer at the end.
"""
def __init__(self, embedding_dim, hidden_dim, label_size, batch_size, embedding_weights, dropout=0.2):
super(LSTM, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.word_embeddings = nn.Embedding.from_pretrained(embedding_weights, freeze=True)
self.dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
bidirectional=False,
batch_first=True)
self.fc = nn.Linear(hidden_dim, label_size)
self.act = nn.Softmax(dim=1)
def forward(self, sentence, src_len, train=True):
embeds = self.word_embeddings(sentence)
packed_embedded = nn.utils.rnn.pack_padded_sequence(embeds, src_len)
packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
hidden = hidden.squeeze(dim=0)
hidden = self.dropout(hidden)
dense_outputs = self.fc(hidden)
outputs = self.act(dense_outputs)
return outputs