-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
91 lines (82 loc) · 3.36 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import cv2
import string
from tqdm import tqdm
import click
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from dataset.test_data import TestDataset
from dataset.text_data import TextDataset
from dataset.collate_fn import text_collate
from dataset.data_transform import Resize, Rotation, Translation, Scale
from models.model_loader import load_model
from torchvision.transforms import Compose
import editdistance
def test(net, data, abc, cuda, visualize, batch_size=256):
data_loader = DataLoader(data, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=text_collate)
count = 0
tp = 0
avg_ed = 0
iterator = tqdm(data_loader)
for sample in iterator:
imgs = Variable(sample["img"])
if cuda:
imgs = imgs.cuda()
out = net(imgs, decode=True)
gt = (sample["seq"].numpy() - 1).tolist()
lens = sample["seq_len"].numpy().tolist()
pos = 0
key = ''
for i in range(len(out)):
gts = ''.join(abc[c] for c in gt[pos:pos+lens[i]])
pos += lens[i]
if gts == out[i]:
tp += 1
else:
avg_ed += editdistance.eval(out[i], gts)
count += 1
if visualize:
status = "pred: {}; gt: {}".format(out[i], gts)
iterator.set_description(status)
img = imgs[i].permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
cv2.imshow("img", img)
key = chr(cv2.waitKey() & 255)
if key == 'q':
break
if key == 'q':
break
if not visualize:
iterator.set_description("acc: {0:.4f}; avg_ed: {0:.4f}".format(tp / count, avg_ed / count))
acc = tp / count
avg_ed = avg_ed / count
return acc, avg_ed
@click.command()
@click.option('--data-path', type=str, default=None, help='Path to dataset')
@click.option('--abc', type=str, default=string.digits+string.ascii_uppercase, help='Alphabet')
@click.option('--seq-proj', type=str, default="10x20", help='Projection of sequence')
@click.option('--backend', type=str, default="resnet18", help='Backend network')
@click.option('--snapshot', type=str, default=None, help='Pre-trained weights')
@click.option('--input-size', type=str, default="320x32", help='Input size')
@click.option('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
@click.option('--visualize', type=bool, default=False, help='Visualize output')
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
cuda = True if gpu is not '' else False
input_size = [int(x) for x in input_size.split('x')]
transform = Compose([
Rotation(),
Resize(size=(input_size[0], input_size[1]))
])
if data_path is not None:
data = TextDataset(data_path=data_path, mode="test", transform=transform)
else:
data = TestDataset(transform=transform, abc=abc)
seq_proj = [int(x) for x in seq_proj.split('x')]
net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
acc, avg_ed = test(net, data, data.get_abc(), cuda, visualize)
print("Accuracy: {}".format(acc))
print("Edit distance: {}".format(avg_ed))
if __name__ == '__main__':
main()