diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py index c4feb213954918..418cc91e25fda3 100644 --- a/src/transformers/models/marian/convert_marian_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_to_pytorch.py @@ -480,6 +480,8 @@ def __init__(self, source_dir, eos_token_id=0): if "Wpos" in self.state_dict: raise ValueError("Wpos key in state dictionary") self.state_dict = dict(self.state_dict) + if cfg["tied-embeddings-all"]: + cfg["tied-embeddings-src"] = True self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"] # create the tokenizer here because we need to know the eos_token_id