-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdriver.py
74 lines (61 loc) · 1.75 KB
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from DatasetGenerator import gp_datagen
from OnlineClassifiers import train_al_perceptron
from GraphingUtils import *
from ALRules import *
import matplotlib.pyplot as plt
import numpy as np
import random
N_train = 1000
N_test = 100
data = gp_datagen(N_train + N_test, seed=2)
data_train = data[:N_train]
data_test = data[N_train:]
def acc(classifier):
return classifier.score(
np.stack(data_test[:, 0]),
data_test[:, 1].astype(int)
)
clfs = {}
n_pts = {}
for p in np.linspace(0, 1, 20):
if p > 0:
clf_random, ixs_random = train_al_perceptron(
data_train,
rule=lambda *x: random.random() < p
)
clfs['random {:0.2f}'.format(p)] = clf_random
n_pts['random {:0.2f}'.format(p)] = len(ixs_random)
for b in np.logspace(0.001, 1000, 10):
clf_bs, ixs_bs = train_al_perceptron(
data_train,
rule=predefined_func_decision(b, b_sampling)
)
if len(ixs_bs) > 0:
clfs['b-sampling {:0.2f}'.format(b)] = clf_bs
n_pts['b-sampling {:0.2f}'.format(b)] = len(ixs_bs)
for g in np.logspace(0.001, 1000, 10):
clf_lm, ixs_lm = train_al_perceptron(
data_train,
rule=predefined_func_decision(g, lm_sampling)
)
if len(ixs_lm) > 0:
clfs['lm-sampling {:0.2f}'.format(g)] = clf_lm
n_pts['lm-sampling {:0.2f}'.format(g)] = len(ixs_lm)
accs = {k: acc(v) for k, v in clfs.items()}
plt.figure()
rule_types = (
'random',
'b-sampling',
'lm-sampling'
)
for rule in rule_types:
plt.plot(
[v for k, v in n_pts.items() if rule in k],
[v for k, v in accs.items() if rule in k],
'-o', alpha=0.5
)
plt.xlabel('Points Sampled')
plt.ylabel('Accuracy')
plt.legend(rule_types)
plt.show()
plot_points(data_test)