-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy path03_generate_rnn_data.py
105 lines (67 loc) · 2.51 KB
/
03_generate_rnn_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#python 03_generate_rnn_data.py
from vae.arch import VAE
import argparse
import config
import numpy as np
import os
ROOT_DIR_NAME = "./data/"
ROLLOUT_DIR_NAME = "./data/rollout/"
SERIES_DIR_NAME = "./data/series/"
def get_filelist(N):
filelist = os.listdir(ROLLOUT_DIR_NAME)
filelist = [x for x in filelist if x != '.DS_Store']
filelist.sort()
length_filelist = len(filelist)
if length_filelist > N:
filelist = filelist[:N]
if length_filelist < N:
N = length_filelist
return filelist, N
def encode_episode(vae, episode):
obs = episode['obs']
action = episode['action']
reward = episode['reward']
done = episode['done']
done = done.astype(int)
reward = np.where(reward>0, 1, 0) * np.where(done==0, 1, 0)
mu, log_var, _ = vae.encoder.predict(obs)
initial_mu = mu[0, :]
initial_log_var = log_var[0, :]
return (mu, log_var, action, reward, done, initial_mu, initial_log_var)
def main(args):
N = args.N
vae = VAE()
try:
vae.set_weights('./vae/weights.h5')
except Exception as e:
print(e)
print("./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first")
raise
filelist, N = get_filelist(N)
file_count = 0
initial_mus = []
initial_log_vars = []
for file in filelist:
try:
rollout_data = np.load(ROLLOUT_DIR_NAME + file)
mu, log_var, action, reward, done, initial_mu, initial_log_var = encode_episode(vae, rollout_data)
np.savez_compressed(SERIES_DIR_NAME + file, mu=mu, log_var=log_var, action = action, reward = reward, done = done)
initial_mus.append(initial_mu)
initial_log_vars.append(initial_log_var)
file_count += 1
if file_count%50==0:
print('Encoded {} / {} episodes'.format(file_count, N))
except Exception as e:
print(e)
print('Skipped {}...'.format(file))
print('Encoded {} / {} episodes'.format(file_count, N))
initial_mus = np.array(initial_mus)
initial_log_vars = np.array(initial_log_vars)
print('ONE MU SHAPE = {}'.format(mu.shape))
print('INITIAL MU SHAPE = {}'.format(initial_mus.shape))
np.savez_compressed(ROOT_DIR_NAME + 'initial_z.npz', initial_mu=initial_mus, initial_log_var=initial_log_vars)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=('Generate RNN data'))
parser.add_argument('--N',default = 10000, help='number of episodes to use to train')
args = parser.parse_args()
main(args)