Skip to content

Commit

Permalink
No need for masking pad's in attention scores since we do not care ab…
Browse files Browse the repository at this point in the history
…out the pad embedding.

Any non padded token only attends backwards anyways which means it won't attend to padding.
  • Loading branch information
jerrodparker20 committed Mar 31, 2020
1 parent 80d1d42 commit 7100c10
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 55 deletions.
49 changes: 5 additions & 44 deletions StableTransformersReplication/transformer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def _update_mems(self, hids, mems, qlen, mlen):
# TODO : We dropped dec_input since the first 2 dims of obs_emb should be the same as
# that of dec_input, which is unrolled length = query length and batch_size
# we saw this from core_input = core_input.view(T, B, -1) line 668 in monobeast_test.py
def _forward(self, obs_emb, padding_mask, mems=None):
def _forward(self, obs_emb, mems=None):

qlen, bsz, _ = obs_emb.size() #qlen is number of characters in input ex

Expand All @@ -454,23 +454,6 @@ def _forward(self, obs_emb, padding_mask, mems=None):
dec_attn_mask = torch.triu(
obs_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]

# TODO: Possibly make this more efficient (check how much things slow down)
# This part only runs when calling model in "learn" since in "act" we will
# never need padding
if not (padding_mask is None):# and padding_mask.sum().item() > 0:
# concat the memory padding along with the padding_mask
#print('IN TXL')
#print('PADDING BEFORE: ', dec_attn_mask[:,:,0])
dec_attn_mask = dec_attn_mask.repeat(1,1,bsz)
dec_attn_mask = dec_attn_mask | padding_mask
#print('PADDING AFTER', [dec_attn_mask[:,:,x] for x in range(bsz)])
#print('Dec_attn_mask: ', dec_attn_mask[:,:,0])
#want the mlen diagonal to be 0's so that each query can attend
#to itself
dec_attn_mask[range(qlen), range(mlen, klen), :] = False
#print('AFTER: ', [dec_attn_mask[:,:,x] for x in range(bsz)])
#print('ATTN SHAPE: ', dec_attn_mask.shape)

hids = []
pos_seq = torch.arange(klen-1, -1, -1.0, device=obs_emb.device,
dtype=obs_emb.dtype)
Expand Down Expand Up @@ -502,37 +485,15 @@ def _forward(self, obs_emb, padding_mask, mems=None):
return core_out, new_mems


def forward(self, data, mems, padding_mask, mem_padding):
#padding_mask should be shape 1 X (mlen+qlen) X batch_size,
#which we apply row wise
def forward(self, data, mems):

if not mems:
# print('INITIALIZED MEMS')
mems = self.init_mems()

#Concatenate mem_padding and padding_mask (slight modifications if None)
padding_mask2 = mem_padding
#print('mem_padding', mem_padding.shape if mem_padding is not None else None)
#print('PADDING MASK: ', padding_mask.shape if padding_mask is not None else None)
if padding_mask2 is None:
padding_mask2 = padding_mask
elif padding_mask is not None:
padding_mask2 = torch.cat([mem_padding, padding_mask], dim=1)

'''
if mem_padding is not None and padding_mask is not None:
print('Adding orig: ', padding_mask[:,:,0])
print('mem_padding: ', mem_padding[:,:,0])
print('Result: ', padding_mask2[:,:,0])
print('DATA shape: ', data.shape)
print('mems shape: ', mems[0].shape)
'''
hidden, new_mems = self._forward(data, padding_mask=padding_mask2, mems=mems)

if padding_mask2 is not None:
padding_mask2 = padding_mask2[:,-self.mem_len:,:] #will me memory_padding at next iteration.

return hidden, new_mems, padding_mask2
hidden, new_mems = self._forward(data, mems=mems)

return hidden, new_mems


if __name__ == '__main__':
Expand Down
20 changes: 9 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def act(

agent_state = model.initial_state(batch_size=1)
mems, mem_padding = None, None
agent_output, unused_state, mems, mem_padding, pad_mask1, _ = model(env_output, agent_state, mems, mem_padding)
agent_output, unused_state, mems, pad_mask1, _ = model(env_output, agent_state, mems)
while True:
index = free_queue.get()
if index is None:
Expand Down Expand Up @@ -254,7 +254,7 @@ def act(
# mems = None

with torch.no_grad():
agent_output, agent_state, mems, mem_padding, pad_mask1, _ = model(env_output, agent_state, mems, mem_padding)
agent_output, agent_state, mems, pad_mask1, _ = model(env_output, agent_state, mems)
#if actor_index == 0:
# logging.debug('actor: t: {}, mems size: {}, mem_padding size: {}'.format(t, mems[0].shape, mem_padding))
timings.time("model")
Expand Down Expand Up @@ -413,10 +413,9 @@ def learn(

if flags.learner_no_mem:
mems = None
mem_padding = None

learner_outputs, unused_state, mems, mem_padding, curpad_mask, ind_first_done = model(mini_batch, initial_agent_state,
mems=mems, mem_padding=mem_padding)
learner_outputs, unused_state, mems, curpad_mask, ind_first_done = model(mini_batch, initial_agent_state,
mems=mems)
# if mini_batch['done'].any():
# www = time.time()
# torch.save(mini_batch['done'],'./'+str(www)+'mini_batch_done.pt')
Expand Down Expand Up @@ -963,8 +962,8 @@ def test(flags, num_episodes: int = 10):
if flags.mode == "test_render":
env.gym_env.render()

agent_outputs, core_state, mems, mem_padding, _, ind_first_done = model(observation, mems=mems,
mem_padding=mem_padding)
agent_outputs, core_state, mems, _, ind_first_done = model(observation, mems=mems)

observation = env.step(agent_outputs["action"])
if observation["done"].item():
returns.append(observation["episode_return"].item())
Expand Down Expand Up @@ -1121,7 +1120,7 @@ def initial_state(self, batch_size):
for _ in range(2)
)

def forward(self, inputs, core_state=(), mems=None, mem_padding=None):
def forward(self, inputs, core_state=(), mems=None):

x = inputs["frame"]
T, B, *_ = x.shape
Expand Down Expand Up @@ -1179,8 +1178,7 @@ def forward(self, inputs, core_state=(), mems=None, mem_padding=None):
# print('before mem mask: ', mem_padding.squeeze())

#Mem_pad_mask is the memory_mask to use at the next iteration
core_output, mems, mem_pad_mask = self.core(core_input, mems, padding_mask=padding_mask,
mem_padding=mem_padding) # core_input is of shape (T, B, ...)
core_output, mems = self.core(core_input, mems) # core_input is of shape (T, B, ...)
# core_output is (B, ...)
#if mem_padding is not None:
# print('padding_mask AFTER : ', padding_mask.squeeze())
Expand Down Expand Up @@ -1219,7 +1217,7 @@ def forward(self, inputs, core_state=(), mems=None, mem_padding=None):

return (
dict(policy_logits=policy_logits, baseline=baseline, action=action),
core_state, mems, mem_pad_mask, padding_mask, ind_first_done
core_state, mems, padding_mask, ind_first_done
)


Expand Down

0 comments on commit 7100c10

Please sign in to comment.