-
Notifications
You must be signed in to change notification settings - Fork 0
/
BLSTM.lua
90 lines (70 loc) · 2.43 KB
/
BLSTM.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
require 'nn'
require 'nngraph'
require 'rnn'
require 'cudnn'
local layer, parent = torch.class('nn.BLSTM', 'nn.Container')
function layer:__init(inputDim, hiddenDim, isCUDNN)
if isCUDNN then
self.forwardModule = cudnn.LSTM(inputDim, hiddenDim, 1)
self.backwardModule = cudnn.LSTM(inputDim, hiddenDim, 1)
else
self.forwardModule = nn.SeqLSTM(inputDim, hiddenDim)
self.backwardModule = nn.SeqLSTM(inputDim, hiddenDim)
end
local backward = nn.Sequential()
backward:add(nn.SeqReverseSequence(1))
backward:add(self.backwardModule)
backward:add(nn.SeqReverseSequence(1))
local concat = nn.ConcatTable()
concat:add(self.forwardModule):add(backward)
local blstm = nn.Sequential()
blstm:add(concat)
blstm:add(nn.JoinTable(3))
parent.__init(self)
self.output = torch.Tensor()
self.gradInput = torch.Tensor()
self.module = blstm
self.modules[1] = blstm
end
function layer:updateOutput(input)
self.output = self.module:updateOutput(input)
return self.output
end
function layer:updateGradInput(input, gradOutput)
self.gradInput = self.module:updateGradInput(input, gradOutput)
return self.gradInput
end
function layer:accGradParameters(input, gradOutput, scale)
self.module:accGradParameters(input, gradOutput, scale)
end
function layer:accUpdateGradParameters(input, gradOutput, lr)
self.module:accUpdateGradParameters(input, gradOutput, lr)
end
function layer:sharedAccUpdateGradParameters(input, gradOutput, lr)
self.module:sharedAccUpdateGradParameters(input, gradOutput, lr)
end
function layer:__tostring__()
if self.module.__tostring__ then
return torch.type(self) .. ' @ ' .. self.module:__tostring__()
else
return torch.type(self) .. ' @ ' .. torch.type(self.module)
end
end
BLSTM = {}
function BLSTM.createBLSTM(inputDim, hiddenDim, isCUDNN)
if isCUDNN then
return cudnn.BLSTM(inputDim, hiddenDim, 1)
end
local forwardmodule
local backwardmodule
forwardmodule = nn.SeqLSTM(inputDim, hiddenDim)
backwardmodule = nn.SeqLSTM(inputDim, hiddenDim)
local input = nn.Identity()()
local forward = forwardmodule(input)
local backward = nn.SeqReverseSequence(1)(input)
backward = backwardmodule(backward)
backward = nn.SeqReverseSequence(1)(backward)
local output = nn.JoinTable(3)({forward, backward})
return nn.gModule({input}, {output})
end
return BLSTM