From 9b7f29ff95426e25f9ff275a2b29442f1df664cf Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 4 Jul 2024 08:18:59 -0700 Subject: [PATCH] [linen] generalize transform caching * Renames `decorator_lift_transform_jit` to `decorator_lift_transform_cached` and `module_class_lift_transform_jit` to `module_class_lift_transform_cached`, and generalizes them to accept a `transform`. * Adds `lift_transfom_cached` to allow lifting any transform using the functions above. * Updates `lift.checkpoint` so it can be used with `lift_transfom_cached`. * Fixes potential bug in `lift.jit`. PiperOrigin-RevId: 649420171 --- flax/core/lift.py | 65 ++++++++++++++++----------- flax/linen/transforms.py | 66 +++++++++++++++------------- tests/linen/linen_transforms_test.py | 23 ++++++++++ 3 files changed, 99 insertions(+), 55 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index e0de001f7a..24d2d5fa47 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -15,20 +15,20 @@ """Jax transform lifting.""" import collections -import functools -from typing import ( - Any, - TypeVar, -) from collections.abc import Callable, Iterable, Mapping, Sequence +import contextlib +import dataclasses +import functools +import threading +from typing import Any, Generic, TypeVar import warnings from flax import traceback_util from flax.typing import ( - In, - Out, - InOutAxis, - InOutScanAxis, + In, + InOutAxis, + InOutScanAxis, + Out, ) import jax from jax import random @@ -51,6 +51,26 @@ traceback_util.register_exclusion(__file__) +A = TypeVar('A') + + +@dataclasses.dataclass +class TransformContext(Generic[A], threading.local): + """Context for a transform.""" + + stack: list[A] = dataclasses.field(default_factory=list) + + @contextlib.contextmanager + def push(self, a: A): + self.stack.append(a) + try: + yield + finally: + self.stack.pop() + + def get(self) -> A: + return self.stack[-1] + def tree_map_rngs(fn, tree): """Needed for mapping JAX random.* functions over PRNGKey leaves.""" @@ -1416,12 +1436,12 @@ def checkpoint( This function is aliased to ``lift.remat`` just like ``jax.remat``. Args: - fn: scope function for which intermediate computations should be - re-computed when computing gradients. + fn: scope function for which intermediate computations should be re-computed + when computing gradients. variables: The variable collections that are lifted. By default all collections are lifted. - rngs: The PRNG sequences that are lifted. By default all PRNG sequences - are lifted. + rngs: The PRNG sequences that are lifted. By default all PRNG sequences are + lifted. concrete: Optional, boolean indicating whether ``fun`` may involve value-dependent Python control flow (default ``False``). Support for such control flow is optional, and disabled by default, because in some @@ -1440,6 +1460,7 @@ def checkpoint( arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. policy: Experimental checkpoint policy, see ``jax.checkpoint``. + Returns: A wrapped version of ``fn``. When computing gradients intermediate computations will be re-computed when computing gradients. @@ -1554,8 +1575,7 @@ def jit( # Close over scope_fn & repack_fn to avoid recompilation # this is impure but we use the fingerprint arg to differentiate between cases # where scope_fn or repack_fn actually produce non-identical results. - scope_fn = None # type: Callable | None - repack_fn = None # type: Callable | None + jit_context = TransformContext[tuple[Callable, Callable]]() @functools.partial( jax.jit, @@ -1567,33 +1587,28 @@ def jit( ) @functools.wraps(fn) def jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs): - nonlocal scope_fn, repack_fn + scope_fn, repack_fn = jit_context.get() hash_key = fingerprint[1] # fingerprint is only used to differentiate the cache signature - del fingerprint + # del fingerprint scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable y = fn(scope, hash_key, *args, **kwargs) return y, repack_fn(scope) # pylint: disable=not-callable def inner( - scope_fun, - repack_fun, + scope_fn, + repack_fn, variable_groups, rng_groups, module_hash_key, *args, **kwargs, ): - nonlocal scope_fn, repack_fn - try: - scope_fn = scope_fun - repack_fn = repack_fun + with jit_context.push((scope_fn, repack_fn)): scopes = jax.tree_util.tree_leaves(scope_fn(variable_groups, rng_groups)) mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) fingerprint = (mutable, module_hash_key) return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) - finally: - scope_fn, repack_fn = None, None return pack( inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 8be73bc830..1c4643aff6 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -484,6 +484,8 @@ def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]: if isinstance(obj, str): return obj + elif hasattr(obj, '__fn_or_cls__'): # support PaxConfig objects + return _fingerprint_recursive(obj.__fn_or_cls__, path, seen_modules) elif isinstance(obj, Module): fingerprint: Any if obj._id in seen_modules: @@ -562,7 +564,7 @@ def _check_field_is_hashable(path: tuple[str, ...], x: Any): raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e -def decorator_lift_transform_jit(class_fn, **trafo_kwargs): +def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs): """Decorator for lifted transform. Similar to `decorator_lift_transform` but specialized for `jit`, it reuses the @@ -572,7 +574,6 @@ def decorator_lift_transform_jit(class_fn, **trafo_kwargs): # Due to the ordering of method decorators, we must wrap the class_fn # with the module state management wrapper first to maintain Module state # correctly. - transform = lift.jit multi_scope = True if isinstance(class_fn, tuple): @@ -640,11 +641,12 @@ def core_fn( return wrapped_fn -def module_class_lift_transform_jit(module_class, methods=None, **trafo_kwargs): +def module_class_lift_transform_cached( + transform, module_class, methods=None, **trafo_kwargs +): """Module class lift transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # TODO(levskaya): find nicer argument convention for multi-method case? - transform = lift.jit trafo_args = () # Prepare per-method transform args, kwargs. @@ -765,6 +767,24 @@ def lift_transform( raise errors.TransformTargetError(target) +def lift_transform_cached( + transform, target, *trafo_args, methods=None, **trafo_kwargs +): + """Applies to class or as a decorator on class fns.""" + # TODO(marcvanzee): Improve docstrings (#1977). + if _is_module_class(target): + return module_class_lift_transform_cached( + transform, target, *trafo_args, methods=methods, **trafo_kwargs + ) + # we presume this is being used as a function decorator in class definition + elif callable(target) and not isinstance(target, Module): + return decorator_lift_transform_cached( + transform, target, *trafo_args, **trafo_kwargs + ) + else: + raise errors.TransformTargetError(target) + + def lift_direct_transform( transform: Callable[..., Any], targets: tuple[Callable[..., Any], ...], @@ -941,8 +961,8 @@ def jit( A wrapped version of target, set up for just-in-time compilation. """ # TODO(marcvanzee): Improve docstrings (#1977). - if _is_module_class(target): - return module_class_lift_transform_jit( + return lift_transform_cached( + lift.jit, target, variables=variables, rngs=rngs, @@ -952,21 +972,7 @@ def jit( device=device, backend=backend, methods=methods, - ) - # we presume this is being used as a function decorator in class definition - elif callable(target) and not isinstance(target, Module): - return decorator_lift_transform_jit( - target, - variables=variables, - rngs=rngs, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - device=device, - backend=backend, - ) - else: - raise errors.TransformTargetError(target) + ) def checkpoint( @@ -1044,15 +1050,15 @@ def checkpoint( # lifted function static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums) return lift_transform( - lift.checkpoint, - target, - variables=variables, - rngs=rngs, - concrete=concrete, - static_argnums=static_argnums, - prevent_cse=prevent_cse, - policy=policy, - methods=methods, + lift.checkpoint, + target, + variables=variables, + rngs=rngs, + concrete=concrete, + static_argnums=static_argnums, + prevent_cse=prevent_cse, + policy=policy, + methods=methods, ) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 437dfbd59e..290a7c4213 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -1838,6 +1838,29 @@ def __call__(self, x): y = m.apply({}, x) self.assertEqual(n, 1) + def test_jit_recursive(self): + n = 0 + + class Foo(nn.Module): + + @partial(nn.jit, static_argnames='recurse_once') + def __call__(self, x, *, recurse_once: bool = True): + nonlocal n + n += 1 + if recurse_once: + x = self(x, recurse_once=False) + return x + 1 + + x = jnp.array(1.0) + m = Foo() + + self.assertEqual(n, 0) + + y = m.apply({}, x) + self.assertEqual(n, 2) + y = m.apply({}, x) + self.assertEqual(n, 2) + @parameterized.named_parameters(('class', True), ('method', False)) def test_jit_reuse_hash(self, jit_class: bool): n = 0