forked from soumith/imagenet-multiGPU.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.lua
115 lines (108 loc) · 3.46 KB
/
util.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
require 'cunn'
local ffi=require 'ffi'
function makeDataParallel(model, nGPU)
if nGPU > 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1)
for i=1, nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
end
cutorch.setDevice(opt.GPU)
return model
end
local function cleanDPT(module)
-- This assumes this DPT was created by the function above: all the
-- module.modules are clones of the same network on different GPUs
-- hence we only need to keep one when saving the model to the disk.
local newDPT = nn.DataParallelTable(1)
cutorch.setDevice(opt.GPU)
newDPT:add(module:get(1), opt.GPU)
return newDPT
end
function saveDataParallel(filename, model)
if torch.type(model) == 'nn.DataParallelTable' then
torch.save(filename, cleanDPT(model))
elseif torch.type(model) == 'nn.Sequential' then
local temp_model = nn.Sequential()
for i, module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' then
temp_model:add(cleanDPT(module))
else
temp_model:add(module)
end
end
torch.save(filename, temp_model)
else
error('This saving function only works with Sequential or DataParallelTable modules.')
end
end
function loadDataParallel(filename, nGPU)
if opt.backend == 'cudnn' then
require 'cudnn'
end
local model = torch.load(filename)
if torch.type(model) == 'nn.DataParallelTable' then
return makeDataParallel(model:get(1):float(), nGPU)
elseif torch.type(model) == 'nn.Sequential' then
for i,module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' then
model.modules[i] = makeDataParallel(module:get(1):float(), nGPU)
end
end
return model
else
error('The loaded model is not a Sequential or DataParallelTable module.')
end
end
function saveRNGState(filename, donkeys, num_donkeys)
local state = {}
state.cpu = torch.getRNGState()
state.gpu = {}
state.donkeys = {}
for i = 1, cutorch.getDeviceCount() do
state.gpu[i] = cutorch.getRNGState(i)
end
if num_donkeys > 0 then
donkeys:synchronize()
donkeys:specific(true)
for i = 1, num_donkeys do
donkeys:addjob(i,
function()
return __threadid, torch.getRNGState()
end,
function(idx, thread_state)
state.donkeys[idx] = thread_state
end
)
end
donkeys:synchronize()
donkeys:specific(false)
end
torch.save(filename, state)
end
function loadRNGState(filename, donkeys, num_donkeys)
local state = torch.load(filename)
assert(cutorch.getDeviceCount() == #state.gpu, "Mismatch between number of GPUs and GPU RNG states")
assert(num_donkeys == #state.donkeys, "Mismatch in donkey number")
torch.setRNGState(state.cpu)
for i = 1, cutorch.getDeviceCount() do
cutorch.setRNGState(state.gpu[i], i)
end
if num_donkeys > 0 then
donkeys:synchronize()
donkeys:specific(true)
for i = 1, num_donkeys do
donkeys:addjob(i,
function()
torch.setRNGState(state.donkeys[__threadid])
end
)
end
donkeys:synchronize()
donkeys:specific(false)
end
end