-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
1349 lines (1146 loc) · 55.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# the basic training template and IMPALA for this project was taken from https://github.com/facebookresearch/torchbeast
# The vision network and the LSTM were replaced with our gated transformerXL architectures
# A more efficient form of batching was done to feed into the learner
'''
This file is for running on DMLab
TODO: Want to be able to run both DMLab and Atari (shouldn't be very large changes)
TODO: Want to trim off part of chunk if all masked in each element of batch: This may cause issues where we
refer to T as the size of this.
'''
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
import logging
import pprint
import threading
import time
import timeit
import traceback
import typing
from StableTransformersReplication.transformer_xl import MemTransformerLM
from adaptive_span2.models import TransformerSeq as AdaptiveTransformer
os.environ["OMP_NUM_THREADS"] = "1" # Necessary for multithreading.
import torch
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F
from torchbeast.core import environment as dmlab_environment
try:
from torchbeast import dmlab_wrappers
except:
print('NO DMLAB module') #is for case where using Atari on machine without dmlab
from Model.core import prof, vtrace, file_writer
from Model.core import environment as atari_environment
from Model import atari_wrappers
# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")
parser.add_argument("--env", type=str, default="PongNoFrameskip-v4",
help="Gym environment.")
parser.add_argument("--level_name", type=str, default="explore_goal_locations_small",
help="dmlab30 level name")
parser.add_argument("--mode", default="train",
choices=["train", "test", "test_render"],
help="Training or test mode.")
parser.add_argument("--xpid", default=None,
help="Experiment id (default: None).")
parser.add_argument("--sleep_length", default=20, type=int,
help="time between print statements from main thread")
# Architecture setting
parser.add_argument("--n_layer", default=4, type=int,
help="num layers in transformer decoder")
parser.add_argument("--d_inner", default=2048, type=int,
help="the position wise ff network dimension -> d_model x d_inner")
parser.add_argument("--use_gate", action='store_true',
help="whether to use gating in transformer decoder")
#Adaptive Transformer Settings
parser.add_argument("--use_adaptive", action='store_true',
help="whether to use the adaptive transformer, if not use TXL")
parser.add_argument("--attn_span", default=1024, type=int,
help="Attention span of adaptive transformer")
parser.add_argument("--pers_mem_size", default=0, type=int,
help="Number of persistent memory vectors")
parser.add_argument("--adapt_span_loss", default=0.000002, type=float,
help="the loss coefficient for span lengths")
parser.add_argument("--adapt_span_ramp", default=32, type=int,
help="ramp length of the soft masking function")
parser.add_argument("--adapt_span_init", default=0.3, type=float,
help="initial attention span ratio")
parser.add_argument("--adapt_span_cache", action='store_true',
help="adapt cache size as well to reduce memory usage")
parser.add_argument("--dropout", default=0.1, type=float,
help="dropout rate of ReLU and attention in adaptive model")
# Training settings.
parser.add_argument('--learner_no_mem', action='store_true',
help='if true then learner function doesnt use memory')
parser.add_argument('--debug', action='store_true',
help='set logging level to debug')
parser.add_argument("--atari", default=False, type=bool,
help="Whether to run atari (otherwise runs DMLab)")
parser.add_argument("--disable_checkpoint", action="store_true",
help="Disable saving checkpoint.")
parser.add_argument("--savedir", default="./logs/torchbeast",
help="Root dir where experiment data will be saved.")
parser.add_argument("--num_actors", default=32, type=int, metavar="N",
help="Number of actors (default: 4).")
parser.add_argument("--total_steps", default=100000, type=int, metavar="T",
help="Total environment steps to train for.")
parser.add_argument("--batch_size", default=16, type=int, metavar="B",
help="Learner batch size.")
parser.add_argument("--unroll_length", default=1000, type=int, metavar="T",
help="The unroll length (time dimension).")
parser.add_argument("--num_buffers", default=None, type=int,
metavar="N", help="Number of shared-memory buffers.")
parser.add_argument("--num_learner_threads", "--num_threads", default=4, type =int,
metavar="N", help="Number learner threads.")
parser.add_argument("--disable_cuda", action="store_true",
help="Disable CUDA.")
parser.add_argument("--chunk_size", default=100, type=int,
help="Size of chunks to chop batch into")
parser.add_argument("--mem_len", default=100, type=int,
help="Length of memory segment for TXL")
parser.add_argument('--use_pretrained', action='store_true',
help='use the pretrained model identified by --xpid')
parser.add_argument('--action_repeat', default=4, type=int,
help='number of times to repeat an action, default=4')
parser.add_argument('--stats_episodes', default=100, type=int,
help='report the mean episode returns of the last n episodes')
parser.add_argument("--use_lstm", action="store_true",
help="Use LSTM in agent model.")
# Loss settings.
parser.add_argument("--entropy_cost", default=0.01,
type=float, help="Entropy cost/multiplier.")
parser.add_argument("--baseline_cost", default=0.5,
type=float, help="Baseline cost/multiplier.")
parser.add_argument("--discounting", default=0.99,
type=float, help="Discounting factor.")
parser.add_argument("--reward_clipping", default="abs_one",
choices=["abs_one", "none"],
help="Reward clipping.")
# Optimizer settings.
parser.add_argument("--weight_decay", default=0.0,
type=float)
parser.add_argument("--learning_rate", default=0.0004,
type=float, metavar="LR", help="Learning rate.")
parser.add_argument("--alpha", default=0.99, type=float,
help="RMSProp smoothing constant.")
parser.add_argument("--momentum", default=0, type=float,
help="momentum for SGD or RMSProp")
parser.add_argument("--epsilon", default=0.01, type=float,
help="RMSProp epsilon.")
parser.add_argument("--grad_norm_clipping", default=40.0, type=float,
help="Global gradient norm clip.")
parser.add_argument('--optim', default='RMSProp', type=str,
choices=['adam', 'sgd', 'adagrad, RMSProp'],
help='optimizer to use.')
parser.add_argument('--scheduler', default='cosine', type=str,
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant', 'torchLR','linear_decay'],
help='lr scheduler to use.')
parser.add_argument('--warmup_step', type=float, default=0,
help='upper epoch limit')
parser.add_argument('--steps_btw_sched_updates', type=int, default=10000,
help='number of steps between scheduler updates')
parser.add_argument('--decay_rate', type=float, default=0.5,
help='decay factor when ReduceLROnPlateau is used')
parser.add_argument('--lr_min', type=float, default=0.0,
help='minimum learning rate during annealing')
parser.add_argument('--static-loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can '
'improve fp16 convergence.')
parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.')
parser.add_argument('--eta_min', type=float, default=0.0,
help='min learning rate for cosine scheduler')
# yapf: enable
logging.basicConfig(
format=(
"[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
),
level=0,
)
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def compute_baseline_loss(advantages, padding_mask):
if padding_mask is not None:
advantages = advantages * padding_mask
return 0.5 * torch.sum(advantages ** 2)
# padding_mask has 0's wherever padding should mask the logits
def compute_entropy_loss(logits, padding_mask):
"""Return the entropy loss, i.e., the negative entropy of the policy."""
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
if padding_mask is not None:
log_policy = log_policy * padding_mask.unsqueeze(2)
return torch.sum(policy * log_policy)
def compute_policy_gradient_loss(logits, actions, advantages, padding_mask):
cross_entropy = F.nll_loss(
F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
target=torch.flatten(actions, 0, 1),
reduction="none",
)
cross_entropy = cross_entropy.view_as(advantages)
if padding_mask is not None:
cross_entropy = cross_entropy * padding_mask
return torch.sum(cross_entropy * advantages.detach())
def act(
flags,
actor_index: int,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
model: torch.nn.Module,
buffers: Buffers,
initial_agent_state_buffers,
level_name
):
try:
logging.info("Actor %i started.", actor_index)
timings = prof.Timings() # Keep track of how fast things are.
seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
# gym_env.seed(seed)
gym_env = create_env(flags=flags, seed=seed)
if flags.atari:
env = atari_environment.Environment(gym_env)
else:
#DMLAB CHANGES
env = dmlab_environment.Environment(gym_env)
env_output = env.initial()
env_output['done'] = torch.tensor([[0]], dtype=torch.uint8)
agent_state = model.initial_state(batch_size=1)
mems = None
if flags.use_adaptive:
mems = model.core.initial_cache(batch_size=1, device=None)
agent_output, unused_state, mems, pad_mask1, _ = model(env_output, agent_state, mems)
while True:
index = free_queue.get()
if index is None:
break
# explicitly make done False to allow the loop to run
# Don't need to set 'done' to true since now take step out of done state
# when do arrive at 'done'
# env_output['done'] = torch.tensor([0], dtype=torch.uint8)
# Write old rollout end.
for key in env_output:
buffers[key][index][0, ...] = env_output[key]
for key in agent_output:
buffers[key][index][0, ...] = agent_output[key]
for i, tensor in enumerate(agent_state):
initial_agent_state_buffers[index][i][...] = tensor
# Do one new rollout, untill flags.unroll_length
t = 0
logging.debug('STARTING UP ACTOR: %i', actor_index)
while t < flags.unroll_length and not env_output['done'].item():
# for t in range(flags.unroll_length):
timings.reset()
# REmoved since never this will never be true (MOVED TO AFTER FOR LOOP)
# if env_output['done'].item():
# mems = None
with torch.no_grad():
agent_output, agent_state, mems, pad_mask1, _ = model(env_output, agent_state, mems)
#if actor_index == 0:
# logging.debug('actor: t: {}, mems size: {}, mem_padding size: {}'.format(t, mems[0].shape, mem_padding))
timings.time("model")
# TODO : Check if this probability skipping can compromise granularity
# repeat_times = torch.randint(low=2, high=flags.action_repeat + 1, size=(1,)).item()
for el in range(flags.action_repeat):
env_output = env.step(agent_output["action"])
if env_output['done'].item():
break
timings.time("step")
for key in env_output:
buffers[key][index][t + 1, ...] = env_output[key]
for key in agent_output:
buffers[key][index][t + 1, ...] = agent_output[key]
timings.time("write")
t += 1
if env_output['done'].item():
#for key in env_output:
# buffers[key][index][t + 1, ...] = env_output[key]
#for key in agent_output:
# buffers[key][index][t + 1, ...] = agent_output[key]
mems = None
if flags.use_adaptive:
mems = model.core.initial_cache(batch_size=1, device=None)
# Take arbitrary step to reset environment
logging.debug('actor: {}, RETURN: {}'.format(actor_index, env_output['episode_return']))
env_output = env.step(torch.tensor([2]))
buffers['len_traj'][index][0] = t
if t != flags.unroll_length:
# TODO Is there a potential bug here
buffers['done'][index][t + 1:] = torch.tensor([True]).repeat(flags.unroll_length - t)
#logging.debug('Done rollout actor: %i', actor_index)
full_queue.put(index)
if actor_index == 0:
logging.info("Actor %i: %s", actor_index, timings.summary())
except KeyboardInterrupt:
pass # Return silently.
except Exception as e:
logging.error("Exception in worker process %i", actor_index)
traceback.print_exc()
# print()
raise e
def get_batch(
flags,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
buffers: Buffers,
initial_agent_state_buffers,
timings,
lock=threading.Lock(),
):
logging.debug('STARTING GET_BATCH')
with lock:
timings.time("lock")
indices = [full_queue.get() for _ in range(flags.batch_size)]
# TODO: Check if emptying full_queue and then readding to it takes very long,
# seems like the only way to ensure a batch of similar length elements
# One problem with doing this is that if get a really short trajectory, may never end up
# using it. DONT CHANGE THIS FOR NOW.
timings.time("dequeue")
batch = {
key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
}
initial_agent_state = (
torch.cat(ts, dim=1)
for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
)
timings.time("batch")
for m in indices:
free_queue.put(m)
timings.time("enqueue")
batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
initial_agent_state = tuple(
t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
)
timings.time("device")
logging.debug('Returned GetBATCH')
return batch, initial_agent_state
def learn(
flags,
actor_model,
model,
batch,
initial_agent_state,
optimizer,
scheduler,
lock=threading.Lock(), # noqa: B008
):
"""Performs a learning (optimization) step."""
with lock:
"""
put a lock on the central learner,
send the trajectories to it.
Update the parameters of the central learner,
copy the parameters of the central learner back to the actors
"""
# TODO: Chop up batch into smaller pieces to run through TXL one at a time (caching previous as memory)
# TODO: Change batch function to look for trajectories of similar lengths
# TODO: Add in adaptive attention (and think of how things change (for ex no memory))
# print({key: batch[key].shape for key in batch})
mems = None
if flags.use_adaptive:
mems = model.core.initial_cache(batch_size=flags.batch_size, device=flags.device)
# initialize stats
stats = {
"episode_returns": list(),
"mean_episode_return": list(),
"total_loss": 0,
"pg_loss": 0,
"baseline_loss": 0,
"entropy_loss": 0,
"num_unpadded_steps": 0,
"len_max_traj": 0,
"learning_rate":optimizer.param_groups[0]['lr']
}
logging.debug('AT LEARN')
for i in range(0, flags.unroll_length + 1, flags.chunk_size):
mini_batch = {key: batch[key][i:i + flags.chunk_size] for key in batch if key != 'len_traj'}
# Note that initial agent state isn't used by transformer (I think this is hidden state)
# Will need to change if want to use this with LSTM
if mini_batch['done'].shape[0] != flags.chunk_size:
logging.debug('BREAKING WITH SHAPE : %s', mini_batch['done'].shape)
break #This would break around memory padding
#TODO Trim mini_batch if all dones at the end: If everything is done just continue here
# CAN DO THIS by looking at buffers['len_traj']
# For now just say that if more than half the minibatch is done, then continue
mini_batch_size = torch.prod(torch.tensor(mini_batch['done'].size())).item()
if mini_batch['done'].sum().item() ==mini_batch_size: #> mini_batch_size / 2:
logging.debug('Breaking with all elements done') #Breaking with more than half elements done')
break
#if mini_batch['done'].sum().item() > 0:
# print(mini_batch['done'])
# print('FOUND ONE')
logging.debug('MiniBatch shape: %s', mini_batch['done'].shape)
tmp_mask = torch.zeros_like(mini_batch["done"]).bool()
if flags.learner_no_mem:
mems = None
learner_outputs, unused_state, mems, curpad_mask, ind_first_done = model(mini_batch, initial_agent_state,
mems=mems)
if mini_batch['done'].any():
print('********Should see some return*********')
# www = time.time()
# torch.save(mini_batch['done'],'./'+str(www)+'mini_batch_done.pt')
# print("mini_batch['done'] true at ", www)
# torch.save(ind_first_done, './' + str(www) + 'ind_first_done.pt')
#to_print = False
#if mini_batch['done'].sum().item() > 0:
# print('INds done: ', ind_first_done)
# print('MEM PADDING AFTER: ', mem_padding)
# to_print = True
# Here mem_padding is same as "batch" padding for this iteration so can use
# for masking loss
# if mini_batch["done"].any().item():
# print('Indfirstdone: ',ind_first_done)
# print('miniBATCH DONE: ', mini_batch["done"])
# print('Mem padding: ', mem_padding)
# Take final value function slice for bootstrapping.
# this is the final value from this trajectory
if ind_first_done is not None:
# B dimensional tensor
bootstrap_value = learner_outputs["baseline"][ind_first_done, range(flags.batch_size)]
else:
bootstrap_value = learner_outputs["baseline"][-1]
# Move from obs[t] -> action[t] to action[t] -> obs[t].
mini_batch = {key: tensor[1:] for key, tensor in mini_batch.items()}
learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()}
# Using learner_outputs to predict batch since batch is always one ahead of learner_outputs?
rewards = mini_batch["reward"]
if flags.reward_clipping == "abs_one":
clipped_rewards = torch.clamp(rewards, -1, 1)
elif flags.reward_clipping == "none":
clipped_rewards = rewards
discounts = (~mini_batch["done"]).float() * flags.discounting
vtrace_returns = vtrace.from_logits(
behavior_policy_logits=mini_batch["policy_logits"],
target_policy_logits=learner_outputs["policy_logits"], # WHY IS THIS THE TARGET?
actions=mini_batch["action"],
discounts=discounts,
rewards=clipped_rewards,
values=learner_outputs["baseline"],
bootstrap_value=bootstrap_value,
ind_first_done=ind_first_done, # -1 to compensate the one shifted arrays will
# be taken care in the function from_importance_weights
)
# TODO Next Step: the losses also have to be computed with the padding, think on a structure of mask
# to do this efficiently
# Advantages are [rollout_len, batch_size]
# First we mask out vtrace_returns.pg_advantages where there is padding which fixes pg_loss
pad_mask = (~(curpad_mask.squeeze(0)[1:])).float() if curpad_mask is not None else None
#if to_print:
# print('AFTER WARDS 2 mem_padding: ', mem_padding)
# print('Pad_mask: ', pad_mask)
pg_loss = compute_policy_gradient_loss(
learner_outputs["policy_logits"],
mini_batch["action"],
vtrace_returns.pg_advantages,
pad_mask
)
baseline_loss = flags.baseline_cost * compute_baseline_loss(
vtrace_returns.vs - learner_outputs["baseline"],
pad_mask
)
entropy_loss = flags.entropy_cost * compute_entropy_loss(
learner_outputs["policy_logits"],
pad_mask
)
total_loss = pg_loss + baseline_loss + entropy_loss
#Now adding L1 norm of adaptive span params (is already multiplied
#by scaling coefficient (chosen hyper param)).
if flags.use_adaptive:
total_loss += model.core.get_adaptive_span_loss()
# tmp_mask is defined above
if ind_first_done is not None:
rows_to_use = []
cols_to_use = []
for i, val in enumerate(ind_first_done):
if val != -1:
rows_to_use.append(val)
cols_to_use.append(i)
tmp_mask[rows_to_use, cols_to_use] = True # NOT RIGHT FOR COLS THAT DIDNT FINISH
tmp_mask = tmp_mask[1:] # This is how they initially had it so will keep like this
# if mini_batch["done"].any().item():
# print('TMP MASK: ',tmp_mask)
# print('BATCH DONE: ', mini_batch["done"])
# print('shape1: {}, shape2: {}'.format(tmp_mask.shape, mini_batch['done'].shape))
# episode_returns = mini_batch["episode_return"][mini_batch["done"]]
episode_returns = mini_batch["episode_return"][tmp_mask]
num_unpadded_steps = (~curpad_mask).sum().item() if curpad_mask is not None else mini_batch_size
stats_per_chunk = {
"episode_returns": tuple(episode_returns.cpu().numpy()),
"mean_episode_return": torch.mean(episode_returns).item(),
"total_loss": total_loss.item(),
"pg_loss": pg_loss.item(),
"baseline_loss": baseline_loss.item(),
"entropy_loss": entropy_loss.item(),
"num_unpadded_steps": num_unpadded_steps,
"len_max_traj": batch['len_traj'].max().item()
}
logging.debug('in learn with stats_per_chunk : %s', str(stats_per_chunk))
optimizer.zero_grad()
total_loss.backward()
# append the current stats_per_chunk with overall stats
stats['episode_returns'].extend(tuple(episode_returns.cpu().numpy()))
if torch.mean(episode_returns).item() == torch.mean(episode_returns).item():
stats["mean_episode_return"].append(torch.mean(episode_returns).item()),
stats["total_loss"] += total_loss.item()
stats["pg_loss"] += pg_loss.item()
stats["baseline_loss"] += baseline_loss.item()
stats["entropy_loss"] += entropy_loss.item()
stats["num_unpadded_steps"] += num_unpadded_steps
nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
optimizer.step()
# scheduler is being stepped in the lock of batch_and_learn itself
# update len_max_traj separately since it doesnt depend on minibatches
stats["len_max_traj"] = batch['len_traj'].max().item()
# update the losses as the mean
total_num_minibatches = (flags.unroll_length + 1) // flags.chunk_size
stats["mean_episode_return"] = sum(stats["mean_episode_return"]) / total_num_minibatches
stats["total_loss"] /= total_num_minibatches
stats["pg_loss"] /= total_num_minibatches
stats["baseline_loss"] /= total_num_minibatches
stats["entropy_loss"] /= total_num_minibatches
actor_model.load_state_dict(model.state_dict())
return stats
def create_buffers(flags, obs_shape, num_actions) -> Buffers:
T = flags.unroll_length
specs = dict(
frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8),
reward=dict(size=(T + 1,), dtype=torch.float32),
done=dict(size=(T + 1,), dtype=torch.bool),
episode_return=dict(size=(T + 1,), dtype=torch.float32),
episode_step=dict(size=(T + 1,), dtype=torch.int32),
policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32),
baseline=dict(size=(T + 1,), dtype=torch.float32),
last_action=dict(size=(T + 1,), dtype=torch.int64),
action=dict(size=(T + 1,), dtype=torch.int64),
len_traj=dict(size=(1,), dtype=torch.int32) # is min(length til trajectory is done, T)
)
buffers: Buffers = {key: [] for key in specs}
for _ in range(flags.num_buffers):
for key in buffers:
buffers[key].append(torch.zeros(**specs[key]).share_memory_())
return buffers
def get_optimizer(flags, parameters):
optimizer = None
if flags.optim.lower() == 'sgd':
optimizer = torch.optim.SGD(parameters, lr=flags.learning_rate, momentum=flags.momentum)
elif flags.optim.lower() == 'adam':
optimizer = torch.optim.Adam(parameters, lr=flags.learning_rate)
elif flags.optim.lower() == 'adagrad':
optimizer = torch.optim.Adagrad(parameters, lr=flags.learning_rate)
return optimizer
def get_scheduler(flags, optimizer):
scheduler = None
if flags.scheduler == 'cosine':
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
(flags.total_steps-flags.warmup_step) // flags.steps_btw_sched_updates,
eta_min=flags.eta_min)
elif flags.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and flags.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > flags.warmup_step \
else step / (flags.warmup_step ** 1.5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif flags.scheduler == 'dev_perf':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=flags.decay_rate, patience=flags.patience,
min_lr=flags.lr_min)
elif flags.scheduler == 'constant':
pass
return scheduler
def train(flags): # pylint: disable=too-many-branches, too-many-statements
if flags.debug:
logging.root.setLevel(level=logging.DEBUG)
else:
logging.root.setLevel(level=logging.INFO)
logging.debug('First Debug message')
# load the previous config if use_pretrained is true
if flags.use_pretrained:
logging.info('Using Pretrained Model')
#TODO Check if this loading below works properly
class Bunch(object):
def __init__(self, adict):
self.__dict__.update(adict)
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'logs/torchbeast/' + flags.xpid + '/model.tar')
pretrained_model = torch.load(model_path, map_location='cpu' if flags.disable_cuda else 'gpu')
flags = Bunch(pretrained_model['flags'])
flags.use_pretrained = True
if flags.xpid is None:
flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
plogger = file_writer.FileWriter(
xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
)
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)
if flags.num_buffers is None: # Set sensible default for num_buffers.
flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
if flags.num_actors >= flags.num_buffers:
raise ValueError("num_buffers should be larger than num_actors")
if flags.num_buffers < flags.batch_size:
raise ValueError("num_buffers should be larger than batch_size")
T = flags.unroll_length
B = flags.batch_size
flags.device = None
if not flags.disable_cuda and torch.cuda.is_available():
logging.info("Using CUDA.")
flags.device = torch.device("cuda")
else:
logging.info("Not using CUDA.")
flags.device = torch.device("cpu")
env = create_env(flags)
if flags.atari:
"""model is each of the actors, running parallel. The upcoming block ctx.Process(...)"""
model = Net(env.observation_space.shape, env.action_space.n, flags=flags)
buffers = create_buffers(flags, env.observation_space.shape, model.num_actions)
else:
# DMLAB CHANGES
"""model is each of the actors, running parallel. The upcoming block ctx.Process(...)"""
model = Net(env.initial().shape, len(dmlab_environment.DEFAULT_ACTION_SET), flags=flags)
buffers = create_buffers(flags, env._observation().shape, model.num_actions)
model.share_memory()
# Add initial RNN state.
initial_agent_state_buffers = []
for _ in range(flags.num_buffers):
state = model.initial_state(batch_size=1)
for t in state:
t.share_memory_()
initial_agent_state_buffers.append(state)
actor_processes = []
ctx = mp.get_context("fork")
free_queue = ctx.SimpleQueue()
full_queue = ctx.SimpleQueue()
for i in range(flags.num_actors):
actor = ctx.Process(
target=act,
args=(
flags,
i,
free_queue,
full_queue,
model,
buffers,
initial_agent_state_buffers,
flags.level_name
),
)
actor.start()
actor_processes.append(actor)
"""learner_model is the central learner, which takes in the experiences and updates itself"""
if flags.atari:
learner_model = Net(
env.observation_space.shape, env.action_space.n, flags=flags).to(device=flags.device)
else:
# DMLAB CHANGES
learner_model = Net(
env._observation().shape, len(dmlab_environment.DEFAULT_ACTION_SET), flags=flags).to(device=flags.device)
# DMLAB CHANGES END
print('--------------- TOTAL MODEL PARAMETERS : {} ---------------'.format(get_model_parameters(learner_model)))
optimizer = get_optimizer(flags, learner_model.parameters())
if optimizer is None:
# Use the default optimizer used in monobeast
optimizer = torch.optim.RMSprop(
learner_model.parameters(),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha,
weight_decay=flags.weight_decay
)
def lr_lambda(epoch):
return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps
scheduler = get_scheduler(flags, optimizer)
if scheduler is None:
# use the default scheduler as used in monobeast
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
last_n_episode_return_key = "last_{}_episode_returns".format(flags.stats_episodes)
logger = logging.getLogger("logfile")
stat_keys = [
"total_loss",
"mean_episode_return",
last_n_episode_return_key,
"max_return_achieved",
"pg_loss",
"baseline_loss",
"entropy_loss",
"learning_rate",
]
logger.info("# Step\t%s", "\t".join(stat_keys))
step, stats = 0, {}
last_n_episode_returns = torch.zeros((flags.stats_episodes))
steps_since_sched_update = 0
if flags.use_pretrained:
logging.info('Using Pretrained Model -> loading learner_model, optimizer, scheduler states')
learner_model.load_state_dict(pretrained_model['model_state_dict'])
optimizer.load_state_dict(pretrained_model['optimizer_state_dict'])
scheduler.load_state_dict(pretrained_model['scheduler_state_dict'])
def batch_and_learn(i, lock=threading.Lock()):
"""Thread target for the learning process."""
nonlocal step, stats, steps_since_sched_update, last_n_episode_returns
# TODO : last_n_episode_returns and curr_index will be screwed if you use 1+ learner threads, keep in mind
curr_index = -1
max_return = -1e5
max_return_step = 0
timings = prof.Timings()
while step < flags.total_steps:
timings.reset()
zz1 = time.time()
batch, agent_state = get_batch(
flags,
free_queue,
full_queue,
buffers,
initial_agent_state_buffers,
timings,
)
zz2 = time.time()
logging.debug('Before Learn')
stats = learn(
flags, model, learner_model, batch, agent_state, optimizer, scheduler
)
logging.debug('After Learn')
logging.debug('stats: %s ', stats)
timings.time("learn")
with lock:
# step-wise learning rate annealing
if flags.scheduler in ['cosine', 'constant', 'dev_perf','linear_decay']:
# linear warmup stage
if step < flags.warmup_step:
curr_lr = flags.learning_rate * step / flags.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
elif flags.scheduler == 'cosine':
#TODO: Right now number of steps to do depends on T and B, which isn't ideal.
#Instead will
#Is better to step based on number of non padded entries in the padding mask.
#Can make when we take a step be conditional on the step number (maybe each
#10000 we step or so.
if steps_since_sched_update >= flags.steps_btw_sched_updates:
scheduler.step()
steps_since_sched_update = 0
elif flags.scheduler == 'linear_decay':
#print('LR before: ', optimizer.param_groups[0]['lr'])
multiplier = 1-min(step,flags.total_steps)/flags.total_steps
optimizer.param_groups[0]['lr'] = flags.learning_rate * multiplier
#print('LR AFTER : ',optimizer.param_groups[0]['lr'])
elif flags.scheduler == 'inv_sqrt':
scheduler.step()
episode_returns = stats.get("episode_returns", None)
if episode_returns:
for el in episode_returns:
last_n_episode_returns[(curr_index + 1) % flags.stats_episodes] = el.item()
curr_index += 1
if el.item() >= max_return:
max_return = el.item()
max_return_step = step
stats.update({last_n_episode_return_key: last_n_episode_returns.mean().item()})
stats.update({'max_return_achieved':'{} at step {}'.format(max_return, max_return_step)})
to_log = dict(step=step)
to_log.update({k: stats.get(k, None) for k in stat_keys})
# Now keep track of the max span per layer and log them in the csv file if adaptive is enabled
if flags.use_adaptive:
# Get max span per layer in learner_model
max_spans = []
for layer in learner_model.core.layers:
max_spans.append(layer.attn.attn.adaptive_span._mask.get_current_max_size())
print('MAX SPANS : ', max_spans)
for i, span_val in enumerate(max_spans):
to_log.update({'max_span_layer_'+str(i): span_val})
plogger.log(to_log)
# print('updating step from {} to {}'.format(step, step+(T*B)))
if len(stats) > 0:
step += stats['num_unpadded_steps'] #stats.get('num_unpadded_steps', 0) #T * B
steps_since_sched_update += stats['num_unpadded_steps'] #.get('num_unpadded_steps', 0)
print('act took : ',zz2-zz1,' learn took : ',time.time()-zz2)
if i == 0:
logging.info("Batch and learn: %s", timings.summary())
for m in range(flags.num_buffers):
free_queue.put(m)
threads = []
for i in range(flags.num_learner_threads):
thread = threading.Thread(
target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,)
)
thread.start()
threads.append(thread)
logging.debug('FINSIHED starting batchand learn')
def checkpoint():
if flags.disable_checkpoint:
return
logging.info("Saving checkpoint to %s", checkpointpath)
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"flags": vars(flags),
},
checkpointpath,
)
timer = timeit.default_timer
try:
last_checkpoint_time = timer()
logging.debug('initialized stats_eposiodes')
while step < flags.total_steps:
start_step = step
start_time = timer()
time.sleep(flags.sleep_length)
if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min.
checkpoint()
last_checkpoint_time = timer()
sps = (step - start_step) / (timer() - start_time)
episode_returns = stats.get("episode_returns", None)
if episode_returns:
mean_return = (
"Return per episode: %.1f. " % stats["mean_episode_return"]
)
# print(episode_returns)
# print(type(episode_returns[0]))
# torch.save(episode_returns, './ep_return.pt')
else:
mean_return = ""
total_loss = stats.get("total_loss", float("inf"))
# TODO : We also should save the model if the loss is the best loss seen so far
# TODO : call checkpoint() here with some differen prefix
# if not best_val_loss or val_loss < best_val_loss:
# if not args.debug:
# with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
# torch.save(model, f)
# with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
# torch.save(optimizer.state_dict(), f)
# best_val_loss = val_loss
logging.info(
"Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
step,
sps,
total_loss,
mean_return,
pprint.pformat(stats),
)
except KeyboardInterrupt:
return # Try joining actors then quit.
else:
for thread in threads:
thread.join()
logging.info("Learning finished after %d steps.", step)
finally:
for _ in range(flags.num_actors):
free_queue.put(None)
for actor in actor_processes:
actor.join(timeout=1)
checkpoint()
plogger.close()
def test(flags, num_episodes: int = 10):
if flags.xpid is None:
checkpointpath = "./latest/model.tar"
else:
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)
gym_env = create_env(flags)
if flags.atari:
env = atari_environment.Environment(gym_env)
else:
#DMLAB CHANGES
env = dmlab_environment.Environment(gym_env)
if flags.atari:
model = Net(env.gym_env.observation_space.shape, env.action_space.n, flags=flags)
else:
model = Net(env.initial().shape, len(dmlab_environment.DEFAULT_ACTION_SET), flags=flags)
model.eval()
checkpoint = torch.load(checkpointpath, map_location="cpu")