-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpRBM.m
97 lines (64 loc) · 2.32 KB
/
bpRBM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
%function [vishid, visbiases, hidbiases] = bpRBM(trainSet, hiddenLayerSize)
%set initial params
%%%%%%%% learning params %%%%%%%%%%%%%%%%
inputSize = size(trainSet,1);
learnrate = 0.01;
momentum = 0.9;
weightdecay = 0.001;
%%%%%%%%% algorithm params %%%%%%%%%%%%%
numepochs = 50;
batchSize = 50;
numBatches = floor(size(trainSet,2)/batchSize);
vishidinc = 0;
visbiasinc = 0;
hidbiasinc = 0;
%%%%%%%%% initialize weights %%%%%%%%%%%%%%
vishid = 0.01*randn(hiddenLayerSize, inputSize);
visbiases = zeros(inputSize,1);
hidbiases = zeros(hiddenLayerSize,1);
%%%%%%%%%% maybe make trainSet between 0 and 1??? %%%%%
% trainSet = trainSet - min(min(trainSet));
% trainSet = trainSet./max(max(trainSet));
for epoch=1:numepochs
errsum = 0;
tic;
%%%%%%%% simulated annealing with momentum and learning rate %%%%%%%%
if epoch>5,
momentum = 0.9;
else
momentum = 0.5;
end;
if epoch > 25
learnrate = 0.001;
elseif epoch > 100
learnrate = 0.005;
end
for ex=1:numBatches
if (mod(ex, 1000) == 0)
fprintf('ex = %d\n', ex);
end
data = trainSet(:,((ex-1)*batchSize+1):(ex*batchSize));
%%%%%%% run data through network %%%%%%%%%%%%%%%%%%%%%
hidact = vishid*data ;% + repmat(hidbiases,1,batchSize);
% output = vishid'*hidact ;% + repmat(visbiases,1,batchSize);
%hidprobs = hidact;
%%% old code for stochastic update %%%
hidact = 1./(1 + exp( -hidact ));
%hidstates = hidprobs > rand(size(hidprobs));
output = vishid'*hidact ;% + repmat(visbiases,1,batchSize);
%%%%%%% calc error and update scores %%%%%%%%%%%%%%%%%%
delta = data - output;
error = sum(sum(delta.^2));
errsum = errsum + error;
vishidinc = momentum*vishidinc + ...
learnrate*(hidact*(delta') - weightdecay*vishid);
vishid = vishid + vishidinc;
%%%%% normalize weights? hmm... %%%%%%%%%%%%%%%
%vishid = vishid./(max(max(vishid)));
end
% Output Statistics
fprintf('Epoch %d\t Error %f\t W-Norm %f\t Time %f\n', ...
epoch, errsum, norm(vishid(:)), toc);
plotrf(vishid', floor(inputSize^.5), 'temp');
end
%end