Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Keras mask & causal mask to MultiHeadAttention #16619

Merged
merged 6 commits into from
Jul 11, 2022

Conversation

ageron
Copy link
Contributor

@ageron ageron commented May 29, 2022

This is a first go at adding Keras masking support and causal support to the MultiHeadAttention layer.

See the discussion in #16248.

@ageron ageron changed the title Add support for Keras masking and causal masking Add support for Keras mask & causal mask to MultiHeadAttention May 29, 2022
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! For consistency, I would favor the same API as the one used in KerasNLP here: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/transformer_decoder.py#L191

As in, a use_causal_mask argument in call(). It would not change the PR much otherwise.

@fchollet
Copy link
Member

FYI @mattdangerw @chenmoneygithub

@ageron
Copy link
Contributor Author

ageron commented May 29, 2022

Hi @fchollet , thanks for reviewing this PR.

Note that the keras.layers.Attention layer uses an init parameter called causal, as noted by @haifeng-jin in #16248, so it seems to me that consistency within Keras is more important than consistency with Keras-NLP?

That said, I just made the change you requested, and I'm happy to push it if you prefer, I don't have a strong opinion about this: perhaps it's the Attention layer that should be updated to use use_causal_mask in the call() method.

@fchollet
Copy link
Member

This is a good question -- we should make a choice one way or the other and standardize everything (Keras, KerasNLP) on that choice. Let me chat with the team and we'll figure out what to do.

@chenmoneygithub
Copy link
Contributor

@fchollet To me causal is not the best argument name as it is too abstract. I did a codesearch and there are only a few usages of the causal argument of tf.keras.layers.Attention in google3, so arg name changing won't be too hard.

@ageron
Copy link
Contributor Author

ageron commented May 30, 2022

Sounds good. I just replaced causal in the constructor with use_causal_mask in the call() method.

@fchollet
Copy link
Member

I also think it makes sense to have the argument in call() because that's where we pass the other masking-related arguments. So let's keep use_causal_mask in call().

@chenmoneygithub could you please file an issue to change the Attention layer API to adopt this convention?

@chenmoneygithub
Copy link
Contributor

@fchollet Sure!

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 2, 2022
@mattdangerw
Copy link
Member

@ageron did some more testing a came across a few things.

  1. I think we need to be more defensive about assuming mask types. We have real world usage where the mask (explicit or implicit) is floating point, which leads to errors when doing the bitwise &. Maybe we can just make sure we cast all masks to bool type before working with them? We should probably add a test for this too.

  2. In an encoder where QKV are all the same, with an implicit mask, this code looks like it will make an attention mask with zeros along the bottom and along the right. I am used to seeing an attention mask with zeros only along one of those dimensions. This colab the difference. I am reading your code correct? Do you know what the right approach here is?

@ageron
Copy link
Contributor Author

ageron commented Jun 15, 2022

Hi @mattdangerw ,
Sorry for the late response, got Covid in the family.
Thanks for the careful review, you make two good points.
I'll take care of the mask types.
However, regarding the padding difference, after some thought I'm not sure the difference matters, since it only affects masked out tokens, so normally these outputs will be ignored anyway.

For example, let's take a three word sentence padded to 5 tokens like this:

"I am happy <pad> <pad>"

With causal self-attention, we want "I" to attend only to "I", "am" should attend only to "I am", and "happy" should attend only to "I am happy". So far so good. Next, it doesn't really matter what both padding tokens attend to, since their output representation will be ignored anyway (downstream).

I tried to think of ways in which it could matter, but I couldn't find any, other than possibly a performance difference. Wdyt?

@mattdangerw
Copy link
Member

I tried to think of ways in which it could matter, but I couldn't find any, other than possibly a performance difference. Wdyt?

Yeah that makes sense to me. Thanks for the explanation. Let's go with what we have up.

Could you add the defensive casting to bool types for the masks, and a test?

@ageron
Copy link
Contributor Author

ageron commented Jun 25, 2022

Hi @mattdangerw ,
Thanks for your feedback, I'll add the defensive casting + test by ~Wednesday, sounds good?

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jun 26, 2022
@ageron
Copy link
Contributor Author

ageron commented Jun 26, 2022

Hi @mattdangerw, I just added defensive casting to bool (with tests).

@mattdangerw
Copy link
Member

@ageron thank you so much! I will kick off another round of testing on this.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 29, 2022
@mattdangerw
Copy link
Member

(approving just to trigger our import flow for testing)

@mattdangerw
Copy link
Member

Still looking at this. We have one production failure from an oddly shaped Keras implicit mask we are trying to figure out.

Hopefully no more action needed here, but will ping if anything comes up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull Ready to be merged into the codebase size:M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants