Skip to content

Commit

Permalink
Now only create gated modules if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrodparker20 committed Apr 2, 2020
1 parent 42815da commit 19c9ffb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 5 additions & 2 deletions StableTransformersReplication/transformer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def __init__(self, n_head, d_model, d_head, d_inner, dropout,use_gate,use_stable

self.use_gate = use_gate
self.use_stable_version = use_stable_version
self.gate_mha = GRUGate(d_model)
self.gate_mlp = GRUGate(d_model)

if self.use_gate:
self.gate_mha = GRUGate(d_model)
self.gate_mlp = GRUGate(d_model)

self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)

Expand Down
6 changes: 4 additions & 2 deletions adaptive_span2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ def __init__(self, hidden_size, flags, **kargs):

self.use_gate = flags.use_gate
self.use_stable_version = True #use_stable_version
self.gate_mha = GRUGate(hidden_size)
self.gate_mlp = GRUGate(hidden_size)

if self.use_gate:
self.gate_mha = GRUGate(hidden_size)
self.gate_mlp = GRUGate(hidden_size)

def forward_orig(self, h, h_cache, key_pe):

Expand Down

0 comments on commit 19c9ffb

Please sign in to comment.