Skip to content

Commit

Permalink
Merge pull request #20534 from gnecula:callback_64bit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621507392
  • Loading branch information
jax authors committed Apr 3, 2024
2 parents dcd45c8 + 35b1cb7 commit d89f0d6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _check_shape_dtype(shape_dtype):
dt = np.dtype(shape_dtype.dtype)
if dtypes.canonicalize_dtype(dt) != dt:
raise ValueError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled")
"result_shape_dtypes cannot specify 64-bit types when `jax_enable_x64` is disabled")


def pure_callback(
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,16 +2469,12 @@ def _wrapped_callback(*args):
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(np.asarray(a) for a in out_vals)
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
raise RuntimeError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
f"Actual: {out_val.dtype}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value #{i}: "
Expand Down
19 changes: 12 additions & 7 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def f(x):
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
)

if not config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
elif expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
if not config.enable_x64.value and expect_dtype in (np.int64, np.float64):
ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types")
elif config.enable_x64.value and expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()

Expand Down Expand Up @@ -247,7 +247,7 @@ def f():
jax.effects_barrier()

@with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback):
def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered):

def _cb():
return np.array([1], np.float64)
Expand All @@ -257,9 +257,14 @@ def f():
# Calling a function expected a f32 return value but getting f64
return callback(_cb, core.ShapedArray((1,), np.float32))

with self.assertRaises(RuntimeError):
f()
if config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()
with ctx:
res = f()
jax.effects_barrier()
self.assertAllClose(res, np.array([1], np.float32))

@with_pure_and_io_callbacks
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
Expand Down

0 comments on commit d89f0d6

Please sign in to comment.