-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_bbb_test.py
60 lines (43 loc) · 1.55 KB
/
mnist_bbb_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import mxnet as mx
import numpy as np
from BayesByBackprop import BayesByBackprop
import BBBModel as bbb
################
# BayesByBackprop test
################
model = BayesByBackprop(seed=0)
def transform(data, label):
return data.astype(np.float32)/126.0, label.astype(np.float32)
train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=transform)
test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=transform)
num_inputs = 784
num_outputs = 10
# model.define_model(num_inputs, num_outputs)
# model.train(train_dataset, test_dataset)
################
# BBBModel test
################
model_db_path = "./db_models"
model_id = 0
model_path = "{}/m{}.pkl".format(model_db_path, model_id)
# bbb.train_MNIST(seed=model_id, model_path=model_path)
sample_idx = 0
sample_train = train_dataset[sample_idx]
sample_train_data = sample_train[0]
sample_train_label = sample_train[1]
sample_test = test_dataset[sample_idx]
sample_test_data = sample_test[0]
sample_test_label = sample_test[1]
# output = model.predict(sample_test_data, model_path)
# print("label: ", sample_test_label)
# print("output: ", output[sample_idx].asscalar())
################
# BBBModel test
################
def transform(data, label):
return data.astype(np.float32)/126.0, label.astype(np.float32)
train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=transform)
test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=transform)
num_inputs = 784
num_outputs = 10
bbb.train(train_dataset, test_dataset, num_inputs, num_outputs, model_id, model_path)