Skip to content

Commit

Permalink
[linen] fold rngs on jit to improve caching
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654699932
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Jul 22, 2024
1 parent 4e83e09 commit afcaf66
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
34 changes: 22 additions & 12 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@
from . import axes_scan, meta
from .frozen_dict import freeze, unfreeze
from .scope import (
CollectionFilter,
DenyList, # pylint: disable=g-multiple-import
Filter,
PRNGSequenceFilter,
Scope,
group_collections,
in_filter,
intersect_filters,
is_filter_empty,
subtract_filters,
union_filters,
CollectionFilter,
DenyList, # pylint: disable=g-multiple-import
Filter,
LazyRng,
PRNGSequenceFilter,
Scope,
group_collections,
in_filter,
intersect_filters,
is_filter_empty,
subtract_filters,
union_filters,
)

traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -1605,8 +1606,17 @@ def inner(
**kwargs,
):
with jit_context.push((scope_fn, repack_fn)):
scopes = jax.tree_util.tree_leaves(scope_fn(variable_groups, rng_groups))
scopes: list[Scope] = jax.tree_util.tree_leaves(
scope_fn(variable_groups, rng_groups)
)
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)

rng_groups = jax.tree.map(
lambda x: x.fold() if isinstance(x, LazyRng) else x,
rng_groups,
is_leaf=lambda x: isinstance(x, LazyRng),
)

fingerprint = (mutable, module_hash_key)
return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)

Expand Down
11 changes: 11 additions & 0 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def create(
else:
return LazyRng(rng, suffix)

def fold(self):
key = self.as_jax_rng()
return LazyRng(key, ())


def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
"""Legacy RNG folding."""
Expand Down Expand Up @@ -601,6 +605,13 @@ def default_name(self, prefix: str) -> str:
return name
i += 1

def fold_rngs(self):
"""Folds the rngs of this scope into the parent scope."""
self._check_valid()
for name, rng in self.rngs.items():
assert isinstance(rng, LazyRng)
self.rngs[name] = rng.fold()

def push(
self, name: str | None = None, prefix: str = '', reuse=False
) -> 'Scope':
Expand Down
3 changes: 2 additions & 1 deletion flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def get_module_scopes(module, args=None, kwargs=None):
VariablePlaceholders and Module instances replaced with InstancePlaceholders
that are compatible with jax functions.
"""
scopes = []
scopes: list[Scope] = []
refs = {}

# Gather scopes associated with Variables and Module instances passed as
Expand Down Expand Up @@ -620,6 +620,7 @@ def core_fn(
trafo_fn = transform(*core_fns, **trafo_kwargs)

module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)

if not multi_scope:
if len(module_scopes) != 1:
# TODO(levskaya): transforms like jvp & vjp have args that follow the
Expand Down
26 changes: 26 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,32 @@ def __call__(self, a: jax.Array, b: str):
self.assertEqual(s, 'hi')
np.testing.assert_array_equal(y, jnp.array(1.0))

def test_jit_and_sow(self):
class Inner(nn.Module):

@nn.compact
def __call__(self, x):
self.sow('intermediates', 'loss', jnp.sum(x))
return x + 1

class Outer(nn.Module):

def setup(self):
self.inner = Inner()

@nn.jit
def __call__(self, x):
return self.inner(x)

m = Outer()
x = jnp.ones((2, 2))
vs = m.init(random.key(0), x)
y, updates = m.apply(vs, x, mutable=['intermediates'])
np.testing.assert_array_equal(
updates['intermediates']['inner']['loss'], 4.0
)
np.testing.assert_array_equal(y, 2)

def test_jit_repr_hash(self):
n = 0

Expand Down

0 comments on commit afcaf66

Please sign in to comment.