Skip to content

Commit

Permalink
[callback] Allow external callbacks to return 64-bit values in 32-bit…
Browse files Browse the repository at this point in the history
… mode

Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
  • Loading branch information
gnecula committed Apr 3, 2024
1 parent dcd45c8 commit 35b1cb7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _check_shape_dtype(shape_dtype):
dt = np.dtype(shape_dtype.dtype)
if dtypes.canonicalize_dtype(dt) != dt:
raise ValueError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled")
"result_shape_dtypes cannot specify 64-bit types when `jax_enable_x64` is disabled")


def pure_callback(
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,16 +2469,12 @@ def _wrapped_callback(*args):
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(np.asarray(a) for a in out_vals)
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
raise RuntimeError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
f"Actual: {out_val.dtype}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value #{i}: "
Expand Down
19 changes: 12 additions & 7 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def f(x):
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
)

if not config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
elif expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
if not config.enable_x64.value and expect_dtype in (np.int64, np.float64):
ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types")
elif config.enable_x64.value and expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()

Expand Down Expand Up @@ -247,7 +247,7 @@ def f():
jax.effects_barrier()

@with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback):
def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered):

def _cb():
return np.array([1], np.float64)
Expand All @@ -257,9 +257,14 @@ def f():
# Calling a function expected a f32 return value but getting f64
return callback(_cb, core.ShapedArray((1,), np.float32))

with self.assertRaises(RuntimeError):
f()
if config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()
with ctx:
res = f()
jax.effects_barrier()
self.assertAllClose(res, np.array([1], np.float32))

@with_pure_and_io_callbacks
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
Expand Down

0 comments on commit 35b1cb7

Please sign in to comment.