-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathGRU.lua
279 lines (203 loc) · 7.12 KB
/
GRU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
local GRU, parent
GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent')
function GRU:__init(inputSize, outputSize, rho)
parent.__init(self, rho or 9999)
self.inputSize = inputSize
self.outputSize = outputSize
self.recurrentModule = self:buildModel()
self.modules[1] = self.recurrentModule
self.sharedClones[1] = self.recurrentModule
self.zeroTensor = torch.Tensor()
self.cells = {}
self.gradCells = {}
end
function GRU:buildGate()
local gate = nn.Sequential()
local i2g = nn.Linear(self.inputSize, self.outputSize)
local o2g = nn.Linear(self.outputSize, self.outputSize)
local para = nn.ParallelTable()
para:add(i2g):add(o2g)
gate:add(para)
gate:add(nn.CAddTable())
gate:add(nn.Sigmoid())
return gate
end
function GRU:buildResetGate()
self.resetGate = (self.resetGate == nil and self:buildGate() or self.resetGate)
return self.resetGate
end
function GRU:buildUpdateGate()
self.updateGate = (self.updateGate == nil and self:buildGate() or self.updateGate)
return self.updateGate
end
-- outputCandidate = tanh(W * x + U(r . h[t - 1])))
function GRU:buildOutputCandidate()
local hiddenCandidate = nn.Sequential()
local left = nn.Sequential()
-- select x
left:add(nn.SelectTable(1))
left:add(nn.Linear(self.inputSize, self.outputSize))
local right = nn.Sequential()
-- select (r, y[t - 1])
right:add(nn.NarrowTable(2, 2))
right:add(nn.CMulTable())
right:add(nn.Linear(self.outputSize, self.outputSize))
local para = nn.ConcatTable()
para:add(left):add(right)
hiddenCandidate:add(para)
hiddenCandidate:add(nn.CAddTable())
hiddenCandidate:add(nn.Tanh())
return hiddenCandidate
end
-- input {input, output[t - 1]}
function GRU:buildModel()
self.resetGate = self:buildResetGate()
self.updateGate = self:buildUpdateGate()
self.outputCandidate = self:buildOutputCandidate()
local cell = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(nn.Identity()):add(self.resetGate):add(self.updateGate)
cell:add(concat)
cell:add(nn.FlattenTable())
local seq1 = nn.Sequential()
-- select output[t - 1]
seq1:add(nn.SelectTable(2))
local seq2 = nn.Sequential()
seq2:add(nn.SelectTable(4))
seq2:add(nn.MulConstant(-1, false))
seq2:add(nn.AddConstant(1, false))
local seq3 = nn.Sequential()
seq3:add(nn.NarrowTable(1, 3))
seq3:add(self.outputCandidate)
local concat2 = nn.ConcatTable()
-- input: {x, h[t - 1], r, z}
-- output: h[t - 1] (1 - z) z ~h
concat2:add(seq1)
concat2:add(seq2)
concat2:add(nn.SelectTable(4))
concat2:add(seq3)
cell:add(concat2)
-- cell:add(nn.FlattenTable())
local seq4 = nn.Sequential()
seq4:add(nn.NarrowTable(1, 2))
seq4:add(nn.CMulTable())
local seq5 = nn.Sequential()
seq5:add(nn.NarrowTable(3, 2))
seq5:add(nn.CMulTable())
-- input: {(1 - z) h[t - 1] z ~h}
-- output: {(1 - z) * h[t - i], z * ~h}
local concat3 = nn.ConcatTable()
concat3:add(seq4):add(seq5)
cell:add(concat3)
cell:add(nn.CAddTable())
return cell
end
function GRU:updateOutput(input)
local prevOutput, prevCell
if self.step == 1 then
prevOutput = self.zeroTensor
assert(input:dim() == 1, "only support input with dimension == 1")
self.zeroTensor:resize(self.outputSize):zero()
else
prevOutput = self.output
end
local output
if self.train ~= false then
self:recycle()
local recurrentModule = self:getStepModule(self.step)
-- print{input, prevOutput}
output = recurrentModule:updateOutput{input, prevOutput}
else
output = self.recurrentModule:updateOutput{input, prevOutput}
end
if self.train ~= false then
local input_ = self.inputs[self.step]
self.inputs[self.step] = self.copyInputs
and nn.rnn.recursiveCopy(input_, input)
or nn.rnn.recursiveSet(input_, input)
end
self.outputs[self.step] = output
self.output = output
self.step = self.step + 1
self.gradParametersAccumulated = false
return self.output
end
function GRU:backwardThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
local rho = math.min(self.rho, self.step - 1)
local stop = self.step - rho
if self.fastBackward then
local gradPrevOutput
for step = self.step - 1, math.max(stop, 1), -1 do
local recurrentModule = self:getStepModule(step)
local gradOutput= self.gradOutputs[step]
if gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end
local scale = self.scales[step]
local output = (step == 1) and self.zeroTensor or self.outputs[step - 1]
local inputTable = {self.inputs[step], output}
local gradInputTable = recurrentModule:backward(inputTable, gradOutput, scale)
local gradInput, gradPrevOutput = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
else
local gradInput = self:updateGradInputThroughTime()
self:accGradParametersThroughTime()
return gradInput
end
end
function GRU:updateGradInputThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
local gradInput, gradPrevOutput
local rho = math.min(self.rho, self.step - 1)
local stop = self.step - rho
for step = self.step - 1, math.max(stop, 1), -1 do
local recurrentModule = self:getStepModule(step)
local gradOutput = self.gradOutputs[step]
if gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end
local output = (step == 1) and self.zeroTensor or self.outputs[step - 1]
local inputTable = {self.inputs[step], output}
local gradInputTable = recurrentModule:updateGradInput(inputTable, gradOutput)
gradInput, gradPrevOutput = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
end
function GRU:accGradParametersThroughTime()
local rho = math.min(self.rho, self.step - 1)
local stop = self.step - rho
for step = self.step - 1, math.max(stop, 1), -1 do
local recurrentModule = self:getStepModule(step)
local scale = self.scales[step]
local output = (step == 1) and self.zeroTensor or self.outputs[step - 1]
local inputTable = {self.inputs[step], output}
local gradOutput = (step == self.step - 1) and self.gradOutputs[step] or self._gradOutputs[step]
recurrentModule:accGradParameters(inputTable, gradOutput, scale)
end
self.gradParametersAccumulated = true
return gradInput
end
function GRU:accUpdateGradParametersThroughTime(lr)
local rho = math.min(self.rho, self.step - 1)
local stop = self.step - rho
for step = self.step - 1, math.max(stop, 1), -1 do
local recurrentModule = self:getStepModule(step)
local scale = self.scales[step]
local output = (step == 1) and self.zeroTensor or self.outputs[step - 1]
local inputTable = {self.inputs[step], output}
local gradOutput = (step == self.step - 1) and self.gradOutputs[step] or self._gradOutputs[step]
local gradOutputTable = {self.gradOutputs[step]}
recurrentModule:accUpdateGradParameters(inputTable, gradOutput, lr * scale)
end
return gradInput
end