diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 361300c44a9c..b0b66b2f4e75 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -14,6 +14,7 @@ """Module for JAX callbacks.""" from __future__ import annotations +import dataclasses from collections.abc import Sequence import logging import functools @@ -21,6 +22,7 @@ import numpy as np +import jax from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -46,10 +48,27 @@ map, unsafe_map = util.safe_map, map +@dataclasses.dataclass(frozen=True) +class _FlatCallback: + """A Python function callable with flat arguments and results. + + An instance of this class is used as a parameter for the callback primitives. + We prefer it to an anonymous flattened function because it produces + equal objects when we call the same Python function with the same argument + structure. + """ + callback_func: Callable[..., Any] + in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`. + + def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]: + args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args) + return tree_util.tree_leaves(self.callback_func(*args, **kwargs)) + + def pure_callback_impl( *args, result_avals, - callback: Callable[..., Any], + callback: _FlatCallback, sharding: SingleDeviceSharding | None, vectorized: bool, ): @@ -68,7 +87,7 @@ def pure_callback_impl( @pure_callback_p.def_abstract_eval def pure_callback_abstract_eval( *avals, - callback: Callable[..., Any], + callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, vectorized: bool, @@ -100,7 +119,7 @@ def pure_callback_batching_rule( args, dims, *, - callback, + callback: _FlatCallback, sharding: SingleDeviceSharding | None, vectorized: bool, result_avals: Sequence[core.ShapedArray], @@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): def pure_callback_lowering( - ctx, *args, callback, sharding: SingleDeviceSharding | None, **params + ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params ): def _callback(*flat_args): return tuple( @@ -265,10 +284,6 @@ def pure_callback( .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) - return tree_util.tree_leaves(callback(*args, **kwargs)) - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) result_avals = tree_util.tree_map( @@ -276,7 +291,7 @@ def _flat_callback(*flat_args): flat_result_avals, out_tree = tree_util.tree_flatten(result_avals) out_flat = pure_callback_p.bind( *flat_args, - callback=_flat_callback, + callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, vectorized=vectorized, @@ -378,7 +393,7 @@ class OrderedIOEffect(effects.Effect): def io_callback_impl( *args, result_avals, - callback: Callable[..., Any], + callback: _FlatCallback, sharding: SingleDeviceSharding | None, ordered: bool, ): @@ -397,7 +412,7 @@ def io_callback_impl( @io_callback_p.def_effectful_abstract_eval def io_callback_abstract_eval( *avals, - callback: Callable[..., Any], + callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, ordered: bool, @@ -516,10 +531,6 @@ def io_callback( .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) - return tree_util.tree_leaves(callback(*args, **kwargs)) - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) @@ -528,7 +539,7 @@ def _flat_callback(*flat_args): flat_args = map(core.raise_as_much_as_possible, flat_args) out_flat = io_callback_p.bind( *flat_args, - callback=_flat_callback, + callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, ordered=ordered, diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 0cc15840a54e..0cd51631ad20 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,6 +586,20 @@ def f(x): self.assertIn(f"jax.{api_name} failed", output) self.assertIn("Traceback (most recent call last)", output) + @with_pure_and_io_callbacks + def test_compilation_caching(self, *, callback): + def f_outside(x): + return 2 * x + + def fun(x): + return callback(f_outside, x, x) + + x = np.arange(6, dtype=np.int32).reshape((2, 3)) + with jtu.count_primitive_compiles() as count: + for _ in range(3): + self.assertAllClose(2 * x, fun(x)) + self.assertEqual(count[0], 1) + class PureCallbackTest(jtu.JaxTestCase):