-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF: BART compatible with XLA generation #17479
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@ydshieh tagging you for TF review, as Matt is off and you are also familiar with generate :) |
Actually not very familiar, but would love to get more involved 😃. Thanks for tagging me! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice to me! All the changes to modeling_tf_bart.py
are 100% ok/good for me. I'd maybe just add one slow xla test to test_modeling_bart.py
and once this works we can proceed and make all the tests pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @gante I left a few comments so far.
Questions:
-
From the change in
prepare_inputs_for_generation
, both in the PR forTF-GPT2
and this PR, my understanding of the main change is that: we need to use (decoder) attention mask in order to calculate the correctposition_ids
for both left/right padding. And this is done usingtf.math.cumsum
. Do I understand these PR correctly? -
Why we need
decoder_position_ids
whenpast_key_values
is passed?
transformers/src/transformers/models/bart/modeling_tf_bart.py
Lines 943 to 945 in 9089b7b
if position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - This is not in
TF-GPT2
, and it is calculated astransformers/src/transformers/models/gpt2/modeling_tf_gpt2.py
Lines 388 to 389 in 26e5e12
if position_ids is None: position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) - In
prepare_inputs_for_generation
(this PR), we also compute it astransformers/src/transformers/models/bart/modeling_tf_bart.py
Lines 1420 to 1421 in 9089b7b
elif past is not None: # non xla + past decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) TFBartLearnedPositionalEmbedding
or asTFGP2MainLayer
, right?
I know you mentioned this guard is copied from Flax, but I am just wondering if it is a real necessity. (I feel if it is really necessary, the same guard should also exist in GPT-2.)
- This is not in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I continued a bit, but need to take a rest before looking _update_model_kwargs_for_xla_generation
.
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) | ||
else: # non xla + non past | ||
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, this (else
) case (i.e. when past
is None
) do NOT require decoder_position_ids
. TFBartLearnedPositionalEmbedding
will take care of creating it.
Also, we don't need to broadcast here. (unless XLA
requires the explicit shape here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct, since the guard is only active when the past is passed (not the case here). However, the return statement would fail, because it is expecting a decoder_position_ids
variable. I'd rather make it explicit than implicit :)
(removed the broadcast)
# cut decoder_input_ids if past is used | ||
if past is not None: | ||
decoder_input_ids = decoder_input_ids[:, -1:] | ||
|
||
if decoder_attention_mask is not None: # xla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: From TF-GPT2's prepare_inputs_for_generation
, maybe the following will be more consistent?
I would guess decoder_position_ids
would never be passed to prepare_inputs_for_generation
, so it will be always created here.
(But the question becomes why we even have a such check in TF-GPT2?)
decoder_position_ids = kwargs.get("decoder_position_ids", None)
...
if decoder_position_ids is None:
if decoder_attention_mask is not None:
....
elif:
...
else:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, see my comment for the (merged) TF-GPT2 PR
https://github.com/huggingface/transformers/pull/17426/files#r889687584
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your concerns, but I'm not going to worry about consistency for now :) I haven't managed to get beam search to work, so there is a chance I will have to rewrite all these generate-related functions.
After beam search is working then yes, I'd like to revisit these models and make a template for each kind (i.e. decoder-only or encoder-decoder model). Would that work for you, @ydshieh ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me, @gante :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comment for _update_model_kwargs_for_xla_generation
.
- It's NP hard for me to understand 😢, so I tried to add some comments in the code. Hopefully you will find it is helpful and merge it.
- Need some explanations on
decoder_attention_mask
🙏
decoder_attention_mask = tf.concat( | ||
[ | ||
tf.ones((batch_size, 1), dtype=tf.int32), | ||
tf.zeros((batch_size, num_padding_values), dtype=tf.int32), | ||
tf.ones((batch_size, 1), dtype=tf.int32), | ||
], | ||
axis=1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Update]
It looks like this block is for the case where the generation of decoder_start_token_id
was done.
So the num_padding_values
would equal to max_length - 1 - 1
.
It's still not clear to me why we put zeros
before the second ones
.
In general, I think it would be great if we can put more comments to explain things along the code.
(You definitely know things much better, but it would be beneficial to other developers 😄 )
[Original comment]
I am not able to understand this block so far.
The decoder_attention_mask
normally has the same length as the current input sequence.
I guess maybe here you want to keep the shape being fixed (i.e. with max_length
step).
But this block gives a length of 2 + num_padding_values
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it looks like this discards the decoder_attention_mask
in model_kwargs
(if provided). In TF-GPT2, this case is treated. But probably there is some assumption that decoder_attention_mask
is never provided to generate
for TF-Bart, and will only be added in _update_model_kwargs_for_xla_generation
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this can definitely be improved (it was copy/paste from T5). Will not touch it this PR though, the beam search PR will rewrite it out of necessity 😭
Hey @ydshieh 👋 answering your questions:
Correct 👍
In the original PT code and eager execution TF, the position ids can be obtained by default (i.e. when not explicitly passed) from the past length, as the past length corresponds to the next position id if there is no left padding. In FLAX and XLA TF, the past is zero-padded, so the past length is not the default position id. As such, it is dangerous to leave the default path active -- this path should only be used in generate anyways, and the updated generate passes the position ids. (The GPT2 should also get the same guard, to be safe!) |
OK, I might got it. The Thank you @gante ! |
I didn't think it in a thorough way, but in
it seems to me that we could cut (of course, we need to pass the current length info. to
@gante I don't want to make you too busy. I will let you judge if this is a good idea, and even if it is, if we should change it now, or we can do it later. I know we want to publish our work soon! |
I would love to do that, and it would be a great idea to simplify the code, but sadly XLA does not allow dynamic-sized slices (i.e. cutting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I finally finished the review, sorry for being too long.
Thank you for this awesome work, @gante! LGTM for the logic (thanks for the explanations!)
I would encourage to explain things a bit more along the code, as mentioned in a few of my comments.
(I only review the Bart-related files)
For examples, this place
decoder_attention_mask = tf.concat( |
seems to suggest the prompt would have seq length 1 (i.e.
[decoder_start_token_id]
), and I am totally fine as if this would be the only use case for Bart (I believe so).
However, from the method itself, it looks like it can handle any prompt.
- especially the treatment of
past
A comment mentioning what (case/assumption) the block is dealing with would be great.
Think we can move towards finishing this PR here :-) |
@patrickvonplaten it is ready to merge -- would you like to make a final review, or can I merge the PR? :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean - good to go for me!
Would indeed be nice to eventually replace BART's slow temporary test with a XLA beam search test
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
What does this PR do?
Adds
position_ids
toTFBart
, so that we can do generation with a padded past -- a requirement for XLA generation.This PR was built on top of #17426 (so it will contain its diff until it gets merged), and is a requirement for #17458.
🚨 Important notes:
make fix-copies
(several files have copies from Bart).prepare_inputs_for_generation
were copied from Bart, but the models do not have theposition_ids
input. If the PR gets a positive review, I will propagate the changes to the affected models.