Skip to content

Commit

Permalink
[callback] Fix io_callback for callbacks that return Python literals.
Browse files Browse the repository at this point in the history
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

PiperOrigin-RevId: 619100631
  • Loading branch information
gnecula authored and jax authors committed Mar 26, 2024
1 parent 33cf53c commit e07671d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 5 deletions.
15 changes: 11 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,15 +2461,22 @@ def _wrapped_callback(*args):
raise RuntimeError(
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(np.asarray(a) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
raise RuntimeError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
f"Actual: {out_val.dtype}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
f"Incorrect output dtype for return value #{i}: "
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")

if platform == "tpu":
# On TPU we cannot receive empty arrays. So, we return from the wrapped
# callback only the non-empty results, and we will create empty constants
Expand Down
71 changes: 70 additions & 1 deletion tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import contextlib
import functools
import logging
import textwrap
Expand All @@ -27,11 +28,11 @@
from jax._src import core
from jax._src import dispatch
from jax._src import maps
from jax._src.maps import xmap
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.maps import xmap
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
Expand Down Expand Up @@ -72,6 +73,7 @@ def tearDownModule():
for flavor in ("io_unordered", "io_ordered", "pure")
)


class PythonCallbackTest(jtu.JaxTestCase):

def setUp(self):
Expand All @@ -93,6 +95,73 @@ def f(x):
out = f(0.)
self.assertEqual(out, 1.)

@parameterized.named_parameters(
dict(
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
callback=dict(
io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback,
)[flavor],
expect_dtype=expect_dtype,
)
for flavor in ("io_unordered", "io_ordered", "pure")
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
)
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0

@jax.jit
def f(x):
return callback(
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
)

if not config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
elif expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
else:
ctx = contextlib.nullcontext()

with ctx:
out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, returned_literal)

@with_pure_and_io_callbacks
def test_callback_returning_custom_array(self, *, callback):
# Some users write the callback in TF, returning a tf.Tensor. We don't
# want to add TF as a dependency, but simulate that use case with a
# custom array class.
class CustomArray:

def __init__(self, a: np.ndarray):
self.a = a

@property
def shape(self):
return self.a.shape

@property
def dtype(self):
return self.a.dtype

def __array__(self):
return self.a

@jax.jit
def f(x):
return callback(
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
core.ShapedArray((), np.float32),
x,
)

out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, 42.0)

@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
Expand Down

0 comments on commit e07671d

Please sign in to comment.