diff --git a/train.py b/train.py index 9709f90..cefb2b5 100644 --- a/train.py +++ b/train.py @@ -444,7 +444,8 @@ def learn( learner_outputs, unused_state, mems, curpad_mask, ind_first_done = model(mini_batch, initial_agent_state, mems=mems) - # if mini_batch['done'].any(): + if mini_batch['done'].any(): + print('********Should see some return*********') # www = time.time() # torch.save(mini_batch['done'],'./'+str(www)+'mini_batch_done.pt') # print("mini_batch['done'] true at ", www)