-
Notifications
You must be signed in to change notification settings - Fork 4
/
onlinenode.h
173 lines (145 loc) · 4.66 KB
/
onlinenode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#ifndef ONLINENODE_H_
#define ONLINENODE_H_
#include <vector>
#include "data.h"
#include "hp.h"
#include "randomtest.h"
#include "utilities.h"
using namespace std;
class OnlineNode {
public:
OnlineNode() {
m_isLeaf = true;
}
OnlineNode(const Hyperparameters &hp, const int &numClasses, const int &depth) :
m_numClasses(&numClasses), m_depth(depth), label(6,0), m_isLeaf(true), m_counter(0.0), m_label(6,-1),
m_parentCounter(0.0), m_hp(&hp) {
for (int i = 0; i <numClasses; i++) {
m_labelStats.push_back(0.0);
}
// Creating random tests
for (int i = 0; i < hp.numRandomTests; i++) {
HyperplaneFeature test(numClasses);
m_onlineTests.push_back(test);
}
}
int m_transformlabel;
OnlineNode(const Hyperparameters &hp, const int &numClasses, const int &depth, const vector<double> &parentStats,vector<float> extantlabel) :
m_numClasses(&numClasses),m_depth(depth), label(6,0.0),m_isLeaf(true), m_counter(0.0), m_label(6,-1),
m_parentCounter(0.0), m_hp(&hp){
m_labelStats = parentStats;
m_label = extantlabel;
m_parentCounter = sum(m_labelStats);
// Creating random tests
for (int i = 0; i < hp.numRandomTests; i++) {
HyperplaneFeature test(numClasses);
m_onlineTests.push_back(test);
}
}
~OnlineNode() {
if (!m_isLeaf) {
delete m_leftChildNode;
delete m_rightChildNode;
}
}
vector<float> label;
void update(Sample &sample,DataSet &dataset_tr){
m_counter += sample.w;
m_labelStats[sample.ma] += sample.w;
for(int i=0;i<6;i++)
{
label[i]+=sample.y[i];
}
if (m_isLeaf) {
// Update online tests
for (int i = 0; i < m_hp->numRandomTests; i++) {
m_onlineTests[i].update(sample);
}
for(int i=0;i<6;i++){
m_label[i]=label[i]/m_counter;
}
// Decide for split
if (shouldISplit()) {
m_isLeaf = false;
// Find the best online test
int maxIndex = 0;
double maxScore = -1e100, score;
for (int i = 0; i < m_hp->numRandomTests; i++) {
score = m_onlineTests[i].score(dataset_tr);
if (score > maxScore) {
maxScore = score;
maxIndex = i;
}
}
m_bestTest = m_onlineTests[maxIndex];
m_onlineTests.clear();
// Split
pair<vector<double> , vector<double> > parentStats = m_bestTest.getStats();
m_rightChildNode = new OnlineNode(*m_hp, *m_numClasses, m_depth + 1,parentStats.first,m_label);
m_leftChildNode = new OnlineNode(*m_hp, *m_numClasses,m_depth + 1,parentStats.second,m_label);
}
}
else {
if (m_bestTest.eval(sample)) {
m_rightChildNode->update(sample,dataset_tr);
} else {
m_leftChildNode->update(sample,dataset_tr);
}
}
}
Result eval(Sample &sample) {
if (m_isLeaf) {
Result result;
if (m_counter + m_parentCounter) {
result.confidence = m_labelStats;
result.prediction = m_label;
} else {
for (int i = 0; i < *m_numClasses; i++) {
result.confidence.push_back(1.0 / *m_numClasses);
}
for(int i=0;i<6;i++)
result.prediction.push_back (0.0);
}
return result;
} else {
if (m_bestTest.eval(sample)) {
return m_rightChildNode->eval(sample);
} else {
return m_leftChildNode->eval(sample);
}
}
}
private:
const int *m_numClasses;
int m_depth;
bool m_isLeaf;
double m_counter;
vector<float> m_label;
double m_parentCounter;
const Hyperparameters *m_hp;
vector<double> m_labelStats;
OnlineNode* m_leftChildNode;
OnlineNode* m_rightChildNode;
vector<HyperplaneFeature> m_onlineTests;
HyperplaneFeature m_bestTest;
bool shouldISplit() {
bool isPure = false;
for (int i = 0; i < *m_numClasses; i++) {
if (m_labelStats[i] == m_counter + m_parentCounter) {
isPure = true;
break;
}
}
if (isPure) {
return false;
}
if (m_depth >= m_hp->maxDepth) { // Do not split if the max depth is reached
return false;
}
if (m_counter < m_hp->counterThreshold) { // Do not split if with not enough samples
return false;
}
return true;
}
};
#endif /* ONLINENODE_H_ */