-
Notifications
You must be signed in to change notification settings - Fork 28
/
test.py
39 lines (33 loc) · 1.16 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from datasets import get_ds
from cfg import get_cfg
from methods import get_method
from eval.sgd import eval_sgd
from eval.knn import eval_knn
from eval.lbfgs import eval_lbfgs
from eval.get_data import get_data
if __name__ == "__main__":
cfg = get_cfg()
model_full = get_method(cfg.method)(cfg)
model_full.cuda().eval()
if cfg.fname is None:
print("evaluating random model")
else:
model_full.load_state_dict(torch.load(cfg.fname))
ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers)
device = "cpu" if cfg.clf == "lbfgs" else "cuda"
if cfg.eval_head:
model = lambda x: model_full.head(model_full.model(x))
out_size = cfg.emb
else:
model = model_full.model
out_size = model_full.out_size
x_train, y_train = get_data(model, ds.clf, out_size, device)
x_test, y_test = get_data(model, ds.test, out_size, device)
if cfg.clf == "sgd":
acc = eval_sgd(x_train, y_train, x_test, y_test)
if cfg.clf == "knn":
acc = eval_knn(x_train, y_train, x_test, y_test)
elif cfg.clf == "lbfgs":
acc = eval_lbfgs(x_train, y_train, x_test, y_test)
print(acc)