-
Notifications
You must be signed in to change notification settings - Fork 0
/
elm_kernel_test.m
92 lines (67 loc) · 2.56 KB
/
elm_kernel_test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
function Output = elm_kernel_test(TestingData, model)
%%%%%%%%%%% Load testing dataset
test_data=TestingData;
TestLabel=test_data(:,1);
TestData=test_data(:,2:end);
clear test_data; % Release raw testing data array
tic;
Kernel_type=model.Kernel_type;
Kernel_para=model.Kernel_para;
Omega_test = kernel_matrix(TestData,Kernel_type, Kernel_para,model.X);
TY=Omega_test * model.beta;
model.TrainingTime=toc;
MissClassificationRate_Testing=0;
[~,PredictLabelInx]=max(TY,[],2);
Output.PredictLabel=model.label(PredictLabelInx);
for i = 1 : size(TY, 1)
if Output.PredictLabel(i)~=TestLabel(i)
MissClassificationRate_Testing=MissClassificationRate_Testing+1;
end
end
Output.TestingAccuracy=1-MissClassificationRate_Testing/size(TY, 1);
Output.TestingTime=toc;
end
%%%%%%%%%%%%%%%%%% Kernel Matrix %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function omega = kernel_matrix(Xtrain,kernel_type, kernel_pars,Xt)
nb_data = size(Xtrain,1);
if strcmp(kernel_type,'RBF_kernel'),
if nargin<4,
XXh = sum(Xtrain.^2,2)*ones(1,nb_data);
omega = XXh+XXh'-2*(Xtrain*Xtrain');
omega = exp(-omega./kernel_pars(1));
else
XXh1 = sum(Xtrain.^2,2)*ones(1,size(Xt,1));
XXh2 = sum(Xt.^2,2)*ones(1,nb_data);
omega = XXh1+XXh2' - 2*Xtrain*Xt';
omega = exp(-omega./kernel_pars(1));
end
elseif strcmp(kernel_type,'lin_kernel')
if nargin<4,
omega = Xtrain*Xtrain';
else
omega = Xtrain*Xt';
end
elseif strcmp(kernel_type,'poly_kernel')
if nargin<4,
omega = (Xtrain*Xtrain'+kernel_pars(1)).^kernel_pars(2);
else
omega = (Xtrain*Xt'+kernel_pars(1)).^kernel_pars(2);
end
elseif strcmp(kernel_type,'wav_kernel')
if nargin<4,
XXh = sum(Xtrain.^2,2)*ones(1,nb_data);
omega = XXh+XXh'-2*(Xtrain*Xtrain');
XXh1 = sum(Xtrain,2)*ones(1,nb_data);
omega1 = XXh1-XXh1';
omega = cos(kernel_pars(3)*omega1./kernel_pars(2)).*exp(-omega./kernel_pars(1));
else
XXh1 = sum(Xtrain.^2,2)*ones(1,size(Xt,1));
XXh2 = sum(Xt.^2,2)*ones(1,nb_data);
omega = XXh1+XXh2' - 2*(Xtrain*Xt');
XXh11 = sum(Xtrain,2)*ones(1,size(Xt,1));
XXh22 = sum(Xt,2)*ones(1,nb_data);
omega1 = XXh11-XXh22';
omega = cos(kernel_pars(3)*omega1./kernel_pars(2)).*exp(-omega./kernel_pars(1));
end
end
end