forked from Damcy/prioritized-experience-replay
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbinary_heap.py
211 lines (182 loc) · 6.84 KB
/
binary_heap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: [email protected]
# description:
import sys
import math
import utility
class BinaryHeap(object):
def __init__(self, priority_size=100, priority_init=None, replace=True):
self.e2p = {}
self.p2e = {}
self.replace = replace
if priority_init is None:
self.priority_queue = {}
self.size = 0
self.max_size = priority_size
def __repr__(self):
"""
:return: string of the priority queue, with level info
"""
if self.size == 0:
return 'No element in heap!'
to_string = ''
level = -1
max_level = math.floor(math.log(self.size, 2))
for i in range(1, self.size + 1):
now_level = math.floor(math.log(i, 2))
# log base 2 return 0, 1, 1, 2, 2, 2, 2, 3 ..
if level != now_level:
to_string = to_string + ('\n' if level != -1 else '') \
+ str(int(max_level - now_level))
level = now_level
to_string = to_string + ' {}:{} '.format(self.priority_queue[i][0], self.priority_queue[i][1])
return to_string
def check_full(self):
return self.size > self.max_size
def _insert(self, priority, e_id):
"""
insert new experience id with priority
(maybe don't need get_max_priority and implement it in this function)
:param priority: priority value
:param e_id: experience id
:return: bool
"""
self.size += 1
if self.check_full() and not self.replace:
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
return False
else:
self.size = min(self.size, self.max_size)
self.priority_queue[self.size] = (priority, e_id)
self.p2e[self.size] = e_id
self.e2p[e_id] = self.size
self.up_heap(self.size)
return True
def update(self, priority, e_id):
"""
update priority value according its experience id
:param priority: new priority value
:param e_id: experience id
:return: bool
"""
if e_id in self.e2p:
p_id = self.e2p[e_id]
self.priority_queue[p_id] = (priority, e_id)
self.p2e[p_id] = e_id
self.down_heap(p_id)
self.up_heap(p_id)
return True
else:
# this e id is new, do insert
return self._insert(priority, e_id)
def get_max_priority(self):
"""
get max priority, if no experience, return 1
:return: max priority if size > 0 else 1
"""
if self.size > 0:
return self.priority_queue[1][0]
else:
return 1
def pop(self):
"""
pop out the max priority value with its experience id
:return: priority value & experience id
"""
if self.size == 0:
sys.stderr.write('Error: no value in heap, pop failed\n')
return False, False
pop_priority, pop_e_id = self.priority_queue[1]
self.e2p[pop_e_id] = -1
# replace first
last_priority, last_e_id = self.priority_queue[self.size]
self.priority_queue[1] = (last_priority, last_e_id)
self.size -= 1
self.e2p[last_e_id] = 1
self.p2e[1] = last_e_id
self.down_heap(1)
return pop_priority, pop_e_id
def up_heap(self, i):
"""
upward balance
:param i: tree node i
:return: None
"""
if i > 1:
parent = math.floor(i / 2)
if self.priority_queue[parent][0] < self.priority_queue[i][0]:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[parent]
self.priority_queue[parent] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[parent][1]] = parent
self.p2e[i] = self.priority_queue[i][1]
self.p2e[parent] = self.priority_queue[parent][1]
# up heap parent
self.up_heap(parent)
def down_heap(self, i):
"""
downward balance
:param i: tree node i
:return: None
"""
if i < self.size:
greatest = i
left, right = i * 2, i * 2 + 1
if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]:
greatest = left
if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]:
greatest = right
if greatest != i:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[greatest]
self.priority_queue[greatest] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[greatest][1]] = greatest
self.p2e[i] = self.priority_queue[i][1]
self.p2e[greatest] = self.priority_queue[greatest][1]
# down heap greatest
self.down_heap(greatest)
def get_priority(self):
"""
get all priority value
:return: list of priority
"""
return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size]
def get_e_id(self):
"""
get all experience id in priority queue
:return: list of experience ids order by their priority
"""
return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size]
def balance_tree(self):
"""
rebalance priority queue
:return: None
"""
sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True)
# reconstruct priority_queue
self.priority_queue.clear()
self.p2e.clear()
self.e2p.clear()
cnt = 1
while cnt <= self.size:
priority, e_id = sort_array[cnt - 1]
self.priority_queue[cnt] = (priority, e_id)
self.p2e[cnt] = e_id
self.e2p[e_id] = cnt
cnt += 1
# sort the heap
for i in range(int(math.floor(self.size / 2)), 1, -1):
self.down_heap(i)
def priority_to_experience(self, priority_ids):
"""
retrieve experience ids by priority ids
:param priority_ids: list of priority id
:return: list of experience id
"""
return [self.p2e[i] for i in priority_ids]