Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[callback] Fix io_callback for callbacks that return Python literals. #20433

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,15 +2461,22 @@ def _wrapped_callback(*args):
raise RuntimeError(
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(np.asarray(a) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
raise RuntimeError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
f"Actual: {out_val.dtype}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
f"Incorrect output dtype for return value #{i}: "
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")

if platform == "tpu":
# On TPU we cannot receive empty arrays. So, we return from the wrapped
# callback only the non-empty results, and we will create empty constants
Expand Down
71 changes: 70 additions & 1 deletion tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import contextlib
import functools
import logging
import textwrap
Expand All @@ -27,11 +28,11 @@
from jax._src import core
from jax._src import dispatch
from jax._src import maps
from jax._src.maps import xmap
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.maps import xmap
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
Expand Down Expand Up @@ -72,6 +73,7 @@ def tearDownModule():
for flavor in ("io_unordered", "io_ordered", "pure")
)


class PythonCallbackTest(jtu.JaxTestCase):

def setUp(self):
Expand All @@ -93,6 +95,73 @@ def f(x):
out = f(0.)
self.assertEqual(out, 1.)

@parameterized.named_parameters(
dict(
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
callback=dict(
io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback,
)[flavor],
expect_dtype=expect_dtype,
)
for flavor in ("io_unordered", "io_ordered", "pure")
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
)
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0

@jax.jit
def f(x):
return callback(
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
)

if not config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
elif expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
else:
ctx = contextlib.nullcontext()

with ctx:
out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, returned_literal)

@with_pure_and_io_callbacks
def test_callback_returning_custom_array(self, *, callback):
# Some users write the callback in TF, returning a tf.Tensor. We don't
# want to add TF as a dependency, but simulate that use case with a
# custom array class.
class CustomArray:

def __init__(self, a: np.ndarray):
self.a = a

@property
def shape(self):
return self.a.shape

@property
def dtype(self):
return self.a.dtype

def __array__(self):
return self.a

@jax.jit
def f(x):
return callback(
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
core.ShapedArray((), np.float32),
x,
)

out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, 42.0)

@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
Expand Down
Loading