Skip to content

Commit

Permalink
Fix tf shared embedding (#17730)
Browse files Browse the repository at this point in the history
* fix the naming

* from pt in test for now

* make style

* slow test and removed from_pt
  • Loading branch information
ArthurZucker authored and sgugger committed Jun 16, 2022
1 parent 3981ee8 commit f8c8f4d
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions src/transformers/models/opt/modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
)
Expand Down Expand Up @@ -495,31 +494,15 @@ def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
self.padding_idx = config.pad_token_id
self.layerdrop = config.layerdrop
num_embeddings = config.max_position_embeddings

self.shared = TFSharedEmbeddings(
config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="model.decoder.embed_tokens"
self.embed_tokens = TFSharedEmbeddings(
config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens"
)

self.embed_positions = TFOPTLearnedPositionalEmbedding(
num_embeddings,
config.hidden_size,
name="embed_positions",
)

# set tf scope correctly
if load_weight_prefix is None:
load_weight_prefix = "decoder.embed_tokens"

with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
pass

# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size

self.embed_tokens = embed_tokens

if config.word_embed_proj_dim != config.hidden_size:
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
Expand All @@ -538,17 +521,11 @@ def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("decoder.embed_tokens") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.set_embed_tokens(embed_tokens)
self.embed_tokens.vocab_size = new_embeddings.shape[0]
self.embed_tokens.weight = new_embeddings

def get_input_embeddings(self):
return self.shared
return self.embed_tokens

def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
Expand Down Expand Up @@ -731,7 +708,7 @@ def __init__(self, config: OPTConfig, **kwargs):
self.decoder = TFOPTDecoder(config, name="decoder")

def get_input_embeddings(self):
return self.decoder.shared
return self.decoder.embed_tokens

def set_input_embeddings(self, new_embeddings):
self.decoder.set_input_embeddings(new_embeddings)
Expand Down Expand Up @@ -797,7 +774,7 @@ def __init__(self, config: OPTConfig, **kwargs):
self.model = TFOPTMainLayer(config, name="model")

def get_input_embeddings(self):
return self.model.decoder.shared
return self.model.decoder.embed_tokens

def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
Expand Down Expand Up @@ -1013,8 +990,7 @@ def call(
training=training,
)

logits = self.model.decoder.shared(outputs[0], mode="linear")

logits = self.model.decoder.embed_tokens(outputs[0], mode="linear")
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
Expand Down

0 comments on commit f8c8f4d

Please sign in to comment.