-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathreplay_memory.py
104 lines (79 loc) · 2.63 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import random
# TODO: make this list-based (i.e. variable sized)
class ReplayMemory:
def __init__(self, size, dimO, dimA, dtype=np.float32):
self.size = size
so = np.concatenate(np.atleast_1d(size,dimO),axis=0)
sa = np.concatenate(np.atleast_1d(size,dimA),axis=0)
self.observations = np.empty(so, dtype = dtype)
self.actions = np.empty(sa, dtype = np.float32)
self.rewards = np.empty(size, dtype = np.float32)
self.terminals = np.empty(size, dtype = np.bool)
self.info = np.empty(size,dtype = object)
self.n = 0
self.i = -1
def reset(self):
self.n = 0
self.i = -1
def enqueue(self, observation,terminal,action,reward,info=None):
self.i = (self.i + 1) % self.size
self.observations[self.i, ...] = observation
self.terminals[self.i] = terminal # tells whether this observation is the last
self.actions[self.i,...] = action
self.rewards[self.i] = reward
self.info[self.i,...] = info
self.n = min(self.size-1, self.n + 1)
def minibatch(self,size):
# sample uniform random indexes
# indices = np.zeros(size,dtype=np.int)
# for k in range(size):
# # find random index
# invalid = True
# while invalid:
# # sample index ignore wrapping over buffer
# i = random.randint(0, self.n-2)
# # if i-th sample is current one or is terminal: get new index
# if i != self.i and not self.terminals[i]:
# invalid = False
# indices[k] = i
#print i
#print self.i
indices = np.random.randint(0, self.n-2, size)
o = self.observations[indices,...]
a = self.actions[indices]
r = self.rewards[indices]
t = self.terminals[indices]
o2 = self.observations[indices+1,...]
# t2 = self.terminals[indices+1] # to return t2 instead of t was a mistake
info = self.info[indices,...]
return o, a, r, t, o2, info
def __repr__(self):
indices = range(0,self.n)
o = self.observations[indices,...]
a = self.actions[indices]
r = self.rewards[indices]
t = self.terminals[indices]
info = self.info[indices,...]
s = """
OBSERVATIONS
{}
ACTIONS
{}
REWARDS
{}
TERMINALS
{}
""".format(o,a,r,t)
return s
# TODO: relocate test
if __name__ == '__main__':
s = 100
rm = ReplayMemory(s,1,1)
for i in range(0,100,1):
rm.enqueue(i,i%3==0,i,i,i)
for i in range(1000):
o, a, r, t, o2, info = rm.minibatch(10)
assert all(o == o2-1),"error: o and o2"
assert all(o != s-1) , "error: o wrap over rm. o = "+str(o)
assert all(o2 != 0) , "error: o2 wrap over rm"