Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] FSMT bart-like refactor #11218

Closed
wants to merge 6 commits into from

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Apr 13, 2021

What does this PR do?

This PR refactors FSMT to align it with other (bart-like) seq-2-seq models in the lib.

This PR refactors FSMT similar to Bart in that it moves the time dimension to be always at the 2nd place and the batch dimensions always in the first place. Also, the cache is refactored to consists of tuples instead of a dict.

This refactor is very similar to #10501.

I have verified that all slow-tets are passing and that all metrics (BLEU score) can be reproduced. I ran the evaluation of the following four models and the results are similar to those reported in the model cards.

  • en-ru: 33.42
  • ru-en: 39.20
  • en-de: 42.83
  • de-en: 41.39

Benchmarking

This PR however introduces some speed and memory regression, which I'm currently investigating.

On this PR:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
     facebook/wmt19-en-ru            4               8             0.009     
     facebook/wmt19-en-ru            4               32             0.01     
     facebook/wmt19-en-ru            4              128            0.026     
     facebook/wmt19-en-ru            4              512            0.109     
--------------------------------------------------------------------------------

====================      INFERENCE - MEMORY - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length    Memory in MB 
--------------------------------------------------------------------------------
     facebook/wmt19-en-ru            4               8              2172     
     facebook/wmt19-en-ru            4               32             2200     
     facebook/wmt19-en-ru            4              128             2306     
     facebook/wmt19-en-ru            4              512             2792     
--------------------------------------------------------------------------------

On master:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
     facebook/wmt19-en-ru            4               8             0.007     
     facebook/wmt19-en-ru            4               32            0.007     
     facebook/wmt19-en-ru            4              128            0.013     
     facebook/wmt19-en-ru            4              512            0.046     
--------------------------------------------------------------------------------

====================      INFERENCE - MEMORY - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length    Memory in MB 
--------------------------------------------------------------------------------
     facebook/wmt19-en-ru            4               8              2170     
     facebook/wmt19-en-ru            4               32             2176     
     facebook/wmt19-en-ru            4              128             2204     
     facebook/wmt19-en-ru            4              512             2356     
--------------------------------------------------------------------------------


@patil-suraj patil-suraj changed the title [WIP] FMST refactor [WIP] FMST bart-like refactor Apr 13, 2021

def make_weight(self, num_positions, embedding_dim, padding_idx, dtype=torch.float32):
weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
weight = weight.to(dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle fp16 when resizing weights.

Comment on lines -309 to -310
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO decoder_input_ids should not be created using input_ids for this model, because it uses two different vocabs for the encoder and decode and this is really only applicable for BART because it uses the input_ids for task like classification and QA which FSMT does not support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

@patil-suraj patil-suraj changed the title [WIP] FMST bart-like refactor [WIP] FSMT bart-like refactor Apr 13, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this!

Comment on lines +93 to +94
"FSMTEncoder",
"FSMTDecoder",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed anymore.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for doing this refactoring, @patil-suraj!

It's a bit hard to review since all the code is moved around, so no easy diff to follow - so while I skimmed through it - I trust your expertise and the tests on the correctness.

With regards to memory/performance regression - (thank you for running this important check!) could it be that it was introduced in the initial Bart refactor? i.e. perhaps running the same check on Bart pre and post PR that did the main refactoring (when all the Barts were split up)? And if so then the issue is bigger and needs to be looked in that PR that introduced it.

@patil-suraj
Copy link
Contributor Author

Thanks Stas! I'm not sure what exactly introduced this memory/speed regression, so I'm going to investigate it and won't merge this PR before that.

@patrickvonplaten
Copy link
Contributor

Thank you for doing this refactoring, @patil-suraj!

It's a bit hard to review since all the code is moved around, so no easy diff to follow - so while I skimmed through it - I trust your expertise and the tests on the correctness.

With regards to memory/performance regression - (thank you for running this important check!) could it be that it was introduced in the initial Bart refactor? i.e. perhaps running the same check on Bart pre and post PR that did the main refactoring (when all the Barts were split up)? And if so then the issue is bigger and needs to be looked in that PR that introduced it.

I remember that I checked that the Bart refactor didn't show any regression both on the forward pass and generate(). I might however have overlooked something. Would definitely be a good idea to verify this first with the exact same testing params (batch_size=4, ...)!

assert attention_mask.dim() == 2
return attention_mask.eq(0)
# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right
def shift_tokens_right(input_ids, pad_token_id):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure whether using MBart's way of creating the decoder_input_ids here is correct. Note that MBart uses a very special language-dependent pattern when fine-tuning so that the decoder_start_token_id depends on the language, see Fig. 1 of paper.

IMO, this is not the case for FSMT. FSMT always uses the EOS/SEP token as the decoder_start_token_id = sep_token_id = "<\s>" if I understand correctly => Think it's better therefore to copy the shift_tokens_right function from Bart and set decoder_start_token_id = 2 as a default in the config. This would make this function much easier.

Would be great if you could confirm @stas00

Copy link
Contributor Author

@patil-suraj patil-suraj Apr 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. This is also equivalent in that, in FSMT sequence ends with EOS and then this method moves it to the beginning. But yeah, I will copy BART's function, which is more explicit here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! Thank you for catching that, @patrickvonplaten!

return outputs


class PretrainedFSMTModel(PreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename this to FSMTPreTrainedModel and deprecate this naming as it's done for Bart:

class PretrainedBartModel(BartPretrainedModel):
?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! Great to see much more Copied from statements! I'm also very curios where the memory/speed regression comes from. Think it's a good idea to make sure that the original Bart refactor didn't introduce this regression - maybe you can check the exact same testing params with Bart before the refactor and after & then see from there.

Another small nit: after this PR you should be able to delete this line in the common tests:

model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward

@huggingface huggingface deleted a comment from github-actions bot May 15, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label May 15, 2021
@stas00
Copy link
Contributor

stas00 commented May 15, 2021

@patil-suraj, what needs to be done to complete this?

Last we talked there was a performance regression and the suggestion was to test Bart's performance pre and post its original refactoring.

@stas00
Copy link
Contributor

stas00 commented Jul 21, 2021

@patil-suraj, FYI, recently I made a few fixes to the model to make it work with Deepspeed:
https://github.com/huggingface/transformers/pull/12477/files#diff-564f6d9b78eec17b410c924f868840770a9ad9649032bcf3754827317b9eaba3

Are we still planning to merge this PR? As we said earlier if there is a regression it'll be on the whole Bart family, so perhaps it might be easier to just merge this? Otherwise a lot of time gets waste getting back to it again and again and not getting anywhere.

Thank you.

@ArthurZucker
Copy link
Collaborator

Closing as this PR is super old

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants