-
Notifications
You must be signed in to change notification settings - Fork 42
/
benchmark.py
73 lines (56 loc) · 2.24 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import time
import numpy as np
import argparse
from tqdm import tqdm
from model.hlnet import HLNet
from model.dfanet import DFANet
from model.enet import ENet
from model.lednet import LEDNet
from model.mobilenet import MobileNet
from model.fast_scnn import Fast_SCNN
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", '-i',
help="image size", type=int, default=256)
parser.add_argument("--batch_size", '-b',
help="batch size", type=int, default=3)
parser.add_argument("--model_name", help="model's name",
choices=['hlnet', 'fastscnn', 'lednet', 'dfanet', 'enet', 'mobilenet'],
type=str, default='hlnet')
parser.add_argument("--nums", help="output num",
type=int, default=1)
args = parser.parse_args()
IMG_SIZE = args.image_size
CLS_NUM = args.nums
def get_model(name):
if name == 'hlnet':
model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
elif name == 'fastscnn':
model = Fast_SCNN(num_classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
elif name == 'lednet':
model = LEDNet(groups=2, classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
elif name == 'dfanet':
model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM, size_factor=2)
elif name == 'enet':
model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
elif name == 'mobilenet':
model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
else:
raise NameError("No corresponding model...")
return model
def main():
"""Benchmark your model in your local pc."""
model = get_model(args.model_name)
inputs = np.random.randn(args.batch_size, args.image_size, args.image_size, 3)
time_per_batch = []
for i in tqdm(range(500)):
start = time.time()
model.predict(inputs, batch_size=args.batch_size)
elapsed = time.time() - start
time_per_batch.append(elapsed)
time_per_batch = np.array(time_per_batch)
# Remove the first item
print(time_per_batch[1:].mean())
if __name__ == '__main__':
main()