Skip to content

Commit

Permalink
Use Value dim shape for Attention compute_output_shape (#19284)
Browse files Browse the repository at this point in the history
* Use Value dim shape for Attention compute_output_shape

* Fix attention layer compute output shape

* fix format

* check compute_output_shape with output
  • Loading branch information
sampathweb authored Mar 18, 2024
1 parent 4c35630 commit b2ef949
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 0 additions & 3 deletions keras/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def _parse_name(self, name):
return "float16", "float32"
elif name == "mixed_bfloat16":
return "bfloat16", "float32"
elif name == "uint8":
dtype = backend.standardize_dtype(name)
return dtype, dtype
try:
dtype = backend.standardize_dtype(name)
return dtype, dtype
Expand Down
3 changes: 2 additions & 1 deletion keras/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def compute_mask(self, inputs, mask=None):
return ops.convert_to_tensor(mask[0])

def compute_output_shape(self, input_shape):
return input_shape[0]
"""Returns shape of value tensor dim, but for query tensor length"""
return (*input_shape[0][:-1], input_shape[1][-1])

def _validate_inputs(self, inputs, mask=None):
"""Validates arguments of the call method."""
Expand Down
16 changes: 16 additions & 0 deletions keras/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,19 @@ def test_attention_compute_mask_with_different_input_shapes(self):
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertTrue(np.array_equal(computed_mask, valid_mask))

def test_attention_compute_output_shape(self):
layer = layers.Attention()

query = np.random.random((2, 3, 4))
value = np.random.random((2, 3, 5))
key = np.random.random((2, 3, 4))
layer = layers.Attention()
output = layer([query, value, key])
self.assertAllEqual(output.shape, value.shape)
self.assertAllEqual(
layer.compute_output_shape(
input_shape=[query.shape, value.shape, key.shape]
),
output.shape,
)

0 comments on commit b2ef949

Please sign in to comment.