diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 00da3802d80c..cf16383bead0 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2461,15 +2461,22 @@ def _wrapped_callback(*args): raise RuntimeError( "Mismatched number of outputs from callback. " "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) + # Handle Python literals, and custom arrays, e.g., tf.Tensor. + out_vals = tuple(np.asarray(a) for a in out_vals) for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): if out_val.shape != out_aval.shape: raise RuntimeError( - f"Incorrect output shape for return value {i}: " - "Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape)) + f"Incorrect output shape for return value #{i}: " + f"Expected: {out_aval.shape}, Actual: {out_val.shape}") + if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype): + raise RuntimeError( + "Cannot return 64-bit values when `jax_enable_x64` is disabled. " + f"Actual: {out_val.dtype}") if out_val.dtype != out_aval.dtype: raise RuntimeError( - f"Incorrect output dtype for return value {i}: " - "Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype)) + f"Incorrect output dtype for return value #{i}: " + f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}") + if platform == "tpu": # On TPU we cannot receive empty arrays. So, we return from the wrapped # callback only the non-empty results, and we will create empty constants diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index ff743200ab80..a84740026e91 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import contextlib import functools import logging import textwrap @@ -27,11 +28,11 @@ from jax._src import core from jax._src import dispatch from jax._src import maps -from jax._src.maps import xmap from jax._src import test_util as jtu from jax._src import util from jax._src.lib import xla_client from jax._src.lib import xla_extension_version +from jax._src.maps import xmap from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -72,6 +73,7 @@ def tearDownModule(): for flavor in ("io_unordered", "io_ordered", "pure") ) + class PythonCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -93,6 +95,73 @@ def f(x): out = f(0.) self.assertEqual(out, 1.) + @parameterized.named_parameters( + dict( + testcase_name=f"{flavor}_expect_dtype_{expect_dtype}", + callback=dict( + io_unordered=io_calback_unordered, + io_ordered=io_callback_ordered, + pure=jax.pure_callback, + )[flavor], + expect_dtype=expect_dtype, + ) + for flavor in ("io_unordered", "io_ordered", "pure") + for expect_dtype in (np.int32, np.int64, np.float32, np.float64) + ) + def test_callback_returning_python_literal(self, *, callback, expect_dtype): + returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0 + + @jax.jit + def f(x): + return callback( + lambda x: returned_literal, core.ShapedArray((), expect_dtype), x + ) + + if not config.enable_x64.value: + ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values") + elif expect_dtype in (np.int32, np.float32): + ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype") + else: + ctx = contextlib.nullcontext() + + with ctx: + out = f(0.0) + jax.effects_barrier() + self.assertEqual(out, returned_literal) + + @with_pure_and_io_callbacks + def test_callback_returning_custom_array(self, *, callback): + # Some users write the callback in TF, returning a tf.Tensor. We don't + # want to add TF as a dependency, but simulate that use case with a + # custom array class. + class CustomArray: + + def __init__(self, a: np.ndarray): + self.a = a + + @property + def shape(self): + return self.a.shape + + @property + def dtype(self): + return self.a.dtype + + def __array__(self): + return self.a + + @jax.jit + def f(x): + return callback( + lambda x: CustomArray(np.array(42.0, dtype=np.float32)), + core.ShapedArray((), np.float32), + x, + ) + + out = f(0.0) + jax.effects_barrier() + self.assertEqual(out, 42.0) + @parameterized.named_parameters( dict(testcase_name=f"{flavor}_{dtype}", dtype=dtype,