-
Notifications
You must be signed in to change notification settings - Fork 0
/
MLPclassificationLoss.m
87 lines (73 loc) · 2.54 KB
/
MLPclassificationLoss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function [f,g] = MLPmultipleRegressionLoss(w,X,y,nHidden,nLabels)
[nInstances,nVars] = size(X);
% Form Weights
inputWeights = reshape(w(1:nVars*nHidden(1)),nVars,nHidden(1));
offset = nVars*nHidden(1);
for h = 2:length(nHidden)
hiddenWeights{h-1} = reshape(w(offset+1:offset+nHidden(h-1)*nHidden(h)),nHidden(h-1),nHidden(h));
offset = offset+nHidden(h-1)*nHidden(h);
end
outputWeights = w(offset+1:offset+nHidden(end)*nLabels);
outputWeights = reshape(outputWeights,nHidden(end),nLabels);
f = 0;
if nargout > 1
gInput = zeros(size(inputWeights));
for h = 2:length(nHidden)
gHidden{h-1} = zeros(size(hiddenWeights{h-1}));
end
gOutput = zeros(size(outputWeights));
end
% Compute Output
for i = 1:nInstances
ip{1} = X(i,:)*inputWeights;
fp{1} = tanh(ip{1});
for h = 2:length(nHidden)
ip{h} = fp{h-1}*hiddenWeights{h-1};
fp{h} = tanh(ip{h});
end
yhat = fp{end}*outputWeights;
relativeErr = yhat-y(i,:);
f = f + sum(relativeErr.^2);
if nargout > 1
err = 2*relativeErr;
% Output Weights
for c = 1:nLabels
gOutput(:,c) = gOutput(:,c) + err(c)*fp{end}';
end
if length(nHidden) > 1
% Last Layer of Hidden Weights
clear backprop
for c = 1:nLabels
backprop(c,:) = err(c)*(sech(ip{end}).^2.*outputWeights(:,c)');
gHidden{end} = gHidden{end} + fp{end-1}'*backprop(c,:);
end
backprop = sum(backprop,1);
% Other Hidden Layers
for h = length(nHidden)-2:-1:1
backprop = (backprop*hiddenWeights{h+1}').*sech(ip{h+1}).^2;
gHidden{h} = gHidden{h} + fp{h}'*backprop;
end
% Input Weights
backprop = (backprop*hiddenWeights{1}').*sech(ip{1}).^2;
gInput = gInput + X(i,:)'*backprop;
else
% Input Weights
for c = 1:nLabels
gInput = gInput + err(c)*X(i,:)'*(sech(ip{end}).^2.*outputWeights(:,c)');
end
end
end
end
f = f + 1/1000 * norm(w)*norm(w);
% Put Gradient into vector
if nargout > 1
g = zeros(size(w));
g(1:nVars*nHidden(1)) = gInput(:);
offset = nVars*nHidden(1);
for h = 2:length(nHidden)
g(offset+1:offset+nHidden(h-1)*nHidden(h)) = gHidden{h-1};
offset = offset+nHidden(h-1)*nHidden(h);
end
g(offset+1:offset+nHidden(end)*nLabels) = gOutput(:);
g = g + 1/500 * w;
end