diff --git a/keras/testing/test_case.py b/keras/testing/test_case.py index da5567fec57..fb8d917c531 100644 --- a/keras/testing/test_case.py +++ b/keras/testing/test_case.py @@ -417,14 +417,29 @@ def data_generator(): if run_mixed_precision_check: layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"}) + input_spec = tree.map_structure( + lambda spec: KerasTensor( + spec.shape, + dtype=( + layer.compute_dtype + if layer.autocast + and backend.is_float_dtype(spec.dtype) + else spec.dtype + ), + ), + keras_tensor_inputs, + ) if isinstance(input_data, dict): output_data = layer(**input_data, **call_kwargs) + output_spec = layer.compute_output_spec(**input_spec) else: output_data = layer(input_data, **call_kwargs) - for tensor in tree.flatten(output_data): + output_spec = layer.compute_output_spec(input_spec) + for tensor, spec in zip( + tree.flatten(output_data), tree.flatten(output_spec) + ): dtype = standardize_dtype(tensor.dtype) - if is_float_dtype(dtype): - self.assertEqual(dtype, "float16") + self.assertEqual(dtype, spec.dtype) for weight in layer.weights: dtype = standardize_dtype(weight.dtype) if is_float_dtype(dtype):