-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Keras mask & causal mask to MultiHeadAttention #16619
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! For consistency, I would favor the same API as the one used in KerasNLP here: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/transformer_decoder.py#L191
As in, a use_causal_mask
argument in call()
. It would not change the PR much otherwise.
Hi @fchollet , thanks for reviewing this PR. Note that the That said, I just made the change you requested, and I'm happy to push it if you prefer, I don't have a strong opinion about this: perhaps it's the |
This is a good question -- we should make a choice one way or the other and standardize everything (Keras, KerasNLP) on that choice. Let me chat with the team and we'll figure out what to do. |
@fchollet To me |
Sounds good. I just replaced |
I also think it makes sense to have the argument in @chenmoneygithub could you please file an issue to change the |
@fchollet Sure! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@ageron did some more testing a came across a few things.
|
Hi @mattdangerw , For example, let's take a three word sentence padded to 5 tokens like this: "I am happy <pad> <pad>" With causal self-attention, we want "I" to attend only to "I", "am" should attend only to "I am", and "happy" should attend only to "I am happy". So far so good. Next, it doesn't really matter what both padding tokens attend to, since their output representation will be ignored anyway (downstream). I tried to think of ways in which it could matter, but I couldn't find any, other than possibly a performance difference. Wdyt? |
Yeah that makes sense to me. Thanks for the explanation. Let's go with what we have up. Could you add the defensive casting to bool types for the masks, and a test? |
Hi @mattdangerw , |
Hi @mattdangerw, I just added defensive casting to bool (with tests). |
@ageron thank you so much! I will kick off another round of testing on this. |
(approving just to trigger our import flow for testing) |
Still looking at this. We have one production failure from an oddly shaped Keras implicit mask we are trying to figure out. Hopefully no more action needed here, but will ping if anything comes up! |
This is a first go at adding Keras masking support and causal support to the
MultiHeadAttention
layer.See the discussion in #16248.