-
Notifications
You must be signed in to change notification settings - Fork 0
/
selftest.py
112 lines (88 loc) · 2.98 KB
/
selftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os, unittest
import logging
import time
import numpy as np
import bird_utils
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
np.set_printoptions(threshold=np.inf)
np.set_printoptions(linewidth=np.inf)
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
class MyBirdTestCase(unittest.TestCase):
def test_AggregateCounter(self):
counter = bird_utils.AggregateCounter(10)
for i in range(100):
counter.append(i)
self.assertEqual(counter.avg(), 94.5, "should be 94.5")
def test_numpy_conditional(self):
# 计算 y = a + [b > 0 ? c : 0]
a = np.array(
[
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]
]
)
b = np.array([
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1]
])
c = np.ones((2,5))
mask = b > 0
y = a+ np.where(mask, c, 0)
if True:
logging.info("a is %s", a)
logging.info("b is %s", b)
logging.info("c is %s", c)
logging.info("mask is %s", mask)
logging.info("y is %s", y)
def test_code(self):
a =[1,2,3,4]
print(a[:-2])
pass
def test_get_filename(self):
name = max(os.listdir('logs/tensorboard'))
logging.info('filename is %s', name)
def test_replay_buffer(self):
max_limit = 50* 10000
insert_count = 100* 10000
do_replay_buffer_baseline('deque', bird_utils.RandomReplayBuffer(max_limit), insert_count)
do_replay_buffer_baseline('list', bird_utils.RandomReplayBuffer_ImplAsList(max_limit), insert_count)
def do_replay_buffer_baseline(name, replay_buffer, insert_count):
#
# 测试1
#
ts_begin = time.time()
for i in range(insert_count):
replay_buffer.append( (i,str(i)) )
ts_elapsed = time.time() - ts_begin
logging.info("1 %s initialize %s item takes %s seconds", name, insert_count, ts_elapsed)
#
# 测试2
#
ts_begin = time.time()
sample_loop = 100
sample_batch = 1024
for loop in range(sample_loop):
replay_buffer.sample(sample_batch)
ts_elapsed = time.time() - ts_begin
logging.info("2 %s sample %s item over %s times takes %s seconds", name, sample_batch, sample_loop, ts_elapsed)
#
# 测试3 添加很多数据
#
ts_begin = time.time()
add_item_cnt = insert_count * 3
for i in range(add_item_cnt):
replay_buffer.append( (i,str(i)) )
ts_elapsed = time.time() - ts_begin
logging.info("3 %s add %s item takes %s seconds", name, add_item_cnt, ts_elapsed)
#
# 测试4
#
ts_begin = time.time()
sample_loop = 100
sample_batch = 1024
for loop in range(sample_loop):
replay_buffer.sample(sample_batch)
ts_elapsed = time.time() - ts_begin
logging.info("4 %s sample %s item over %s times takes %s seconds", name, sample_batch, sample_loop, ts_elapsed)
if __name__ == '__main__':
unittest.main()