-
Notifications
You must be signed in to change notification settings - Fork 96
/
gmodule.lua
514 lines (474 loc) · 17.4 KB
/
gmodule.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
local nesting = require('nngraph.nesting')
local utils = require('nngraph.utils')
local istensor = torch.isTensor
local istable = utils.istable
local istorchclass = utils.istorchclass
local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
-- Check if we can bypass the allocation, for the special case where all
-- gradOutputs but one are zero tensors with an underlying one-element
-- storage. Note that for the case that we
-- cannot bypass it, this check will only be performed once
if not node.data.gradOutputBuffer then
local count = 0
local idx = 1
-- Count how many gradOutput are tensors of 1 element filled with zero
for i=1,#gradOutput do
local zero = torch.isTensor(gradOutput[i]) and
gradOutput[i]:storage() ~= nil and
gradOutput[i]:storage():size() == 1 and
gradOutput[i]:storage()[1] == 0
if not zero then
idx = i
count = count + 1
end
end
if count < 2 then
-- Return the only non-zero one, or the first one
-- if they are all zero
return gradOutput[idx]
end
end
node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
nesting.resizeNestedAs(gobuff, gradOutput[1])
nesting.copyNested(gobuff, gradOutput[1])
for i=2,#gradOutput do
nesting.addNestedTo(gobuff, gradOutput[i])
end
gradOutput = gobuff
else
gradOutput = gradOutput[1]
end
return gradOutput
end
-- The gModule allows to have a general non-cyclic graph of of modules.
--
-- Each node of the graph can have multiple inputs.
-- The order of inputs is remembered in node.data.mapindex.
--
-- Each node have only one output.
-- The output can be also a table.
-- To route parts of the outputted table to different modules,
-- use the node:split(nOutputs) function.
-- The split will create subnodes with narrowed output.
--
-- Implementation details:
-- The node.data.input holds a list of inputs.
-- If a module expects only one input, the node.data.input[1] is used.
--
-- The node.data.gradOutput holds the to-be-summed gradOutputs.
-- Each node has only one output. So we need only one gradOutput.
local gModule, parent = torch.class('nn.gModule','nn.Container')
function gModule:__init(inputs,outputs)
parent.__init(self)
-- the graph is defined backwards, we have the output modules as input here
-- we will define a dummy output node that connects all output modules
-- into itself. This will be the output for the forward graph and
-- input point for the backward graph
local node
local outnode = nngraph.Node({input={}})
for i = 1, utils.tableMaxN(outputs) do
node = outputs[i]
if torch.typename(node) ~= 'nngraph.Node' then
error(utils.expectingNodeErrorMessage(node, 'outputs', i))
end
outnode:add(node, true)
end
for i = 1, utils.tableMaxN(inputs) do
node = inputs[i]
if torch.typename(node) ~= 'nngraph.Node' then
error(utils.expectingNodeErrorMessage(node, 'inputs', i))
end
end
-- We add also a dummy input node.
-- The input node will be split to feed the passed input nodes.
local innode = nngraph.Node({input={}})
assert(#inputs > 0, "no inputs are not supported")
if #inputs == 1 then
inputs[1]:add(innode,true)
else
local splits = {innode:split(#inputs)}
for i = 1, #inputs do
assert(#inputs[i].children == 0, "an input should have no inputs")
end
for i = 1, #inputs do
inputs[i]:add(splits[i],true)
end
end
-- the backward graph (bg) is for gradients
-- the forward graph (fg) is for function evaluation
self.bg = outnode:graph()
self.fg = self.bg:reverse()
-- the complete graph is constructed
-- now regenerate the graphs with the additional nodes
local roots = self.fg:roots()
-- if there are more than one root in the forward graph, then make sure that
-- extra roots are parameter nodes
if #roots > 1 then
local innodeRoot = nil
-- first find our innode
for _, root in ipairs(roots) do
if root.data == innode.data then
assert(innodeRoot == nil, 'more than one matching input node found in leaves')
innodeRoot = root
else
assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
assert(torch.typename(root.data.module) == 'nnop.Parameters',
'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
end
end
assert(innodeRoot ~= nil, 'input node not found among roots')
self.innode = innodeRoot
else
assert(#self.fg:roots() == 1, "expecting only one start")
self.innode = self.fg:roots()[1]
end
assert(self.innode.data == innode.data, "expecting the forward innode")
self.outnode = outnode
self.verbose = false
self.nInputs = #inputs
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
-- iteratare over all nodes: check, tag and add to container
for i,node in ipairs(self.forwardnodes) do
-- check for unused inputs or unused split() outputs
if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
local nUnused = node.data.nSplitOutputs - #node.children
local debugLabel = node.data.annotations._debugLabel
local errStr =
"%s of split(%s) outputs from the node declared at %s are unused"
error(string.format(errStr, nUnused, node.data.nSplitOutputs,
debugLabel))
end
-- Check whether any nodes were defined as taking this node as an input,
-- but then left dangling and don't connect to the output. If this is
-- the case, then they won't be present in forwardnodes, so error out.
for successor, _ in pairs(node.data.reverseMap) do
local successorIsInGraph = false
-- Only need to the part of forwardnodes from i onwards, topological
-- sort guarantees it cannot be in the first part.
for j = i+1, #self.forwardnodes do
-- Compare equality of data tables, as new Node objects have been
-- created by processes such as topoological sort, but the
-- underlying .data table is shared.
if self.forwardnodes[j].data == successor.data then
successorIsInGraph = true
break
end
end
local errStr =
"node declared on %s does not connect to gmodule output"
assert(successorIsInGraph,
string.format(errStr, successor.data.annotations._debugLabel))
end
-- set data.forwardNodeId for node:label() output
node.data.forwardNodeId = node.id
-- add module to container
if node.data.module then
self:add(node.data.module)
end
end
self.output = nil
self.gradInput = nil
if #self.outnode.children > 1 then
self.output = self.outnode.data.input
end
end
function gModule:replace(callback)
local out = callback(self)
local revmodules = {}
for i,m in ipairs(self.modules) do
revmodules[m] = i
end
for i,node in ipairs(self.forwardnodes) do
if node.data.module then
local m = node.data.module
node.data.module = m:replace(callback)
self.modules[revmodules[m]] = node.data.module
end
end
return out
end
function gModule:map(gm, func)
for i,node in ipairs(self.forwardnodes) do
local gmnode = gm.forwardnodes[i]
assert(gmnode, 'trying to map another gModule with a different structure')
if node.data.module then
assert(gmnode.data.module, 'trying to map another gModule with a different structure')
func(node.data.module, gmnode.data.module)
end
end
end
--[[ Recursively applies type(type_str) to any tensors in the argument. If the
argument is a tensor, type(type_str) is applied; if the argument is an array,
this function recurses into it. ]]
local function recursiveType(param, type_str)
if torch.type(param) == 'table' then
for i = 1, #param do
param[i] = recursiveType(param[i], type_str)
end
elseif torch.typename(param) and
torch.typename(param):find('torch%..+Tensor') then
param = param:type(type_str)
end
return param
end
function gModule:type(type, tensorCache)
if not type then
return self._type
end
tensorCache = tensorCache or {}
local function applyTypeToTable(table)
for key, value in pairs(table) do
table[key] = recursiveType(table[key], type)
end
end
-- Convert any stored data in self, and in the in and out nodes
applyTypeToTable(self)
if self.innode then applyTypeToTable(self.innode.data) end
if self.outnode then applyTypeToTable(self.outnode.data) end
-- Loop through modules and convert data
for _, m in ipairs(self.modules) do
m:type(type, tensorCache)
end
for i,node in ipairs(self.backwardnodes) do
if node.data.gradOutputBuffer ~= nil then
node.data.gradOutputBuffer =
recursiveType(node.data.gradOutputBuffer, type)
end
for k, child in ipairs(node.children) do
applyTypeToTable(child.data)
end
end
for i,node in ipairs(self.forwardnodes) do
if node.data.input ~= nil then
node.data.input = recursiveType(node.data.input, type)
end
for k, child in ipairs(node.children) do
applyTypeToTable(child.data)
end
end
self._type = type
return self
end
function gModule:updateOutput(input)
return self:runForwardFunction('updateOutput',input)
end
function gModule:clearState()
local ret = parent.clearState(self)
for _,node in ipairs(self.backwardnodes) do
node.data.gradOutput = nil
node.data.gradOutputBuffer = nil
end
for _,node in ipairs(self.forwardnodes) do
node.data.input = nil
end
return ret
end
function gModule:runForwardFunction(func,input)
if type(func) == "string" then
local func_name = func
func = function(module,input) return module[func_name](module,input) end
end
-- For backward compatibility, we allow self.nInputs to be missing.
local nInputs = self.nInputs or #self.innode.children
-- We see the input as a list of inputs.
if nInputs <= 1 then
input={input}
elseif type(input) ~= "table" then
error(string.format("expecting table of %s inputs", nInputs))
end
local function neteval(node)
local function propagate(node,x)
for i,child in ipairs(node.children) do
child.data.input = child.data.input or {}
local mapindex = child.data.mapindex[node.data]
assert(not child.data.input[mapindex], "each input should have one source")
child.data.input[mapindex] = x
end
end
if node.data.selectindex then
assert(not node.data.module, "the selectindex-handling nodes should have no module")
local input = node.data.input
assert(#input == 1, "only the splitted node should be the input")
assert(istable(input[1]), "the input for a split should be a table")
input = input[1][node.data.selectindex]
propagate(node,input)
else
local input = node.data.input
-- a parameter node is captured
if input == nil and node.data.module ~= nil then
input = {}
end
if #input == 1 then
input = input[1]
end
-- forward through this node
-- If no module is present, the node behaves like nn.Identity.
local output
if not node.data.module then
output = input
else
output = func(node.data.module,input)
end
if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
error(string.format("split(%s) cannot split %s outputs",
node.data.nSplitOutputs,
#output))
end
-- propagate the output to children
propagate(node,output)
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local innode = self.innode
if #input ~= nInputs then
error(string.format('Got %s inputs instead of %s', #input, nInputs))
end
-- first clear the input states
for _,node in ipairs(self.forwardnodes) do
local input = node.data.input
while input and #input>0 do
table.remove(input)
end
end
-- Set the starting input.
-- We do copy instead of modifying the passed input.
innode.data.input = innode.data.input or {}
for i, item in ipairs(input) do
innode.data.input[i] = item
end
-- the run forward
for i,node in ipairs(self.forwardnodes) do
neteval(node)
end
self.output = self.outnode.data.input
if #self.outnode.children == 1 then
self.output = self.output[1]
end
return self.output
end
function gModule:updateGradInput(input,gradOutput)
local function neteval(node)
if node.data.selectindex then
assert(not node.data.module, "the selectindex-handling nodes should have no module")
assert(#node.children == 1, "only the splitted node should be the input")
local child = node.children[1]
local go = getTotalGradOutput(node)
child.data.gradOutput = child.data.gradOutput or {}
assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
-- The data.gradOutput holds the to-be-summed gradients.
child.data.gradOutput[1] = child.data.gradOutput[1] or {}
assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
child.data.gradOutput[1][node.data.selectindex] = go
else
local gradOutput = getTotalGradOutput(node)
-- updateGradInput through this node
-- If no module is present, the node behaves like nn.Identity.
local gradInput
if not node.data.module then
gradInput = gradOutput
else
local input = node.data.input
-- a parameter node is captured
if input == nil and node.data.module ~= nil then
input = {}
end
if #input == 1 then
input = input[1]
end
local module = node.data.module
gradInput = module:updateGradInput(input,gradOutput)
end
-- propagate the output to children
for i,child in ipairs(node.children) do
child.data.gradOutput = child.data.gradOutput or {}
local mapindex = node.data.mapindex[child.data]
local gi
if #node.children == 1 then
gi = gradInput
else
gi = gradInput[mapindex]
end
table.insert(child.data.gradOutput,gi)
end
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
if #outnode.children > 1 and #gradOutput ~= #outnode.children then
error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
for _,node in ipairs(self.backwardnodes) do
local gradOutput = node.data.gradOutput
while gradOutput and #gradOutput >0 do
table.remove(gradOutput)
end
end
-- Set the starting gradOutput.
outnode.data.gradOutput = outnode.data.gradOutput or {}
outnode.data.gradOutput[1] = gradOutput
for i,node in ipairs(self.backwardnodes) do
neteval(node)
end
assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
self.gradInput = self.innode.data.gradOutput[1]
return self.gradInput
end
function gModule:accGradParameters(input,gradOutput,lr)
local function neteval(node)
if node.data.module then
local module = node.data.module
local gradOutput = node.data.gradOutput[1]
if #node.data.gradOutput > 1 then
gradOutput = node.data.gradOutputBuffer
end
local input = node.data.input
-- a parameter node is captured
if input == nil and node.data.module ~= nil then
input = {}
end
if #input == 1 then
input = input[1]
end
-- accGradParameters through this node
module:accGradParameters(input,gradOutput,lr)
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
if #outnode.children > 1 and #gradOutput ~= #outnode.children then
error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
for i,node in ipairs(self.backwardnodes) do
neteval(node)
end
end
function gModule:read(file)
local data = file:readObject()
for k, v in pairs(data) do
self[k] = v
end
-- Initialize the modules table if necessary.
if not self.modules then
self.modules = {}
for _, node in ipairs(self.forwardnodes) do
if node.data.module then
table.insert(self.modules, node.data.module)
end
end
end
end
function gModule:__tostring__()
return self.name or torch.type(self)
end