Skip to content

Commit

Permalink
[linen] allows multiple compact methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jun 29, 2024
1 parent 15e0e8d commit 86292dd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 31 deletions.
16 changes: 1 addition & 15 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,6 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
cls._customized_dataclass_transform(kw_only)
# We wrap user-defined methods including setup and __call__ to enforce
# a number of different checks and to provide clear error messages.
cls._verify_single_or_no_compact()
cls._find_compact_name_scope_methods()
cls._wrap_module_attributes()
# Set empty class defaults.
Expand Down Expand Up @@ -1116,20 +1115,6 @@ def _customized_dataclass_transform(cls, kw_only: bool):

cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign]

@classmethod
def _verify_single_or_no_compact(cls):
"""Statically verifies that at most a single method is labelled compact."""
methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)]
n_compact_fns = len(
[
method_name
for method_name in methods
if hasattr(getattr(cls, method_name), 'compact')
]
)
if n_compact_fns > 1:
raise errors.MultipleMethodsCompactError()

@classmethod
def _find_compact_name_scope_methods(cls):
"""Finds all compact_name_scope methods in the class."""
Expand Down Expand Up @@ -1208,6 +1193,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
raise errors.CallCompactUnboundModuleError()
is_recurrent = self._state.in_compact_method
self._state.in_compact_method = True
self._state.autoname_cursor = {}
_context.module_stack.append(self)
try:
# get call info
Expand Down
35 changes: 23 additions & 12 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,29 @@ def __call__(self):

Foo({'a': ()}).apply({})

def test_only_one_compact_method(self):
msg = 'Only one method per class can be @compact'
with self.assertRaisesRegex(errors.MultipleMethodsCompactError, msg):

class MultipleCompactMethods(nn.Module):
@compact
def call1(self):
pass

@compact
def call2(self):
pass
def test_multiple_compact_methods(self):
"""Test that multiple methods with the @compact decorator can be used.
NOTE: in the near future we might want to have compact methods reset the
autoname_cursor such that Dense would be reused in the second method.
"""

class MultipleCompactMethods(nn.Module):
@compact
def __call__(self, x):
x = nn.Dense(1)(x)
return self.method(x)

@compact
def method(self, x):
x = nn.Dense(1)(x)
return x

m = MultipleCompactMethods()
variables = m.init(random.key(0), jnp.ones((1, 1)))
params = variables['params']
self.assertIn('Dense_0', params)
self.assertNotIn('Dense_1', params)

def test_only_one_compact_method_subclass(self):
class Dummy(nn.Module):
Expand Down
9 changes: 5 additions & 4 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,14 +2089,15 @@ def body_fn(mdl, c):
rngs={'params': random.key(1), 'loop': random.key(2)},
)
self.assertEqual(vars['state']['acc'], x)
np.testing.assert_array_equal(
vars['state']['rng_params'][0], vars['state']['rng_params'][1]
self.assertTrue(
jnp.equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1])
)
with jax_debug_key_reuse(False):
np.testing.assert_array_compare(
operator.__ne__,
self.assertFalse(
jnp.equal(
vars['state']['rng_loop'][0],
vars['state']['rng_loop'][1],
)
)

def test_cond(self):
Expand Down

0 comments on commit 86292dd

Please sign in to comment.