Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compact_name_scope v2 #3640

Merged
1 commit merged into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
Variable as Variable,
apply as apply,
compact as compact,
compact_name_scope as compact_name_scope,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
Expand Down
114 changes: 113 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,84 @@ def nowrap(fun: _CallableT) -> _CallableT:
return fun


def compact_name_scope(fun: _CallableT) -> _CallableT:
"""Creates compact submodules from a method.

This is a decorator that allows you to define compact submodules from a
method. It's intention is to make it easier to port code Haiku code to Flax
by providing the same functionality.

Example::

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
>>> from flax.core import pretty_repr
...
>>> class Foo(nn.Module):
... @nn.compact_name_scope
... def up(self, x):
... return nn.Dense(3)(x)
...
... @nn.compact_name_scope
... def down(self, x):
... return nn.Dense(3)(x)
...
... def __call__(self, x):
... return self.up(x) + self.down(x)
...
>>> module = Foo()
>>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
>>> params = variables['params']
>>> print(pretty_repr(jax.tree_map(jnp.shape, params)))
{
down: {
Dense_0: {
bias: (3,),
kernel: (2, 3),
},
},
up: {
Dense_0: {
bias: (3,),
kernel: (2, 3),
},
},
}

You can also use ``compact_name_scope`` inside ``@compact`` methods or even other
``compact_name_scope`` methods. Methods that are decorated with ``compact_name_scope``
can also be called directly from ``init`` or ``apply`` via the ``method`` argument::

>>> y_down = module.apply({'params': params}, jnp.ones((1, 2)), method='down')
>>> y_down.shape
(1, 3)

Args:
fun: The Module method to mark as compact_name_scope.

Returns:
The given function ``fun`` marked as compact_name_scope.
"""

@functools.wraps(fun)
def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs):
name = fun.__name__
if not hasattr(self, '_compact_name_scope_modules'):
raise ValueError(
f'Cannot call compact_name_scope method {name!r} on a Module that has not been '
f'setup. This is likely because you are calling {name!r} '
'from outside of init or apply.'
)
module = self._compact_name_scope_modules[name]
return module(*args, **kwargs)

compact_name_scope_wrapper.compact_name_scope = True # type: ignore[attr-defined]
compact_name_scope_wrapper.inner_fun = fun # type: ignore[attr-defined]
compact_name_scope_wrapper.nowrap = True # type: ignore[attr-defined]
return compact_name_scope_wrapper # type: ignore[return-value]


def _get_local_method_names(
cls: Any, exclude: Iterable[str] = ()
) -> Tuple[str, ...]:
Expand Down Expand Up @@ -955,6 +1033,7 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
# We wrap user-defined methods including setup and __call__ to enforce
# a number of different checks and to provide clear error messages.
cls._verify_single_or_no_compact()
cls._find_compact_name_scope_methods()
cls._wrap_module_attributes()
# Set empty class defaults.
cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined]
Expand Down Expand Up @@ -1046,6 +1125,17 @@ def _verify_single_or_no_compact(cls):
if n_compact_fns > 1:
raise errors.MultipleMethodsCompactError()

@classmethod
def _find_compact_name_scope_methods(cls):
"""Finds all compact_name_scope methods in the class."""
methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)]
compact_name_scope_fns = tuple(
method_name
for method_name in methods
if hasattr(getattr(cls, method_name), 'compact_name_scope')
)
cls._compact_name_scope_methods = compact_name_scope_fns

@classmethod
def _wrap_module_attributes(cls):
"""Wraps user-defined non-inherited methods and descriptors with state
Expand Down Expand Up @@ -1347,6 +1437,7 @@ def _register_submodules(self, name, val):

def adopt_attr_modules(cache, queue, suffix, subvalue):
if isinstance(subvalue, Module):
current_name = subvalue.name
adopted_name = None
if subvalue.parent is None:
# Preserve sharing-by-reference relationships during adoption
Expand All @@ -1366,7 +1457,11 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
if adopted_name is None:
adopted_name = f'{name}{suffix}'
adopted_name = (
f'{name}{suffix}'
if not isinstance(subvalue, NonTransparent)
else current_name
)
object.__setattr__(subvalue, 'name', adopted_name)
queue.append(subvalue)
return subvalue
Expand Down Expand Up @@ -1397,6 +1492,14 @@ def _try_setup(self, shallow: bool = False) -> None:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
# create NonTransparent Modules
self._compact_name_scope_modules = {
name: NonTransparent(
getattr(type(self), name).inner_fun, lambda: self, name=name
)
for name in self._compact_name_scope_methods
}

# We run static checks abstractly once for setup before any transforms
# to detect name collisions and other python errors.
elif self._state.setup_called == SetupState.NEW:
Expand Down Expand Up @@ -2835,3 +2938,12 @@ def init_wrapper(*args, **kwargs):
return init_fn(*args, **kwargs)[1]

return init_wrapper


class NonTransparent(Module):
fn: Callable
module_fn: Callable[[], Module]

@compact
def __call__(self, *args, **kwargs) -> Any:
return self.fn(self.module_fn(), *args, **kwargs)
70 changes: 70 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,76 @@ def my_property(self):
self.assertEqual(obj_loaded.b, 'ok')
self.assertEqual(obj_loaded.my_property, 'okok')

def test_compact_name_scope(self):
class Foo(nn.Module):
@nn.compact_name_scope
def up(self, x):
return nn.Dense(3)(x)

@nn.compact_name_scope
def down(self, x):
return nn.Dense(3)(x)

@nn.compact
def __call__(self, x):
return self.up(x) + self.down(x) + nn.Dense(3)(x)

m = Foo()
x = jnp.ones((1, 2))

self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'})

variables = m.init(random.key(0), x)
params = variables['params']

self.assertIn('Dense_0', params)
self.assertIn('down', params)
self.assertIn('up', params)
self.assertIn('Dense_0', params['down'])
self.assertIn('Dense_0', params['up'])

y = m.apply(variables, x)
y_up = m.apply(variables, x, method='up')
y_down = m.apply(variables, x, method='down')

assert y.shape == (1, 3)
assert y_up.shape == (1, 3)
assert y_down.shape == (1, 3)

def test_compact_name_scope_outside_compact(self):
class Foo(nn.Module):
@nn.compact_name_scope
def up(self, x):
return nn.Dense(3)(x)

@nn.compact_name_scope
def down(self, x):
return nn.Dense(3)(x)

def __call__(self, x):
return self.up(x) + self.down(x)

m = Foo()
x = jnp.ones((1, 2))

self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'})

variables = m.init(random.key(0), x)
params = variables['params']

self.assertIn('down', params)
self.assertIn('up', params)
self.assertIn('Dense_0', params['down'])
self.assertIn('Dense_0', params['up'])

y = m.apply(variables, x)
y_up = m.apply(variables, x, method='up')
y_down = m.apply(variables, x, method='down')

assert y.shape == (1, 3)
assert y_up.shape == (1, 3)
assert y_down.shape == (1, 3)


class LeakTests(absltest.TestCase):
def test_tracer_leaks(self):
Expand Down
Loading