Skip to content

Commit

Permalink
Merge pull request #6 from jerrodparker20/gcpDebug
Browse files Browse the repository at this point in the history
Gcp debug
  • Loading branch information
shaktikshri authored Mar 31, 2020
2 parents 677f27e + 7100c10 commit 3d82324
Show file tree
Hide file tree
Showing 24 changed files with 4,336 additions and 181 deletions.
1 change: 0 additions & 1 deletion Comparative Analysis of DRL and ERL.gslides

This file was deleted.

1 change: 0 additions & 1 deletion LearningLatentBehaviors.gdoc

This file was deleted.

1 change: 0 additions & 1 deletion Model/logs/torchbeast/example/fields.csv

This file was deleted.

50 changes: 0 additions & 50 deletions Model/logs/torchbeast/example/logs.csv

This file was deleted.

79 changes: 0 additions & 79 deletions Model/logs/torchbeast/example/meta.json

This file was deleted.

6 changes: 0 additions & 6 deletions Model/logs/torchbeast/example/out.log

This file was deleted.

1 change: 0 additions & 1 deletion Model/logs/torchbeast/latest

This file was deleted.

1 change: 0 additions & 1 deletion Presentation Outline.gdoc

This file was deleted.

1 change: 0 additions & 1 deletion ProjectImplementationDetails.gdoc

This file was deleted.

1 change: 0 additions & 1 deletion README.md

This file was deleted.

49 changes: 13 additions & 36 deletions StableTransformersReplication/transformer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class MemTransformerLM(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=None,
div_val=1,
tgt_len=None, ext_len=None, mem_len=1,
tgt_len=None, ext_len=0, mem_len=1,
cutoffs=[], adapt_inp=False,
same_length=False, clamp_len=-1,
use_gate=True, use_stable_version=True):
Expand All @@ -354,7 +354,7 @@ def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
#self.max_klen = tgt_len + ext_len + mem_len

self.layers = nn.ModuleList()

Expand Down Expand Up @@ -419,8 +419,13 @@ def _update_mems(self, hids, mems, qlen, mlen):
end_idx = mlen + max(0, qlen - 0 - self.ext_len) # ext_len looks to usually be 0 (in their experiments anyways

# TODO: I have changed beg_idx to 0 since want to use all memory, may want to change
# this once move to larger environments
beg_idx = 0 #max(0, end_idx - self.mem_len)
# this once move to larger environments (THIS HAS NOW BEEN CHANGED)

#HERE IS THE PROBLEM.
#print('hids shape: ', hids[0].shape)

beg_idx = max(0, end_idx - self.mem_len) #if hids[0].shape[0] > 1 else 0
#print('BEG IND: ', beg_idx)
for i in range(len(hids)):

cat = torch.cat([mems[i], hids[i]], dim=0)
Expand All @@ -432,7 +437,7 @@ def _update_mems(self, hids, mems, qlen, mlen):
# TODO : We dropped dec_input since the first 2 dims of obs_emb should be the same as
# that of dec_input, which is unrolled length = query length and batch_size
# we saw this from core_input = core_input.view(T, B, -1) line 668 in monobeast_test.py
def _forward(self, obs_emb, padding_mask, mems=None):
def _forward(self, obs_emb, mems=None):

qlen, bsz, _ = obs_emb.size() #qlen is number of characters in input ex

Expand All @@ -449,20 +454,6 @@ def _forward(self, obs_emb, padding_mask, mems=None):
dec_attn_mask = torch.triu(
obs_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]

# TODO: Possibly make this more efficient (check how much things slow down)
# This part only runs when calling model in "learn" since in "act" we will
# never need padding
if not (padding_mask is None):
# concat the memory padding along with the padding_mask
dec_attn_mask = dec_attn_mask.repeat(1,1,bsz)
dec_attn_mask = dec_attn_mask | padding_mask
#print('Dec_attn_mask: ', dec_attn_mask[:,:,0])
#want the mlen diagonal to be 0's so that each query can attend
#to itself
dec_attn_mask[range(qlen), range(mlen, klen), :] = False
#print('AFTER: ', dec_attn_mask[:,:,0])
#print('ATTN SHAPE: ', dec_attn_mask.shape)

hids = []
pos_seq = torch.arange(klen-1, -1, -1.0, device=obs_emb.device,
dtype=obs_emb.dtype)
Expand All @@ -488,33 +479,19 @@ def _forward(self, obs_emb, padding_mask, mems=None):


core_out = self.drop(core_out)

#print('before update mems hids shape: {}, mems shape {}'.format(hids[0].shape,mems[0].shape if mems else None))
new_mems = self._update_mems(hids, mems, mlen, qlen)

return core_out, new_mems


def forward(self, data, mems, padding_mask, mem_padding):
#padding_mask should be shape 1 X (mlen+qlen) X batch_size,
#which we apply row wise
def forward(self, data, mems):

if not mems:
# print('INITIALIZED MEMS')
mems = self.init_mems()

#Concatenate mem_padding and padding_mask (slight modifications if None)
padding_mask2 = mem_padding
if padding_mask2 is None:
padding_mask2 = padding_mask
elif padding_mask is not None:
padding_mask2 = torch.cat([mem_padding, padding_mask], dim=1)

if mem_padding is not None and padding_mask is not None:
print('Adding orig: ', padding_mask[:,:,0])
print('mem_padding: ', mem_padding[:,:,0])
print('Result: ', padding_mask2[:,:,0])

hidden, new_mems = self._forward(data, padding_mask=padding_mask2, mems=mems)
hidden, new_mems = self._forward(data, mems=mems)

return hidden, new_mems

Expand Down
1 change: 0 additions & 1 deletion Transformer-XLCode/transformer-xl
Submodule transformer-xl deleted from 9ba2e2
1 change: 0 additions & 1 deletion hyperparameter search.gsheet

This file was deleted.

Loading

0 comments on commit 3d82324

Please sign in to comment.