From b2ef949cceb01c53d231a4da9cbfbaa12cea981d Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:25:43 -0700 Subject: [PATCH] Use Value dim shape for Attention compute_output_shape (#19284) * Use Value dim shape for Attention compute_output_shape * Fix attention layer compute output shape * fix format * check compute_output_shape with output --- keras/dtype_policies/dtype_policy.py | 3 --- keras/layers/attention/attention.py | 3 ++- keras/layers/attention/attention_test.py | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index c546022f5f5..7c8586c2a6e 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -173,9 +173,6 @@ def _parse_name(self, name): return "float16", "float32" elif name == "mixed_bfloat16": return "bfloat16", "float32" - elif name == "uint8": - dtype = backend.standardize_dtype(name) - return dtype, dtype try: dtype = backend.standardize_dtype(name) return dtype, dtype diff --git a/keras/layers/attention/attention.py b/keras/layers/attention/attention.py index 22586e02959..b42b4c05634 100644 --- a/keras/layers/attention/attention.py +++ b/keras/layers/attention/attention.py @@ -242,7 +242,8 @@ def compute_mask(self, inputs, mask=None): return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): - return input_shape[0] + """Returns shape of value tensor dim, but for query tensor length""" + return (*input_shape[0][:-1], input_shape[1][-1]) def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" diff --git a/keras/layers/attention/attention_test.py b/keras/layers/attention/attention_test.py index c6010b67461..102717994ea 100644 --- a/keras/layers/attention/attention_test.py +++ b/keras/layers/attention/attention_test.py @@ -342,3 +342,19 @@ def test_attention_compute_mask_with_different_input_shapes(self): computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) computed_mask = ops.convert_to_numpy(computed_mask) self.assertTrue(np.array_equal(computed_mask, valid_mask)) + + def test_attention_compute_output_shape(self): + layer = layers.Attention() + + query = np.random.random((2, 3, 4)) + value = np.random.random((2, 3, 5)) + key = np.random.random((2, 3, 4)) + layer = layers.Attention() + output = layer([query, value, key]) + self.assertAllEqual(output.shape, value.shape) + self.assertAllEqual( + layer.compute_output_shape( + input_shape=[query.shape, value.shape, key.shape] + ), + output.shape, + )