Skip to content

Commit

Permalink
[linen] generalize transform caching
Browse files Browse the repository at this point in the history
* Renames `decorator_lift_transform_jit` to `decorator_lift_transform_cached` and `module_class_lift_transform_jit` to `module_class_lift_transform_cached`, and generalizes them to accept a `transform`.
* Adds `lift_transfom_cached` to allow lifting any transform using the functions above.
* Updates `lift.checkpoint` so it can be used with `lift_transfom_cached`.
* Fixes potential bug in `lift.jit`.

PiperOrigin-RevId: 649420171
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Jul 22, 2024
1 parent 65cd19e commit 9b7f29f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 55 deletions.
65 changes: 40 additions & 25 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""Jax transform lifting."""

import collections
import functools
from typing import (
Any,
TypeVar,
)
from collections.abc import Callable, Iterable, Mapping, Sequence
import contextlib
import dataclasses
import functools
import threading
from typing import Any, Generic, TypeVar
import warnings

from flax import traceback_util
from flax.typing import (
In,
Out,
InOutAxis,
InOutScanAxis,
In,
InOutAxis,
InOutScanAxis,
Out,
)
import jax
from jax import random
Expand All @@ -51,6 +51,26 @@

traceback_util.register_exclusion(__file__)

A = TypeVar('A')


@dataclasses.dataclass
class TransformContext(Generic[A], threading.local):
"""Context for a transform."""

stack: list[A] = dataclasses.field(default_factory=list)

@contextlib.contextmanager
def push(self, a: A):
self.stack.append(a)
try:
yield
finally:
self.stack.pop()

def get(self) -> A:
return self.stack[-1]


def tree_map_rngs(fn, tree):
"""Needed for mapping JAX random.* functions over PRNGKey leaves."""
Expand Down Expand Up @@ -1416,12 +1436,12 @@ def checkpoint(
This function is aliased to ``lift.remat`` just like ``jax.remat``.
Args:
fn: scope function for which intermediate computations should be
re-computed when computing gradients.
fn: scope function for which intermediate computations should be re-computed
when computing gradients.
variables: The variable collections that are lifted. By default all
collections are lifted.
rngs: The PRNG sequences that are lifted. By default all PRNG sequences
are lifted.
rngs: The PRNG sequences that are lifted. By default all PRNG sequences are
lifted.
concrete: Optional, boolean indicating whether ``fun`` may involve
value-dependent Python control flow (default ``False``). Support for such
control flow is optional, and disabled by default, because in some
Expand All @@ -1440,6 +1460,7 @@ def checkpoint(
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
Returns:
A wrapped version of ``fn``. When computing gradients intermediate
computations will be re-computed when computing gradients.
Expand Down Expand Up @@ -1554,8 +1575,7 @@ def jit(
# Close over scope_fn & repack_fn to avoid recompilation
# this is impure but we use the fingerprint arg to differentiate between cases
# where scope_fn or repack_fn actually produce non-identical results.
scope_fn = None # type: Callable | None
repack_fn = None # type: Callable | None
jit_context = TransformContext[tuple[Callable, Callable]]()

@functools.partial(
jax.jit,
Expand All @@ -1567,33 +1587,28 @@ def jit(
)
@functools.wraps(fn)
def jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs):
nonlocal scope_fn, repack_fn
scope_fn, repack_fn = jit_context.get()
hash_key = fingerprint[1]
# fingerprint is only used to differentiate the cache signature
del fingerprint
# del fingerprint
scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable
y = fn(scope, hash_key, *args, **kwargs)
return y, repack_fn(scope) # pylint: disable=not-callable

def inner(
scope_fun,
repack_fun,
scope_fn,
repack_fn,
variable_groups,
rng_groups,
module_hash_key,
*args,
**kwargs,
):
nonlocal scope_fn, repack_fn
try:
scope_fn = scope_fun
repack_fn = repack_fun
with jit_context.push((scope_fn, repack_fn)):
scopes = jax.tree_util.tree_leaves(scope_fn(variable_groups, rng_groups))
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)
fingerprint = (mutable, module_hash_key)
return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
finally:
scope_fn, repack_fn = None, None

return pack(
inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True
Expand Down
66 changes: 36 additions & 30 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]:

if isinstance(obj, str):
return obj
elif hasattr(obj, '__fn_or_cls__'): # support PaxConfig objects
return _fingerprint_recursive(obj.__fn_or_cls__, path, seen_modules)
elif isinstance(obj, Module):
fingerprint: Any
if obj._id in seen_modules:
Expand Down Expand Up @@ -562,7 +564,7 @@ def _check_field_is_hashable(path: tuple[str, ...], x: Any):
raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e


def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs):
"""Decorator for lifted transform.
Similar to `decorator_lift_transform` but specialized for `jit`, it reuses the
Expand All @@ -572,7 +574,6 @@ def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
# Due to the ordering of method decorators, we must wrap the class_fn
# with the module state management wrapper first to maintain Module state
# correctly.
transform = lift.jit
multi_scope = True

if isinstance(class_fn, tuple):
Expand Down Expand Up @@ -640,11 +641,12 @@ def core_fn(
return wrapped_fn


def module_class_lift_transform_jit(module_class, methods=None, **trafo_kwargs):
def module_class_lift_transform_cached(
transform, module_class, methods=None, **trafo_kwargs
):
"""Module class lift transform."""
# TODO(marcvanzee): Improve docstrings (#1977).
# TODO(levskaya): find nicer argument convention for multi-method case?
transform = lift.jit
trafo_args = ()

# Prepare per-method transform args, kwargs.
Expand Down Expand Up @@ -765,6 +767,24 @@ def lift_transform(
raise errors.TransformTargetError(target)


def lift_transform_cached(
transform, target, *trafo_args, methods=None, **trafo_kwargs
):
"""Applies to class or as a decorator on class fns."""
# TODO(marcvanzee): Improve docstrings (#1977).
if _is_module_class(target):
return module_class_lift_transform_cached(
transform, target, *trafo_args, methods=methods, **trafo_kwargs
)
# we presume this is being used as a function decorator in class definition
elif callable(target) and not isinstance(target, Module):
return decorator_lift_transform_cached(
transform, target, *trafo_args, **trafo_kwargs
)
else:
raise errors.TransformTargetError(target)


def lift_direct_transform(
transform: Callable[..., Any],
targets: tuple[Callable[..., Any], ...],
Expand Down Expand Up @@ -941,8 +961,8 @@ def jit(
A wrapped version of target, set up for just-in-time compilation.
"""
# TODO(marcvanzee): Improve docstrings (#1977).
if _is_module_class(target):
return module_class_lift_transform_jit(
return lift_transform_cached(
lift.jit,
target,
variables=variables,
rngs=rngs,
Expand All @@ -952,21 +972,7 @@ def jit(
device=device,
backend=backend,
methods=methods,
)
# we presume this is being used as a function decorator in class definition
elif callable(target) and not isinstance(target, Module):
return decorator_lift_transform_jit(
target,
variables=variables,
rngs=rngs,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
device=device,
backend=backend,
)
else:
raise errors.TransformTargetError(target)
)


def checkpoint(
Expand Down Expand Up @@ -1044,15 +1050,15 @@ def checkpoint(
# lifted function
static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums)
return lift_transform(
lift.checkpoint,
target,
variables=variables,
rngs=rngs,
concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse,
policy=policy,
methods=methods,
lift.checkpoint,
target,
variables=variables,
rngs=rngs,
concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse,
policy=policy,
methods=methods,
)


Expand Down
23 changes: 23 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,29 @@ def __call__(self, x):
y = m.apply({}, x)
self.assertEqual(n, 1)

def test_jit_recursive(self):
n = 0

class Foo(nn.Module):

@partial(nn.jit, static_argnames='recurse_once')
def __call__(self, x, *, recurse_once: bool = True):
nonlocal n
n += 1
if recurse_once:
x = self(x, recurse_once=False)
return x + 1

x = jnp.array(1.0)
m = Foo()

self.assertEqual(n, 0)

y = m.apply({}, x)
self.assertEqual(n, 2)
y = m.apply({}, x)
self.assertEqual(n, 2)

@parameterized.named_parameters(('class', True), ('method', False))
def test_jit_reuse_hash(self, jit_class: bool):
n = 0
Expand Down

0 comments on commit 9b7f29f

Please sign in to comment.