-
Notifications
You must be signed in to change notification settings - Fork 0
/
dt_training.h
36 lines (30 loc) · 1.37 KB
/
dt_training.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#ifndef DT_TRAINING
#define DT_TRAINING
#define MAX_DEPTH 5
#define MAX_NODES 40
#define MIN_SIZE 20
#include <stdint.h>
#include "fixed.h"
#include "dataset.h"
#include "pipeline.h"
struct Node {
fixed threshold;
uint16_t feature;
uint16_t Left_group[MEMORY_SIZE];
uint16_t Right_group[MEMORY_SIZE];
fixed left_counter;
fixed right_counter;
uint16_t left_class;
uint16_t right_class;
uint16_t taken;
struct Node* left;
struct Node* right;
};
struct Node* decision_tree_training(fixed max_samples[MEMORY_SIZE+UPDATE_THR][N_FEATURE], struct Node* root, uint16_t y_train[MEMORY_SIZE+UPDATE_THR], uint16_t size);
struct Node* get_split(fixed max_samples[MEMORY_SIZE+UPDATE_THR][N_FEATURE], struct Node* root, uint16_t* group, uint16_t y_train[MEMORY_SIZE+UPDATE_THR], uint16_t size);
struct Node* split_samples(fixed max_samples[MEMORY_SIZE+UPDATE_THR][N_FEATURE], struct Node* root, uint16_t* group, uint16_t feature, fixed threshold, uint16_t size);
fixed gini_index(struct Node* root, uint16_t y_train[MEMORY_SIZE+UPDATE_THR]);
struct Node* split(fixed max_samples[MEMORY_SIZE+UPDATE_THR][N_FEATURE], struct Node* node, uint16_t y_train[MEMORY_SIZE+UPDATE_THR], uint16_t max_depth, uint16_t min_size, uint16_t depth);
struct Node* GetNewNode();
uint16_t to_terminal(uint16_t *group, uint16_t y_train[MEMORY_SIZE+UPDATE_THR], uint16_t size);
#endif