-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
278 lines (251 loc) · 13.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import argparse
import gym
import numpy as np
import os
import tensorflow as tf
import tempfile
import time
import json
import random
import rlattack.common.tf_util as U
from rlattack import logger
from rlattack import deepq
from rlattack.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from rlattack.common.misc_util import (
boolean_flag,
pickle_load,
pretty_eta,
relatively_safe_pickle_dump,
set_global_seeds,
RunningAvg,
SimpleMonitor
)
from rlattack.common.schedules import LinearSchedule, PiecewiseSchedule
# when updating this to non-deperecated ones, it is important to
# copy over LazyFrames
from rlattack.common.atari_wrappers_deprecated import wrap_dqn
from rlattack.common.azure_utils import Container
from model import model, dueling_model
from statistics import statistics
def parse_args():
parser = argparse.ArgumentParser("DQN experiments for Atari games")
# Environment
parser.add_argument("--env", type=str, default="Pong", help="name of the game")
parser.add_argument("--seed", type=int, default=42, help="which seed to use")
# Core DQN parameters
parser.add_argument("--replay-buffer-size", type=int, default=int(1e6), help="replay buffer size")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Adam optimizer")
parser.add_argument("--num-steps", type=int, default=int(2e8), help="total number of steps to run the environment for")
parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time")
parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step")
parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
# Bells and whistles
boolean_flag(parser, "noisy", default=False, help="whether or not to NoisyNetwork")
boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
boolean_flag(parser, "prioritized", default=False, help="whether or not to use prioritized replay buffer")
parser.add_argument("--prioritized-alpha", type=float, default=0.6, help="alpha parameter for prioritized replay buffer")
parser.add_argument("--prioritized-beta0", type=float, default=0.4, help="initial value of beta parameters for prioritized replay")
parser.add_argument("--prioritized-eps", type=float, default=1e-6, help="eps parameter for prioritized replay buffer")
# Checkpointing
parser.add_argument("--save-dir", type=str, default=None, required=True, help="directory in which training state and model should be saved.")
parser.add_argument("--save-azure-container", type=str, default=None,
help="It present data will saved/loaded from Azure. Should be in format ACCOUNT_NAME:ACCOUNT_KEY:CONTAINER")
parser.add_argument("--save-freq", type=int, default=1e5, help="save model once every time this many iterations are completed")
boolean_flag(parser, "load-on-start", default=True, help="if true and model was previously saved then training will be resumed")
#V: Attack Arguments #
parser.add_argument("--attack", type=str, default=None, help="Method to attack the model.")
parser.add_argument("--attack-init", type=int, default=0, help="Iteration no. to begin attacks")
parser.add_argument("--attack-prob", type=float, default=0.0, help="Probability of attack at each step, float in range 0 - 1.0" )
return parser.parse_args()
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v4")
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
return env, monitored_env
def maybe_save_model(savedir, container, state):
"""This function checkpoints the model and state of the training algorithm."""
if savedir is None:
return
start_time = time.time()
model_dir = "model-{}".format(state["num_iters"])
U.save_state(os.path.join(savedir, model_dir, "saved"))
if container is not None:
container.put(os.path.join(savedir, model_dir), model_dir)
relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True)
if container is not None:
container.put(os.path.join(savedir, 'training_state.pkl.zip'), 'training_state.pkl.zip')
relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl'))
if container is not None:
container.put(os.path.join(savedir, 'monitor_state.pkl'), 'monitor_state.pkl')
logger.log("Saved model in {} seconds\n".format(time.time() - start_time))
def maybe_load_model(savedir, container):
"""Load model if present at the specified path."""
if savedir is None:
return
state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
if container is not None:
logger.log("Attempting to download model from Azure")
found_model = container.get(savedir, 'training_state.pkl.zip')
else:
found_model = os.path.exists(state_path)
if found_model:
state = pickle_load(state_path, compression=True)
model_dir = "model-{}".format(state["num_iters"])
if container is not None:
container.get(savedir, model_dir)
U.load_state(os.path.join(savedir, model_dir, "saved"))
logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
return state
if __name__ == '__main__':
args = parse_args()
# Parse savedir and azure container.
savedir = args.save_dir
if args.save_azure_container is not None:
account_name, account_key, container_name = args.save_azure_container.split(":")
container = Container(account_name=account_name,
account_key=account_key,
container_name=container_name,
maybe_create=True)
if savedir is None:
# Careful! This will not get cleaned up. Docker spoils the developers.
savedir = tempfile.TemporaryDirectory().name
else:
container = None
# Create and seed the env.
env, monitored_env = make_env(args.env)
if args.seed > 0:
set_global_seeds(args.seed)
env.unwrapped.seed(args.seed)
# V: Save arguments, configure log dump path to savedir #
if savedir:
with open(os.path.join(savedir, 'args.json'), 'w+') as f:
json.dump(vars(args), f)
logger.configure(dir=savedir) # log to savedir
with U.make_session(4) as sess:
# Create training graph and replay buffer
act, train, update_target, debug, craft_adv = deepq.build_train(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
q_func=dueling_model if args.dueling else model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=args.lr, epsilon=1e-4),
gamma=0.99,
grad_norm_clipping=10,
double_q=args.double_q,
noisy=args.noisy,
attack = args.attack
)
approximate_num_iters = args.num_steps / 4
exploration = PiecewiseSchedule([
(0, 1.0),
(approximate_num_iters / 50, 0.1),
(approximate_num_iters / 5, 0.01)
], outside_value=0.01)
if args.prioritized:
replay_buffer = PrioritizedReplayBuffer(args.replay_buffer_size, args.prioritized_alpha)
beta_schedule = LinearSchedule(approximate_num_iters, initial_p=args.prioritized_beta0, final_p=1.0)
else:
replay_buffer = ReplayBuffer(args.replay_buffer_size)
U.initialize()
update_target()
num_iters = 0
# Load the model
state = maybe_load_model(savedir, container)
if state is not None:
num_iters, replay_buffer = state["num_iters"], state["replay_buffer"],
monitored_env.set_state(state["monitor_state"])
start_time, start_steps = None, None
steps_per_iter = RunningAvg(0.999)
iteration_time_est = RunningAvg(0.999)
obs = env.reset()
# Record the mean of the \sigma
sigma_name_list = []
sigma_list = []
for param in tf.trainable_variables():
# only record the \sigma in the action network
if 'sigma' in param.name and 'deepq/q_func/action_value' in param.name:
summary_name = param.name.replace('deepq/q_func/action_value/', '').replace('/', '.').split(':')[0]
sigma_name_list.append(summary_name)
sigma_list.append(tf.reduce_mean(tf.abs(param)))
f_mean_sigma = U.function(inputs=[], outputs=sigma_list)
# Statistics
writer = tf.summary.FileWriter(savedir, sess.graph)
im_stats = statistics(scalar_keys=['action', 'im_reward', 'td_errors', 'huber_loss']+sigma_name_list)
ep_stats = statistics(scalar_keys=['ep_reward', 'ep_length'])
# Main trianing loop
ep_length = 0
while True:
num_iters += 1
ep_length += 1
#V: Perturb observation if we are past the init stage and at a designated attack step #
#if craft_adv != None and (num_iters >= args.attack_init) and ((num_iters - args.attack_init) % args.attack_freq == 0) :
if craft_adv != None and (num_iters >= args.attack_init) and (random.random() <= args.attack_prob) :
obs = craft_adv(np.array(obs)[None])[0]
# Take action and store transition in the replay buffer.
if args.noisy:
# greedily choose
action = act(np.array(obs)[None], stochastic=False)[0]
else:
# epsilon greedy
action = act(np.array(obs)[None], update_eps=exploration.value(num_iters))[0]
new_obs, rew, done, info = env.step(action)
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
if done:
obs = env.reset()
if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and
num_iters % args.learning_freq == 0):
# Sample a bunch of transitions from replay buffer
if args.prioritized:
experience = replay_buffer.sample(args.batch_size, beta=beta_schedule.value(num_iters))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size)
weights = np.ones_like(rewards)
# Minimize the error in Bellman's equation and compute TD-error
td_errors, huber_loss = train(obses_t, actions, rewards, obses_tp1, dones, weights)
# Update the priorities in the replay buffer
if args.prioritized:
new_priorities = np.abs(td_errors) + args.prioritized_eps
replay_buffer.update_priorities(batch_idxes, new_priorities)
# Write summary
mean_sigma = f_mean_sigma()
im_stats.add_all_summary(writer, [action, rew, np.mean(td_errors), np.mean(huber_loss)]+mean_sigma, num_iters)
# Update target network.
if num_iters % args.target_update_freq == 0:
update_target()
if start_time is not None:
steps_per_iter.update(info['steps'] - start_steps)
iteration_time_est.update(time.time() - start_time)
start_time, start_steps = time.time(), info["steps"]
# Save the model and training state.
if num_iters > 0 and (num_iters % args.save_freq == 0 or info["steps"] > args.num_steps):
maybe_save_model(savedir, container, {
'replay_buffer': replay_buffer,
'num_iters': num_iters,
'monitor_state': monitored_env.get_state()
})
if info["steps"] > args.num_steps:
break
if done:
steps_left = args.num_steps - info["steps"]
completion = np.round(info["steps"] / args.num_steps, 1)
mean_ep_reward = np.mean(info["rewards"][-100:])
logger.record_tabular("% completion", completion)
logger.record_tabular("steps", info["steps"])
logger.record_tabular("iters", num_iters)
logger.record_tabular("episodes", len(info["rewards"]))
logger.record_tabular("reward (100 epi mean)", np.mean(info["rewards"][-100:]))
if not args.noisy:
logger.record_tabular("exploration", exploration.value(num_iters))
if args.prioritized:
logger.record_tabular("max priority", replay_buffer._max_priority)
fps_estimate = (float(steps_per_iter) / (float(iteration_time_est) + 1e-6)
if steps_per_iter._value is not None else "calculating...")
logger.dump_tabular()
logger.log()
logger.log("ETA: " + pretty_eta(int(steps_left / fps_estimate)))
logger.log()
# add summary for one episode
ep_stats.add_all_summary(writer, [mean_ep_reward, ep_length], num_iters)
ep_length = 0