-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] FSMT bart-like refactor #11218
Conversation
|
||
def make_weight(self, num_positions, embedding_dim, padding_idx, dtype=torch.float32): | ||
weight = self.get_embedding(num_positions, embedding_dim, padding_idx) | ||
weight = weight.to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handle fp16 when resizing weights.
if decoder_input_ids is None: | ||
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO decoder_input_ids
should not be created using input_ids
for this model, because it uses two different vocabs for the encoder and decode and this is really only applicable for BART because it uses the input_ids
for task like classification and QA which FSMT does not support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
"FSMTEncoder", | ||
"FSMTDecoder", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be needed anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for doing this refactoring, @patil-suraj!
It's a bit hard to review since all the code is moved around, so no easy diff to follow - so while I skimmed through it - I trust your expertise and the tests on the correctness.
With regards to memory/performance regression - (thank you for running this important check!) could it be that it was introduced in the initial Bart refactor? i.e. perhaps running the same check on Bart pre and post PR that did the main refactoring (when all the Barts were split up)? And if so then the issue is bigger and needs to be looked in that PR that introduced it.
Thanks Stas! I'm not sure what exactly introduced this memory/speed regression, so I'm going to investigate it and won't merge this PR before that. |
I remember that I checked that the Bart refactor didn't show any regression both on the forward pass and |
assert attention_mask.dim() == 2 | ||
return attention_mask.eq(0) | ||
# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right | ||
def shift_tokens_right(input_ids, pad_token_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure whether using MBart's way of creating the decoder_input_ids
here is correct. Note that MBart uses a very special language-dependent pattern when fine-tuning so that the decoder_start_token_id
depends on the language, see Fig. 1 of paper.
IMO, this is not the case for FSMT. FSMT always uses the EOS/SEP token as the decoder_start_token_id = sep_token_id = "<\s>"
if I understand correctly => Think it's better therefore to copy the shift_tokens_right
function from Bart and set decoder_start_token_id = 2
as a default in the config. This would make this function much easier.
Would be great if you could confirm @stas00
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. This is also equivalent in that, in FSMT sequence ends with EOS
and then this method moves it to the beginning. But yeah, I will copy BART's function, which is more explicit here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed! Thank you for catching that, @patrickvonplaten!
return outputs | ||
|
||
|
||
class PretrainedFSMTModel(PreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename this to FSMTPreTrainedModel
and deprecate this naming as it's done for Bart:
class PretrainedBartModel(BartPretrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! Great to see much more Copied from
statements! I'm also very curios where the memory/speed regression comes from. Think it's a good idea to make sure that the original Bart refactor didn't introduce this regression - maybe you can check the exact same testing params with Bart before the refactor and after & then see from there.
Another small nit: after this PR you should be able to delete this line in the common tests:
transformers/tests/test_modeling_common.py
Line 404 in 5a34d8d
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward |
@patil-suraj, what needs to be done to complete this? Last we talked there was a performance regression and the suggestion was to test Bart's performance pre and post its original refactoring. |
@patil-suraj, FYI, recently I made a few fixes to the model to make it work with Deepspeed: Are we still planning to merge this PR? As we said earlier if there is a regression it'll be on the whole Bart family, so perhaps it might be easier to just merge this? Otherwise a lot of time gets waste getting back to it again and again and not getting anywhere. Thank you. |
Closing as this PR is super old |
What does this PR do?
This PR refactors
FSMT
to align it with other (bart-like) seq-2-seq models in the lib.This PR refactors
FSMT
similar toBart
in that it moves the time dimension to be always at the 2nd place and the batch dimensions always in the first place. Also, the cache is refactored to consists oftuples
instead of adict
.This refactor is very similar to #10501.
I have verified that all slow-tets are passing and that all metrics (BLEU score) can be reproduced. I ran the evaluation of the following four models and the results are similar to those reported in the model cards.
Benchmarking
This PR however introduces some speed and memory regression, which I'm currently investigating.
On this PR:
On master: