Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen] allows multiple compact methods #3808

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,6 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
cls._customized_dataclass_transform(kw_only)
# We wrap user-defined methods including setup and __call__ to enforce
# a number of different checks and to provide clear error messages.
cls._verify_single_or_no_compact()
cls._find_compact_name_scope_methods()
cls._wrap_module_attributes()
# Set empty class defaults.
Expand Down Expand Up @@ -1116,20 +1115,6 @@ def _customized_dataclass_transform(cls, kw_only: bool):

cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign]

@classmethod
def _verify_single_or_no_compact(cls):
"""Statically verifies that at most a single method is labelled compact."""
methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)]
n_compact_fns = len(
[
method_name
for method_name in methods
if hasattr(getattr(cls, method_name), 'compact')
]
)
if n_compact_fns > 1:
raise errors.MultipleMethodsCompactError()

@classmethod
def _find_compact_name_scope_methods(cls):
"""Finds all compact_name_scope methods in the class."""
Expand Down
35 changes: 23 additions & 12 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,29 @@ def __call__(self):

Foo({'a': ()}).apply({})

def test_only_one_compact_method(self):
msg = 'Only one method per class can be @compact'
with self.assertRaisesRegex(errors.MultipleMethodsCompactError, msg):

class MultipleCompactMethods(nn.Module):
@compact
def call1(self):
pass

@compact
def call2(self):
pass
def test_multiple_compact_methods(self):
"""Test that multiple methods with the @compact decorator can be used.

NOTE: in the near future we might want to have compact methods reset the
autoname_cursor such that Dense would be reused in the second method.
"""

class MultipleCompactMethods(nn.Module):
@compact
def __call__(self, x):
x = nn.Dense(1)(x)
return self.method(x)

@compact
def method(self, x):
x = nn.Dense(1)(x)
return x

m = MultipleCompactMethods()
variables = m.init(random.key(0), jnp.ones((1, 1)))
params = variables['params']
self.assertIn('Dense_0', params)
self.assertIn('Dense_1', params)

def test_only_one_compact_method_subclass(self):
class Dummy(nn.Module):
Expand Down
9 changes: 5 additions & 4 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,14 +2089,15 @@ def body_fn(mdl, c):
rngs={'params': random.key(1), 'loop': random.key(2)},
)
self.assertEqual(vars['state']['acc'], x)
np.testing.assert_array_equal(
vars['state']['rng_params'][0], vars['state']['rng_params'][1]
self.assertTrue(
jnp.equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1])
)
with jax_debug_key_reuse(False):
np.testing.assert_array_compare(
operator.__ne__,
self.assertFalse(
jnp.equal(
vars['state']['rng_loop'][0],
vars['state']['rng_loop'][1],
)
)

def test_cond(self):
Expand Down
Loading