Skip to content

Commit

Permalink
Remove uses of jax.experimental.host_callback.call
Browse files Browse the repository at this point in the history
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

Most of the changes here have to do with the fact that io_callback does not pass the `device` to the callback. Fortunately, it seems that this code uses the device argument only for logging. I removed all uses of `device`.

PiperOrigin-RevId: 618402363
  • Loading branch information
gnecula authored and jax authors committed Mar 25, 2024
1 parent 3d8ffd4 commit 1423d25
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
14 changes: 10 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,14 +2462,20 @@ def _wrapped_callback(*args):
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
actual_shape = np.shape(out_val)
if actual_shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
if out_val.dtype != out_aval.dtype:
"Expected: {}, Actual: {}".format(out_aval.shape, actual_shape))
actual_dtype = np.result_type(out_val)
if actual_dtype != dtypes.canonicalize_dtype(actual_dtype):
raise ValueError(
"Cannot return 64-bit values when `jax_enable_x64` is disabled")
if actual_dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
"Expected: {}, Actual: {}".format(out_aval.dtype, actual_dtype))

if platform == "tpu":
# On TPU we cannot receive empty arrays. So, we return from the wrapped
# callback only the non-empty results, and we will create empty constants
Expand Down
30 changes: 30 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import contextlib
import functools
import logging
import textwrap
Expand Down Expand Up @@ -72,6 +73,7 @@ def tearDownModule():
for flavor in ("io_unordered", "io_ordered", "pure")
)


class PythonCallbackTest(jtu.JaxTestCase):

def setUp(self):
Expand All @@ -93,6 +95,34 @@ def f(x):
out = f(0.)
self.assertEqual(out, 1.)

@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor],
expect_dtype=expect_dtype)
for flavor in ("io_unordered", "io_ordered", "pure")
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
)
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.
@jax.jit
def f(x):
return callback(lambda x: returned_literal,
core.ShapedArray((), expect_dtype), x)

if not config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
elif expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
else:
ctx = contextlib.nullcontext()

with ctx:
out = f(0.)
jax.effects_barrier()
self.assertEqual(out, returned_literal)

@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
Expand Down

0 comments on commit 1423d25

Please sign in to comment.