Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Marian model conversion #30173

Merged
merged 4 commits into from
May 1, 2024
Merged

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Apr 11, 2024

What does this PR do?

Fixes #26216. After a bit of code exploration I found that problem was in the "tie_weights" method. The issue is that "tie_weights" does not clone, but simply change pointer output_embeddings -> input_embeddings. And since we tie weights at least twice during loading weights (once before loading state dict and the other is after), the weights were being loaded incorrectly.

Why the issue is only in MarianMT and not other models that tie weights? -> Because in other models either the weights already hold the same/tied values , or the order of loading parameters is different from MarianMT. When loading state dict into the MarianMT, "output_embeddings" are loaded last and therefore "input_embeddings" weights data is overriden to will hold same data as "output_embeddings". That way we lose access to actual "input_embeddings" weights data
This is somehow related to an old PR I found.

I converted the weights with python src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py --models fin-eng --save_dir converted and checked the correctness of results with the below script. Everything works good.

from transformers import AutoTokenizer, MarianMTModel

tokenizer = AutoTokenizer.from_pretrained("/home/raushan/converted/opus-mt-fin-eng/")
model = MarianMTModel.from_pretrained("/home/raushan/converted/opus-mt-fin-eng/")

inputs = ["Hei siellä", "Miten aurinko sanotaan suomeksi?"]
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True)
model_output = model.generate(
    **batch_tokenized, max_new_tokens=100
)
batch_detokenized = tokenizer.batch_decode(
    model_output,
    skip_special_tokens=True,
)

print(batch_detokenized)

Marian tests including slow are all passsing on my end.

@@ -3840,7 +3840,6 @@ def _fix_key(key):
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = sorted(unexpected_keys - model_buffers)

model.tie_weights()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob we can remove this, given a few lines above the weights before loading state dict are tied already

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced in #25107, however, indeed there are no changeds to the model in the code above.
Can you make sure the bug fixed in the PR are still passing?

@@ -34,7 +34,6 @@

DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one gives 404, so I just loaded dataset to the hub

@@ -622,6 +622,12 @@ def load_marian_model(self) -> MarianMTModel:
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor

# handle tied embeddings, otherwise "from_pretrained" loads them incorrectly
if self.cfg["tied-embeddings"]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this is the actual fix, which gives equal weights for tied parameters

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We were recently pinged about that.
Great digging!

@@ -3840,7 +3840,6 @@ def _fix_key(key):
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = sorted(unexpected_keys - model_buffers)

model.tie_weights()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced in #25107, however, indeed there are no changeds to the model in the code above.
Can you make sure the bug fixed in the PR are still passing?

Comment on lines 627 to 628
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where these are used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yeah, makes sense. They are not used anymore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird I cannot reply to the above comment. Anyway, I tested FSDP with fsdp_cpu_ram_efficient_loading: true, and looks like tying weights do not have big effect on CPU memory. But to be clear, summoning @pacman100 to confirm that as he was the PR contributor

@zucchini-nlp
Copy link
Member Author

@ArthurZucker i reverted the change for tying weights, to be consistent with 'main'. It was not the actual solution for Marian models, so I think it does not hurt whatever was the reason for it to be added. Requesting re-review :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super glad with these changes ! Thanks for fixing!

@zucchini-nlp zucchini-nlp merged commit 4bc9cb3 into huggingface:main May 1, 2024
19 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* fix marian model coversion

* uncomment that line

* remove unnecessary code

* revert tie_weights, doesn't hurt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Some MarianMT models broken and output garbage
3 participants