Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: BART compatible with XLA generation #17479

Merged
merged 17 commits into from
Jun 20, 2022
Merged

TF: BART compatible with XLA generation #17479

merged 17 commits into from
Jun 20, 2022

Conversation

gante
Copy link
Member

@gante gante commented May 30, 2022

What does this PR do?

Adds position_ids to TFBart, so that we can do generation with a padded past -- a requirement for XLA generation.

This PR was built on top of #17426 (so it will contain its diff until it gets merged), and is a requirement for #17458.

🚨 Important notes:

  1. Review suggestion: check the Bart file, then its test file. The other changes are either cosmetic changes (e.g. correcting comments) or the result of make fix-copies (several files have copies from Bart).
  2. There are several failing tests, but it's intentional -- some models' prepare_inputs_for_generation were copied from Bart, but the models do not have the position_ids input. If the PR gets a positive review, I will propagate the changes to the affected models.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante mentioned this pull request May 30, 2022
@gante gante marked this pull request as ready for review May 31, 2022 13:07
@gante
Copy link
Member Author

gante commented May 31, 2022

@ydshieh tagging you for TF review, as Matt is off and you are also familiar with generate :)

@ydshieh
Copy link
Collaborator

ydshieh commented May 31, 2022

@ydshieh tagging you for TF review, as Matt is off and you are also familiar with generate :)

Actually not very familiar, but would love to get more involved 😃. Thanks for tagging me!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice to me! All the changes to modeling_tf_bart.py are 100% ok/good for me. I'd maybe just add one slow xla test to test_modeling_bart.py and once this works we can proceed and make all the tests pass?

src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
tests/test_modeling_tf_common.py Show resolved Hide resolved
tests/models/bart/test_modeling_tf_bart.py Outdated Show resolved Hide resolved
tests/models/bart/test_modeling_tf_bart.py Show resolved Hide resolved
tests/models/bart/test_modeling_tf_bart.py Show resolved Hide resolved
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @gante I left a few comments so far.

Questions:

  • From the change in prepare_inputs_for_generation, both in the PR for TF-GPT2 and this PR, my understanding of the main change is that: we need to use (decoder) attention mask in order to calculate the correct position_ids for both left/right padding. And this is done using tf.math.cumsum. Do I understand these PR correctly?

  • Why we need decoder_position_ids when past_key_values is passed?

    if position_ids is None:
    if past_key_values is not None:
    raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")

    I know you mentioned this guard is copied from Flax, but I am just wondering if it is a real necessity. (I feel if it is really necessary, the same guard should also exist in GPT-2.)

src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I continued a bit, but need to take a rest before looking _update_model_kwargs_for_xla_generation.

src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)

Copy link
Collaborator

@ydshieh ydshieh Jun 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, this (else) case (i.e. when past is None) do NOT require decoder_position_ids. TFBartLearnedPositionalEmbedding will take care of creating it.

Also, we don't need to broadcast here. (unless XLA requires the explicit shape here)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct, since the guard is only active when the past is passed (not the case here). However, the return statement would fail, because it is expecting a decoder_position_ids variable. I'd rather make it explicit than implicit :)

(removed the broadcast)

src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

if decoder_attention_mask is not None: # xla
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: From TF-GPT2's prepare_inputs_for_generation, maybe the following will be more consistent?

I would guess decoder_position_ids would never be passed to prepare_inputs_for_generation, so it will be always created here.
(But the question becomes why we even have a such check in TF-GPT2?)

        decoder_position_ids = kwargs.get("decoder_position_ids", None)
        ...

        if decoder_position_ids is None:
            if decoder_attention_mask is not None:
                ....
            elif:
                ...
            else:
                ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, see my comment for the (merged) TF-GPT2 PR

https://github.com/huggingface/transformers/pull/17426/files#r889687584

Copy link
Member Author

@gante gante Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your concerns, but I'm not going to worry about consistency for now :) I haven't managed to get beam search to work, so there is a chance I will have to rewrite all these generate-related functions.

After beam search is working then yes, I'd like to revisit these models and make a template for each kind (i.e. decoder-only or encoder-decoder model). Would that work for you, @ydshieh ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me, @gante :-)

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comment for _update_model_kwargs_for_xla_generation.

  • It's NP hard for me to understand 😢, so I tried to add some comments in the code. Hopefully you will find it is helpful and merge it.
  • Need some explanations on decoder_attention_mask 🙏

src/transformers/models/bart/modeling_tf_bart.py Outdated Show resolved Hide resolved
Comment on lines +1461 to +1472
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
Copy link
Collaborator

@ydshieh ydshieh Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Update]
It looks like this block is for the case where the generation of decoder_start_token_id was done.
So the num_padding_values would equal to max_length - 1 - 1.
It's still not clear to me why we put zeros before the second ones.

In general, I think it would be great if we can put more comments to explain things along the code.
(You definitely know things much better, but it would be beneficial to other developers 😄 )

[Original comment]
I am not able to understand this block so far.

The decoder_attention_mask normally has the same length as the current input sequence.
I guess maybe here you want to keep the shape being fixed (i.e. with max_length step).
But this block gives a length of 2 + num_padding_values?

Copy link
Collaborator

@ydshieh ydshieh Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it looks like this discards the decoder_attention_mask in model_kwargs (if provided). In TF-GPT2, this case is treated. But probably there is some assumption that decoder_attention_mask is never provided to generate for TF-Bart, and will only be added in _update_model_kwargs_for_xla_generation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this can definitely be improved (it was copy/paste from T5). Will not touch it this PR though, the beam search PR will rewrite it out of necessity 😭

@gante
Copy link
Member Author

gante commented Jun 6, 2022

Hey @ydshieh 👋 answering your questions:

From the change in prepare_inputs_for_generation, both in the PR for TF-GPT2 and this PR, my understanding of the main change is that: we need to use (decoder) attention mask in order to calculate the correct position_ids for both left/right padding. And this is done using tf.math.cumsum. Do I understand these PR correctly?

Correct 👍

Why we need decoder_position_ids when past_key_values is passed?

In the original PT code and eager execution TF, the position ids can be obtained by default (i.e. when not explicitly passed) from the past length, as the past length corresponds to the next position id if there is no left padding. In FLAX and XLA TF, the past is zero-padded, so the past length is not the default position id. As such, it is dangerous to leave the default path active -- this path should only be used in generate anyways, and the updated generate passes the position ids. (The GPT2 should also get the same guard, to be safe!)

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 6, 2022

OK, I might got it. The past sent to the model is the padded (on the right) version! (which is required by XLA to have a fixed shape during loop, right?)

Thank you @gante !

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 6, 2022

I didn't think it in a thorough way, but in prepare_inputs_for_generation, when we return the actual inputs to a model,

it seems to me that we could cut past to the actual (non-padded) version. And when the model returns past, in _update_model_kwargs_for_xla_generation, we just always pad on the right.

(of course, we need to pass the current length info. to prepare_inputs_for_generation if we want to do so)

  • this will keep model_kwargs["past"] compatible with XLA
  • the actual past to model is the same as before
    • especially, it won't get max_length - 1 as length, so we no longer have overhead due to the increasing length
  • it might make the logic a bit easier in _update_model_kwargs_for_xla_generation

@gante I don't want to make you too busy. I will let you judge if this is a good idea, and even if it is, if we should change it now, or we can do it later. I know we want to publish our work soon!

@gante
Copy link
Member Author

gante commented Jun 6, 2022

it seems to me that we could cut past to the actual (non-padded) version.

I would love to do that, and it would be a great idea to simplify the code, but sadly XLA does not allow dynamic-sized slices (i.e. cutting past based on the current length or based on its non-zero values). I've had the same idea too, but then I came across this limitation (documented here)😢 Sadly, we have to keep working with the full padded array everywhere when XLA is on.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally finished the review, sorry for being too long.
Thank you for this awesome work, @gante! LGTM for the logic (thanks for the explanations!)

I would encourage to explain things a bit more along the code, as mentioned in a few of my comments.
(I only review the Bart-related files)

For examples, this place

decoder_attention_mask = tf.concat(

seems to suggest the prompt would have seq length 1 (i.e. [decoder_start_token_id]), and I am totally fine as if this would be the only use case for Bart (I believe so).

However, from the method itself, it looks like it can handle any prompt.

  • especially the treatment of past

A comment mentioning what (case/assumption) the block is dealing with would be great.

tests/models/bart/test_modeling_tf_bart.py Outdated Show resolved Hide resolved
tests/models/bart/test_modeling_tf_bart.py Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

Think we can move towards finishing this PR here :-)

@gante
Copy link
Member Author

gante commented Jun 15, 2022

@patrickvonplaten it is ready to merge -- would you like to make a final review, or can I merge the PR? :)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean - good to go for me!

Would indeed be nice to eventually replace BART's slow temporary test with a XLA beam search test

@gante gante merged commit 132402d into huggingface:main Jun 20, 2022
@gante gante deleted the xla_bart branch June 20, 2022 10:07
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants