From 19c9ffb451288afc06bda8236f7fa104cc88f7d3 Mon Sep 17 00:00:00 2001 From: jerrodparker20 Date: Thu, 2 Apr 2020 06:56:17 -0600 Subject: [PATCH] Now only create gated modules if necessary --- StableTransformersReplication/transformer_xl.py | 7 +++++-- adaptive_span2/models.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/StableTransformersReplication/transformer_xl.py b/StableTransformersReplication/transformer_xl.py index e69371e..fb8f2d0 100644 --- a/StableTransformersReplication/transformer_xl.py +++ b/StableTransformersReplication/transformer_xl.py @@ -118,8 +118,11 @@ def __init__(self, n_head, d_model, d_head, d_inner, dropout,use_gate,use_stable self.use_gate = use_gate self.use_stable_version = use_stable_version - self.gate_mha = GRUGate(d_model) - self.gate_mlp = GRUGate(d_model) + + if self.use_gate: + self.gate_mha = GRUGate(d_model) + self.gate_mlp = GRUGate(d_model) + self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) diff --git a/adaptive_span2/models.py b/adaptive_span2/models.py index 4aea8dc..472ea91 100644 --- a/adaptive_span2/models.py +++ b/adaptive_span2/models.py @@ -185,8 +185,10 @@ def __init__(self, hidden_size, flags, **kargs): self.use_gate = flags.use_gate self.use_stable_version = True #use_stable_version - self.gate_mha = GRUGate(hidden_size) - self.gate_mlp = GRUGate(hidden_size) + + if self.use_gate: + self.gate_mha = GRUGate(hidden_size) + self.gate_mlp = GRUGate(hidden_size) def forward_orig(self, h, h_cache, key_pe):