diff --git a/jax/_src/callback.py b/jax/_src/callback.py index b0b66b2f4e75..0942d04ab44c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -244,7 +244,7 @@ def _check_shape_dtype(shape_dtype): dt = np.dtype(shape_dtype.dtype) if dtypes.canonicalize_dtype(dt) != dt: raise ValueError( - "Cannot return 64-bit values when `jax_enable_x64` is disabled") + "result_shape_dtypes cannot specify 64-bit types when `jax_enable_x64` is disabled") def pure_callback( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index d2572b5d2d05..239559a4740f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2469,16 +2469,12 @@ def _wrapped_callback(*args): "Mismatched number of outputs from callback. " "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) # Handle Python literals, and custom arrays, e.g., tf.Tensor. - out_vals = tuple(np.asarray(a) for a in out_vals) + out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals) for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): if out_val.shape != out_aval.shape: raise RuntimeError( f"Incorrect output shape for return value #{i}: " f"Expected: {out_aval.shape}, Actual: {out_val.shape}") - if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype): - raise RuntimeError( - "Cannot return 64-bit values when `jax_enable_x64` is disabled. " - f"Actual: {out_val.dtype}") if out_val.dtype != out_aval.dtype: raise RuntimeError( f"Incorrect output dtype for return value #{i}: " diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 0cd51631ad20..2329fe65e052 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -117,10 +117,10 @@ def f(x): lambda x: returned_literal, core.ShapedArray((), expect_dtype), x ) - if not config.enable_x64.value: - ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values") - elif expect_dtype in (np.int32, np.float32): - ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype") + if not config.enable_x64.value and expect_dtype in (np.int64, np.float64): + ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types") + elif config.enable_x64.value and expect_dtype in (np.int32, np.float32): + ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value") else: ctx = contextlib.nullcontext() @@ -247,7 +247,7 @@ def f(): jax.effects_barrier() @with_pure_and_io_callbacks - def test_callback_with_wrong_dtype_outputs(self, *, callback): + def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered): def _cb(): return np.array([1], np.float64) @@ -257,9 +257,14 @@ def f(): # Calling a function expected a f32 return value but getting f64 return callback(_cb, core.ShapedArray((1,), np.float32)) - with self.assertRaises(RuntimeError): - f() + if config.enable_x64.value: + ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value") + else: + ctx = contextlib.nullcontext() + with ctx: + res = f() jax.effects_barrier() + self.assertAllClose(res, np.array([1], np.float32)) @with_pure_and_io_callbacks def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):