Skip to content

Commit

Permalink
Fix mixed precision check: compare with output spec dtype instead of …
Browse files Browse the repository at this point in the history
…hardcoded float16
  • Loading branch information
shkarupa-alex committed Mar 12, 2024
1 parent f189dec commit 94d5584
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions keras/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,26 @@ def data_generator():

if run_mixed_precision_check:
layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"})
input_spec = tree.map_structure(
lambda spec: KerasTensor(
spec.shape,
dtype=layer.compute_dtype
if layer.autocast and backend.is_float_dtype(spec.dtype)
else spec.dtype,
),
keras_tensor_inputs,
)
if isinstance(input_data, dict):
output_data = layer(**input_data, **call_kwargs)
output_spec = layer.compute_output_spec(**input_spec)
else:
output_data = layer(input_data, **call_kwargs)
for tensor in tree.flatten(output_data):
output_spec = layer.compute_output_spec(input_spec)
for tensor, spec in zip(
tree.flatten(output_data), tree.flatten(output_spec)
):
dtype = standardize_dtype(tensor.dtype)
if is_float_dtype(dtype):
self.assertEqual(dtype, "float16")
self.assertEqual(dtype, spec.dtype)
for weight in layer.weights:
dtype = standardize_dtype(weight.dtype)
if is_float_dtype(dtype):
Expand Down

0 comments on commit 94d5584

Please sign in to comment.