-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Value dim shape for Attention compute_output_shape #19284
Use Value dim shape for Attention compute_output_shape #19284
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Unit tests are failing, I think some golden values need to be updated.
def test_attention_compute_output_shape(self): | ||
layer = layers.Attention() | ||
input_shape = [(2, 8, 7), (2, 8, 5), (2, 8, 7)] # Shapes of Q, V, K | ||
self.assertAllEqual(layer.compute_output_shape(input_shape) == (2, 8, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please call the layer on an input and read its shape, to ensure match between actual shape and computed shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to match it with output.shape
9b26297
to
ececf6e
Compare
@@ -173,9 +173,6 @@ def _parse_name(self, name): | |||
return "float16", "float32" | |||
elif name == "mixed_bfloat16": | |||
return "bfloat16", "float32" | |||
elif name == "uint8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant. Its addressed in the try block so removing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
…nse` Add qlora-like technique to `quantized_call` in `Dense` Update `save_own_variables` and `load_own_variables` Update `benchmark.py` update version string. Set dtype policy for uint8 (keras-team#19327) * Set Quantization policy for uint8 to float * Add uint8 to dtype_policies Use Value dim shape for Attention compute_output_shape (keras-team#19284) * Use Value dim shape for Attention compute_output_shape * Fix attention layer compute output shape * fix format * check compute_output_shape with output Update `quantized_call` in `EinsumDense` to support training with quantized weights
Fixes #19257 by using Value dim shape for
compute_output_shape
of Attention layer.