-
Notifications
You must be signed in to change notification settings - Fork 6
/
netinit_mnist_dag.m
180 lines (162 loc) · 8.06 KB
/
netinit_mnist_dag.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
function [net] = netinit_mnist_dag(varargin)
opts.useDropout = false;
opts.useBnorm = true;
opts.modelDir = '.';
opts.weightDecay = 1 ;
opts.imchannels = 1;
opts.classWeights = [];
opts.warmstart = false;
opts.loadnetfrom = [];
opts = vl_argparse(opts, varargin) ;
opts.initBias = 0 ;
rng('default');
rng(0) ;
angle_n = 17;
lr = [1 0.5] ;
dropout_rate = 0.75;
leak = 0.0;
conv_filters = zeros(1,5);
% BLOCK 1: CS1
conv_filters(end,:) = [9,9,1,6,1];
%%%% MAX POOLING
% BLOCK 2: CS2
conv_filters(end+1,:) = [9,9,6,16,2];
%%%% MAX POOLING
% BLOCK 3: CS3
conv_filters(end+1,:) = [9,9,16,32,2];
conv_filters(end+1,:) = [1,1,32,128,1];
conv_filters(end+1,:) = [1,1,128,10,1];
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% BLOCK 1: CS1
net = dagnn.DagNN() ;
id = 1; nid = 1;
convBlock = dagnn.Convsteer('size', conv_filters(id,:), 'stride', 1,'angle_n',angle_n, 'hasBias', ...
true, 'pad', floor(conv_filters(id,1)/2)) ;
net.addLayer(['convsteer' num2str(id)], convBlock, {'input'}, {['x' num2str(nid)]}, ...
{['l' num2str(nid) 'f'], ['l' num2str(nid) 'b']}) ;
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).learningRate = lr(1);
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).weightDecay = opts.weightDecay;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = lr(2);
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
nid = nid + 1;
% --- RELU LAYER
net.addLayer(['relu' num2str(id)],dagnn.ReLU('leak',leak),['x' num2str(nid-1)],['x' num2str(nid)]);
split_nid = nid;
nid = nid + 1;
% --- POOLANGLE LAYER
net.addLayer(['poolangle_v' num2str(id)],dagnn.PoolingAngle('bins',0,'angle_n',angle_n),['x' num2str(split_nid)],['x' num2str(nid)]);
nid = nid + 1;
% --- POOL LAYER
poolBlock = dagnn.PoolingExternal('poolSize', 2, 'pad', [0 0 0 0]) ;
net.addLayer(['pool_v' num2str(id)], poolBlock,['x' num2str(nid-1)],['x' num2str(nid)]);%['x' num2str(nid)]);
nid = nid + 1;
% --- BNORM LAYER
net.addLayer(['bnormangle' num2str(id)],dagnn.BatchNormAngle('numChannels',conv_filters(id,4)),...
['x' num2str(nid-1)],['x' num2str(nid)], ...
{['l' num2str(nid) 'm'], ['l' num2str(nid) 'b'],['l' num2str(nid) 'x']});
nid = nid + 1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% BLOCK 2: CS2
id = id + 1;
convBlock = dagnn.Convsteer('size', conv_filters(id,:), 'stride', 1,'angle_n',angle_n, 'hasBias', ...
true, 'pad', floor(conv_filters(id,1)/2)) ;
net.addLayer(['convsteer' num2str(id)], convBlock, {['x' num2str(nid-1)]}, {['x' num2str(nid)]}, ...
{['l' num2str(nid) 'f'], ['l' num2str(nid) 'b']}) ;
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).learningRate = lr(1);
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).weightDecay = opts.weightDecay;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = lr(2);
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
nid = nid + 1;
% --- RELU LAYER
net.addLayer(['relu' num2str(id)],dagnn.ReLU('leak',leak),['x' num2str(nid-1)],['x' num2str(nid)]);
split_nid = nid;
nid = nid + 1;
% --- POOLANGLE LAYER
net.addLayer(['poolangle_v' num2str(id)],dagnn.PoolingAngle('bins',0,'angle_n',angle_n),['x' num2str(split_nid)],['x' num2str(nid)]);
nid = nid + 1;
% --- POOL LAYER
poolBlock = dagnn.PoolingExternal('poolSize', 2, 'pad', [0 0 0 0]) ;
net.addLayer(['pool_v' num2str(id)], poolBlock,['x' num2str(nid-1)],['x' num2str(nid)]);%['x' num2str(nid)]);
nid = nid + 1;
% --- BNORM LAYER
net.addLayer(['bnormangle' num2str(id)],dagnn.BatchNormAngle('numChannels',conv_filters(id,4)),...
['x' num2str(nid-1)],['x' num2str(nid)], ...
{['l' num2str(nid) 'm'], ['l' num2str(nid) 'b'],['l' num2str(nid) 'x']});
nid = nid + 1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% FC1 steerable
id = id + 1;
convBlock = dagnn.Convsteer('size', conv_filters(id,:), 'stride', 1,'angle_n',angle_n, 'hasBias', ...
true, 'pad', 1) ;
net.addLayer(['convsteer' num2str(id)], convBlock, {['x' num2str(nid-1)]}, {['x' num2str(nid)]}, ...
{['l' num2str(nid) 'f'], ['l' num2str(nid) 'b']}) ;
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).learningRate = lr(1);
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).weightDecay = opts.weightDecay;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = lr(2);
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
nid = nid + 1;
% --- POOLANGLE LAYER
net.addLayer(['poolangle' num2str(id)],dagnn.PoolingAngle('bins',1,'angle_n',angle_n),['x' num2str(nid-1)],['x' num2str(nid)]);
nid = nid + 1;
% --- RELU LAYER
net.addLayer(['relu' num2str(id)],dagnn.ReLU('leak',leak),['x' num2str(nid-1)],['x' num2str(nid)]);
nid = nid + 1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% FC2
id = id + 1;
convBlock = dagnn.Conv('size', conv_filters(id,:)*diag([1 1 1 1 1]), 'stride', 1, 'hasBias', ...
true, 'pad', 0) ;
net.addLayer(['conv' num2str(id)], convBlock, {['x' num2str(nid-1)]}, {['x' num2str(nid)]}, ...
{['l' num2str(nid) 'f'], ['l' num2str(nid) 'b']}) ;
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).learningRate = lr(1);
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).weightDecay = opts.weightDecay;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = lr(2);
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
nid = nid + 1;
% --- RELU LAYER
net.addLayer(['relu' num2str(id)],dagnn.ReLU('leak',0.1),['x' num2str(nid-1)],['x' num2str(nid)]);
nid = nid + 1;
% --- BNORM LAYER
net.addLayer(['bnorm' num2str(id)],dagnn.BatchNorm('numChannels',conv_filters(id,4)),...
['x' num2str(nid-1)],['x' num2str(nid)], ...
{['l' num2str(nid) 'm'], ['l' num2str(nid) 'b'],['l' num2str(nid) 'x']});
nid = nid + 1;
% --- DROPOUT LAYER
dropoutBlock = dagnn.DropOut('rate', dropout_rate) ;
net.addLayer(['dropout' num2str(id)], dropoutBlock,['x' num2str(nid-1)],['x' num2str(nid)]);%['x' num2str(nid)]);
nid = nid + 1;
% --- CONV LAYER
id = id + 1;
convBlock = dagnn.Conv('size', conv_filters(id,:), 'stride', 1, 'hasBias', ...
true, 'pad', 0) ;
net.addLayer(['conv' num2str(id)], convBlock, {['x' num2str(nid-1)]}, {'prediction'}, ...
{['l' num2str(nid) 'f'], ['l' num2str(nid) 'b']}) ;
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).learningRate = lr(1);
net.params(net.getParamIndex(['l' num2str(nid) 'f'])).weightDecay = opts.weightDecay;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = lr(2);
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
net.addLayer('objective', ...
dagnn.Loss('loss', 'softmaxlog'), ...
{'prediction', 'label'}, 'objective') ;
net.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
{'prediction','label'}, 'top1err') ;
for nid = 1:numel(net.layers)
if strcmp(net.layers(nid).name(1:2), 'bn')
net.params(net.getParamIndex(['l' num2str(nid) 'm'])).learningRate = 1;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).learningRate = 1;
net.params(net.getParamIndex(['l' num2str(nid) 'x'])).learningRate = 0.05;
net.params(net.getParamIndex(['l' num2str(nid) 'm'])).weightDecay = 0;
net.params(net.getParamIndex(['l' num2str(nid) 'b'])).weightDecay = 0;
end
end
% Meta parameters
net.meta.inputSize = [28 28 1] ;
net.meta.trainOpts.learningRate = [0.3*ones(1,10) 0.1*ones(1,10) 0.03*ones(1,10) 0.01*ones(1,10) 0.003*ones(1,10) 0.001*ones(1,10) 0.0003*ones(1,10)] ;
net.meta.trainOpts.weightDecay = 1e-2 ;
net.meta.trainOpts.numEpochs = numel(net.meta.trainOpts.learningRate) ;
net.meta.trainOpts.batchSize = 600 ;
net.initParams();