Skip to content

Commit

Permalink
Fix mixed precision check in TestCase.run_layer_test: compare with ou…
Browse files Browse the repository at this point in the history
…tput_spec dtype instead of hardcoded float16 (#19297)

* Fix mixed precision check: compare with output spec dtype instead of hardcoded float16

* Revert "Fix mixed precision check: compare with output spec dtype instead of hardcoded float16"

This reverts commit 94d5584.

* Restore changes

* Trying to reformat code

* Restore formatted code

* Fix formatting

* Fix black latest wrong formatting
  • Loading branch information
shkarupa-alex authored Mar 14, 2024
1 parent 6a266b8 commit c591329
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions keras/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,14 +519,29 @@ def data_generator():

if run_mixed_precision_check:
layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"})
input_spec = tree.map_structure(
lambda spec: KerasTensor(
spec.shape,
dtype=(
layer.compute_dtype
if layer.autocast
and backend.is_float_dtype(spec.dtype)
else spec.dtype
),
),
keras_tensor_inputs,
)
if isinstance(input_data, dict):
output_data = layer(**input_data, **call_kwargs)
output_spec = layer.compute_output_spec(**input_spec)
else:
output_data = layer(input_data, **call_kwargs)
for tensor in tree.flatten(output_data):
output_spec = layer.compute_output_spec(input_spec)
for tensor, spec in zip(
tree.flatten(output_data), tree.flatten(output_spec)
):
dtype = standardize_dtype(tensor.dtype)
if is_float_dtype(dtype):
self.assertEqual(dtype, "float16")
self.assertEqual(dtype, spec.dtype)
for weight in layer.weights:
dtype = standardize_dtype(weight.dtype)
if is_float_dtype(dtype):
Expand Down

0 comments on commit c591329

Please sign in to comment.