forked from cristianocapone/biodreaming
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
148 lines (108 loc) · 5.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gym
import pickle
import numpy as np
from os import path
from tqdm import trange
from argparse import ArgumentParser
from argparse import Namespace
from src.agent import Actor, Planner
from src.config import Config
from src.monitor import Recorder
from collections import defaultdict
def main(args : Namespace):
# * Initialization of environment
env = gym.make(
args.env,
time_limit = args.timeout,
render_mode = args.render_mode,
num_actions = args.num_actions,
)
agent = Actor (Config[args.env])
planner = Planner(Config[args.env])
monitor = Recorder(args.monitor, do_raise=args.strict_monitor)
monitor.criterion = lambda episode : episode % args.monitor_freq == 0
for rep in trange(args.n_rep):
agent.forget()
planner.forget()
monitor.reset()
reward_tot = []
reward_fin = []
env.register_step_callback(monitor)
agent.register_step_callback(monitor)
planner.register_step_callback(monitor)
iterator = trange(args.epochs, desc = 'Episode reward: ---')
for episode in iterator:
agent.reset()
planner.reset()
obs, info = env.reset(options = {'button_pressed' : True})
# * §§§ Awake phase §§§
r_tot = 0
done, timeout = False, False
while not done and not timeout:
state = obs['agent_target']
# * Agent action
action = agent.step(state, deterministic = True, episode = episode)
# * Planner action: prediction of next env state
# Convert action to one-hot for concatenation with the observation
# planner_obs = np.concatenate((np.eye(args.num_actions)[action], obs['agent_target']))
pred_state, pred_reward = planner.step(
state,
action = np.eye(args.num_actions)[action],
deterministic = True,
episode = episode
)
# * Environment step
obs, r_fin, done, timeout, info = env.step(action, episode = episode)
# Update agents using the reward signal and planner using the prediction to
# the next environment state and reward
agent.accumulate_evidence(r_fin)
planner.accumulate_evidence((pred_state, pred_reward), (obs['agent_target'], r_fin))
planner.learn_from_evidence()
r_tot += r_fin
# Commit monitor buffer after episode end to have clear episode separation in data
monitor.commit_buffer()
agent.learn_from_evidence()
reward_tot.append(r_tot)
reward_fin.append(r_fin)
iterator.set_description(f'Episode reward {r_fin:.2f}')
# * §§§ Dreaming phase §§§
for _ in range(args.num_dream):
agent.reset()
planner.reset()
obs, info = env.reset(options = {'button_pressed' : True})
obs = obs['agent_target']
for dream_t in range(args.dream_len):
# * Agent action
action = agent.step(obs, deterministic = True)
# * Planner predicts the new observation
planner_obs = np.concatenate((action, obs))
obs, reward = planner.step(planner_obs, deterministic = True)
agent.accumulate_evidence(reward)
agent.learn_from_evidence()
monitor['reward_tot'].append(reward_tot)
monitor['reward_fin'].append(reward_fin)
# * Save agent
agent.save(path.join(args.save_dir, f'agent_{args.env}_{str(rep).zfill(2)}.pkl'))
# * Save planner
planner.save(path.join(args.save_dir, f'planner_{args.env}_{str(rep).zfill(2)}.pkl'))
# Save the monitored quantities
filename = path.join(args.save_dir, f'stats_{args.env}.pkl')
monitor.dump(filename)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-env', type = str, default = 'ButtonFood-v0', help = 'Environment to run')
parser.add_argument('-n_rep', type = int, default = 10, help = 'Number of repetitions of the experiment')
parser.add_argument('-epochs', type = int, default = 2000, help = 'Number of agent training iterations')
parser.add_argument('-timeout', type = int, default = 1000, help = 'Max number of environment episodes to run')
parser.add_argument('-num_dream', type = int, default = 0, help = 'Number of planner dreams')
parser.add_argument('-dream_len', type = int, default = 50, help = 'Length of each planner dream')
parser.add_argument('-num_actions', type = int, default = 8, help = 'Number of available discrete actions. Use 0 to trigger continuous control.')
# parser.add_argument('-dream_lag', default = 50, help = 'Time step of dreams')
parser.add_argument('-render_mode', type = str, default = None, choices = ['human', 'rgb_array', None], help = 'Rendering mode')
parser.add_argument('-monitor', type = str, nargs = '*', default = [], help = 'Path to monitor configuration for metric recording')
parser.add_argument('-monitor_freq', type = int, default = 1, help = 'Episode Frequency for metric recording')
parser.add_argument('-save_dir', type = str, default = 'data', help = 'Directory to save data')
parser.add_argument('-load_dir', type = str, default = 'data', help = 'Directory to load data')
parser.add_argument('--strict_monitor', action='store_true', default=False)
args = parser.parse_args()
main(args)