Skip to content

Commit

Permalink
[callback] Improve caching effectiveness in presence of callbacks.
Browse files Browse the repository at this point in the history
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
  • Loading branch information
gnecula committed Apr 1, 2024
1 parent 3f3986f commit 8b1b66d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
43 changes: 27 additions & 16 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
"""Module for JAX callbacks."""
from __future__ import annotations

import dataclasses
from collections.abc import Sequence
import logging
import functools
from typing import Any, Callable

import numpy as np

import jax
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
Expand All @@ -46,10 +48,27 @@
map, unsafe_map = util.safe_map, map


@dataclasses.dataclass(frozen=True)
class _FlatCallback:
"""A Python function callable with flat arguments and results.
An instance of this class is used as a parameter for the callback primitives.
We prefer it to an anonymous flattened function because it produces
equal objects when we call the same Python function with the same argument
structure.
"""
callback_func: Callable[..., Any]
in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`.

def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]:
args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args)
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))


def pure_callback_impl(
*args,
result_avals,
callback: Callable[..., Any],
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
):
Expand All @@ -68,7 +87,7 @@ def pure_callback_impl(
@pure_callback_p.def_abstract_eval
def pure_callback_abstract_eval(
*avals,
callback: Callable[..., Any],
callback: _FlatCallback,
result_avals,
sharding: SingleDeviceSharding | None,
vectorized: bool,
Expand Down Expand Up @@ -100,7 +119,7 @@ def pure_callback_batching_rule(
args,
dims,
*,
callback,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
result_avals: Sequence[core.ShapedArray],
Expand Down Expand Up @@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):


def pure_callback_lowering(
ctx, *args, callback, sharding: SingleDeviceSharding | None, **params
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
):
def _callback(*flat_args):
return tuple(
Expand Down Expand Up @@ -265,18 +284,14 @@ def pure_callback(
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))

flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
result_avals = tree_util.tree_map(
lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
out_flat = pure_callback_p.bind(
*flat_args,
callback=_flat_callback,
callback=_FlatCallback(callback, in_tree),
result_avals=tuple(flat_result_avals),
sharding=sharding,
vectorized=vectorized,
Expand Down Expand Up @@ -378,7 +393,7 @@ class OrderedIOEffect(effects.Effect):
def io_callback_impl(
*args,
result_avals,
callback: Callable[..., Any],
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
ordered: bool,
):
Expand All @@ -397,7 +412,7 @@ def io_callback_impl(
@io_callback_p.def_effectful_abstract_eval
def io_callback_abstract_eval(
*avals,
callback: Callable[..., Any],
callback: _FlatCallback,
result_avals,
sharding: SingleDeviceSharding | None,
ordered: bool,
Expand Down Expand Up @@ -516,10 +531,6 @@ def io_callback(
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))

flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
Expand All @@ -528,7 +539,7 @@ def _flat_callback(*flat_args):
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind(
*flat_args,
callback=_flat_callback,
callback=_FlatCallback(callback, in_tree),
result_avals=tuple(flat_result_avals),
sharding=sharding,
ordered=ordered,
Expand Down
14 changes: 14 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,20 @@ def f(x):
self.assertIn(f"jax.{api_name} failed", output)
self.assertIn("Traceback (most recent call last)", output)

@with_pure_and_io_callbacks
def test_compilation_caching(self, *, callback):
def f_outside(x):
return 2 * x

def fun(x):
return callback(f_outside, x, x)

x = np.arange(6, dtype=np.int32).reshape((2, 3))
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * x, fun(x))
self.assertEqual(count[0], 1)


class PureCallbackTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 8b1b66d

Please sign in to comment.