-
Notifications
You must be signed in to change notification settings - Fork 0
/
extractDeepCD.lua
139 lines (122 loc) · 4.26 KB
/
extractDeepCD.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
require 'cudnn'
require 'cunn'
require 'image'
require 'mattorch'
require 'nn'
require 'stn'
require 'trepl'
require 'xlua'
local pl = require('pl.import_into')()
local cmd = torch.CmdLine()
cmd:option('-dataName', '')
cmd:option('-imageNum', 6)
cmd:option('-batchSize', 128)
cmd:option('-modelDir', './model/')
cmd:option('-networkType', '') -- necessary
local option = cmd:parse(arg)
assert(option.networkType ~= '', 'you should specify the network type')
local dataName = option.dataName
local imageNum = option.imageNum
local batchSize = option.batchSize
local modelDir = option.modelDir
local networkType = option.networkType
local model
if networkType == 'DeepCD_2S' then
model = 'DeepCD_2Stream_liberty.t7'
elseif networkType == 'DeepCD_2S_new' then
model = 'DeepCD_2S_DDU_liberty.t7'
elseif networkType == 'DeepCD_2S_noSTN' then
model = 'DeepCD_noSTN_2Stream.t7'
elseif networkType == 'DeepCD_Sp' then
model = 'DeepCD_Split_liberty.t7'
else
print(' you should define model path first')
os.exit()
end
local network = torch.load(paths.concat(modelDir, model)):cuda()
-----------------------------------------------------------------------------
--Remove mulconstant---------------------------------------------------------
if networkType == 'DeepCD_2S' then
network:get(1):get(2):remove(13)
elseif networkType == 'DeepCD_2S_new' then
network:get(1):get(2):remove(12)
elseif networkType == 'DeepCD_2S_noSTN' then
network:get(1):get(2):remove(12)
elseif networkType == 'DeepCD_Sp' then
network:get(7):get(2):remove(6)
end
print('use '..networkType..' to extract the feature of '..dataName)
for image = 1,imageNum do
print(' image: '..image)
for n = 1,2 do
---------------------------------------------------------------------
--Select table-------------------------------------------------------
if n == 1 then
if image ~= 1 then
if networkType == 'DeepCD_2S' then
network:remove(2)
elseif networkType == 'DeepCD_2S_new' then
network:remove(2)
elseif networkType == 'DeepCD_2S_noSTN' then
network:remove(2)
elseif networkType == 'DeepCD_Sp' then
network:remove(8)
end
end
network:add(nn.SelectTable(1))
else
if networkType == 'DeepCD_2S' then
network:remove(2)
elseif networkType == 'DeepCD_2S_new' then
network:remove(2)
elseif networkType == 'DeepCD_2S_noSTN' then
network:remove(2)
elseif networkType == 'DeepCD_Sp' then
network:remove(8)
end
network:add(nn.SelectTable(2))
end
---------------------------------------------------------------------
--Load patch---------------------------------------------------------
local fileContent = mattorch.load(paths.concat('gooddata', dataName, 'patch', image, 'R_64_patch.mat'))
--Due to the mattorch format, the loaded matrix should be transposed
--Remember to use "clone()" inside the for loopi
frame = fileContent.frame:clone()
local patch = fileContent.local_norm_patch_32:clone()
local patchNum = patch:size(1)
for t = 1,patchNum do
local tmp = patch[{t, 1, {}, {}}]:clone()
patch[{t, 1, {}, {}}] = tmp:t():clone()
end
patch = patch:float() -- This is an important step since net:forward can only input FloatTensor
---------------------------------------------------------------------
--Feed into network--------------------------------------------------
local descriptor
local descriptorSplit
if n == 1 then
descriptor = torch.Tensor(patchNum, 128)
descriptorSplit = descriptor:split(batchSize)
else
descriptor = torch.Tensor(patchNum, 256)
descriptorSplit = descriptor:split(batchSize)
end
for i, v in ipairs(patch:split(batchSize)) do
v = v:cuda()
if n == 1 then
descriptorSplit[i]:copy(network:forward(v))
else
descriptorSplit[i]:copy(network:forward(v))
end
end
-------------------------------------------------------------------------
--Save the output--------------------------------------------------------
if n == 1 then
descriptorLead = descriptor:clone()
else
local descriptorComplete = descriptor:clone()
local outputContent = {frame = frame, descriptor_lead = descriptorLead, descriptor_complete = descriptorComplete}
mattorch.save(paths.concat('gooddata', dataName, 'patch', image, 'R_64_'..networkType..'.mat'), outputContent)
end
collectgarbage()
end
end