-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
115 lines (87 loc) · 2.75 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <iostream>
#include <fstream>
#include <Eigen/Dense>
#include <cstdlib>
#include "Layer.h"
#include "Trainset.h"
#include "misc.h"
#define PRINT(x) (cout<< #x << " = " << endl << x << endl)
using namespace Eigen;
using namespace std;
const int inputs = 2;
const int hidden = 10;
const int outputs = 1;
int main()
{
Layer w1(inputs,hidden);
Layer w2(hidden, outputs);
VectorXd x(inputs);
VectorXd y(hidden), dy(hidden), theta1(hidden);
VectorXd z(outputs), dz(outputs), theta2(outputs), ans(outputs);
Trainset train(2, 1, 100000);
Trainset testcase(2,1,5000);
int trainsize = train.getSize();
for(int i = 0; i < trainsize; i++) {
if (i % (trainsize / 100) == 0)
{
cout << "Training Process: " << (double) i / trainsize * 100 << '%' << " of " << trainsize << '\r';
cout.flush();
}
x = train.getInput(i);
theta1 = w1.calculate(x);
y = sigmoid(theta1);
theta2 = w2.calculate(y);
z = sigmoid(theta2);
ans = train.getOutput(i);
dz = ans - z;
MatrixXd dw1, dw2, diag_dtheta1, diag_dtheta2;
VectorXd db1, db2, dtheta1, dtheta2;
diag_dtheta1 = d_sigmoid(theta1).asDiagonal();
diag_dtheta2 = d_sigmoid(theta2).asDiagonal();
db2 = diag_dtheta2 * dz;
dw2 = db2 * y.transpose();
dy = w2.getMatrix().transpose() * diag_dtheta2 * dz;
db1 = diag_dtheta1 * dy;
dw1 = db1 * x.transpose();
// PRINT(x);
// PRINT(z);
// PRINT(dz);
double step = 0.1;
w1.setMatrix(w1.getMatrix() + dw1 * step);
w1.setOffset(w1.getOffset() + db1 * step);
w2.setMatrix(w2.getMatrix() + dw2 * step);
w2.setOffset(w2.getOffset() + db2 * step);
y = sigmoid(w1.calculate(x));
z = sigmoid(w2.calculate(y));
// PRINT(z);
}
for (int i = 0; i < testcase.getSize(); i++)
{
x = testcase.getInput(i);
y = sigmoid(w1.calculate(x));
z = sigmoid(w2.calculate(y));
testcase.setMyAns(z, i);
}
double accurate = testcase.calculatePrecision();
cout << endl << endl << "----------" << endl;
cout << "accurate: " << accurate * 100 << '%' << endl;
cout << "----------" << endl << endl;
fstream fout;
fout.open("plot/dot.txt", ios::out);
for (int i = 0; i < testcase.getSize(); i++)
{
x = testcase.getInput(i);
z = testcase.getMyAns(i);
for (int j = 0; j < x.rows(); j++)
{
fout << x[j] << ' ';
}
fout << z[0] << endl;
}
fout.close();
// PRINT(w1.getMatrix());
// PRINT(w1.getOffset());
// PRINT(w2.getMatrix());
// PRINT(w2.getOffset());
return 0;
}