Skip to content

Commit

Permalink
add compact_name_scope v2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600454908
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Jan 22, 2024
1 parent 77da09e commit 59c21cf
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
Variable as Variable,
apply as apply,
compact as compact,
compact_name_scope as compact_name_scope,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
Expand Down
114 changes: 113 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,84 @@ def nowrap(fun: _CallableT) -> _CallableT:
return fun


def compact_name_scope(fun: _CallableT) -> _CallableT:
"""Creates compact submodules from a method.
This is a decorator that allows you to define compact submodules from a
method. It's intention is to make it easier to port code Haiku code to Flax
by providing the same functionality.
Example::
>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
>>> from flax.core import pretty_repr
...
>>> class Foo(nn.Module):
... @nn.compact_name_scope
... def up(self, x):
... return nn.Dense(3)(x)
...
... @nn.compact_name_scope
... def down(self, x):
... return nn.Dense(3)(x)
...
... def __call__(self, x):
... return self.up(x) + self.down(x)
...
>>> module = Foo()
>>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
>>> params = variables['params']
>>> print(pretty_repr(jax.tree_map(jnp.shape, params)))
{
down: {
Dense_0: {
bias: (3,),
kernel: (2, 3),
},
},
up: {
Dense_0: {
bias: (3,),
kernel: (2, 3),
},
},
}
You can also use ``compact_name_scope`` inside ``@compact`` methods or even other
``compact_name_scope`` methods. Methods that are decorated with ``compact_name_scope``
can also be called directly from ``init`` or ``apply`` via the ``method`` argument::
>>> y_down = module.apply({'params': params}, jnp.ones((1, 2)), method='down')
>>> y_down.shape
(1, 3)
Args:
fun: The Module method to mark as compact_name_scope.
Returns:
The given function ``fun`` marked as compact_name_scope.
"""

@functools.wraps(fun)
def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs):
name = fun.__name__
if not hasattr(self, '_compact_name_scope_modules'):
raise ValueError(
f'Cannot call compact_name_scope method {name!r} on a Module that has not been '
f'setup. This is likely because you are calling {name!r} '
'from outside of init or apply.'
)
module = self._compact_name_scope_modules[name]
return module(*args, **kwargs)

compact_name_scope_wrapper.compact_name_scope = True # type: ignore[attr-defined]
compact_name_scope_wrapper.inner_fun = fun # type: ignore[attr-defined]
compact_name_scope_wrapper.nowrap = True # type: ignore[attr-defined]
return compact_name_scope_wrapper # type: ignore[return-value]


def _get_local_method_names(
cls: Any, exclude: Iterable[str] = ()
) -> Tuple[str, ...]:
Expand Down Expand Up @@ -955,6 +1033,7 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
# We wrap user-defined methods including setup and __call__ to enforce
# a number of different checks and to provide clear error messages.
cls._verify_single_or_no_compact()
cls._find_compact_name_scope_methods()
cls._wrap_module_attributes()
# Set empty class defaults.
cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined]
Expand Down Expand Up @@ -1046,6 +1125,17 @@ def _verify_single_or_no_compact(cls):
if n_compact_fns > 1:
raise errors.MultipleMethodsCompactError()

@classmethod
def _find_compact_name_scope_methods(cls):
"""Finds all compact_name_scope methods in the class."""
methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)]
compact_name_scope_fns = tuple(
method_name
for method_name in methods
if hasattr(getattr(cls, method_name), 'compact_name_scope')
)
cls._compact_name_scope_methods = compact_name_scope_fns

@classmethod
def _wrap_module_attributes(cls):
"""Wraps user-defined non-inherited methods and descriptors with state
Expand Down Expand Up @@ -1347,6 +1437,7 @@ def _register_submodules(self, name, val):

def adopt_attr_modules(cache, queue, suffix, subvalue):
if isinstance(subvalue, Module):
current_name = subvalue.name
adopted_name = None
if subvalue.parent is None:
# Preserve sharing-by-reference relationships during adoption
Expand All @@ -1366,7 +1457,11 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
if adopted_name is None:
adopted_name = f'{name}{suffix}'
adopted_name = (
f'{name}{suffix}'
if not isinstance(subvalue, NonTransparent)
else current_name
)
object.__setattr__(subvalue, 'name', adopted_name)
queue.append(subvalue)
return subvalue
Expand Down Expand Up @@ -1397,6 +1492,14 @@ def _try_setup(self, shallow: bool = False) -> None:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
# create NonTransparent Modules
self._compact_name_scope_modules = {
name: NonTransparent(
getattr(type(self), name).inner_fun, lambda: self, name=name
)
for name in self._compact_name_scope_methods
}

# We run static checks abstractly once for setup before any transforms
# to detect name collisions and other python errors.
elif self._state.setup_called == SetupState.NEW:
Expand Down Expand Up @@ -2835,3 +2938,12 @@ def init_wrapper(*args, **kwargs):
return init_fn(*args, **kwargs)[1]

return init_wrapper


class NonTransparent(Module):
fn: Callable
module_fn: Callable[[], Module]

@compact
def __call__(self, *args, **kwargs) -> Any:
return self.fn(self.module_fn(), *args, **kwargs)
70 changes: 70 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,76 @@ def my_property(self):
self.assertEqual(obj_loaded.b, 'ok')
self.assertEqual(obj_loaded.my_property, 'okok')

def test_compact_name_scope(self):
class Foo(nn.Module):
@nn.compact_name_scope
def up(self, x):
return nn.Dense(3)(x)

@nn.compact_name_scope
def down(self, x):
return nn.Dense(3)(x)

@nn.compact
def __call__(self, x):
return self.up(x) + self.down(x) + nn.Dense(3)(x)

m = Foo()
x = jnp.ones((1, 2))

self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'})

variables = m.init(random.key(0), x)
params = variables['params']

self.assertIn('Dense_0', params)
self.assertIn('down', params)
self.assertIn('up', params)
self.assertIn('Dense_0', params['down'])
self.assertIn('Dense_0', params['up'])

y = m.apply(variables, x)
y_up = m.apply(variables, x, method='up')
y_down = m.apply(variables, x, method='down')

assert y.shape == (1, 3)
assert y_up.shape == (1, 3)
assert y_down.shape == (1, 3)

def test_compact_name_scope_outside_compact(self):
class Foo(nn.Module):
@nn.compact_name_scope
def up(self, x):
return nn.Dense(3)(x)

@nn.compact_name_scope
def down(self, x):
return nn.Dense(3)(x)

def __call__(self, x):
return self.up(x) + self.down(x)

m = Foo()
x = jnp.ones((1, 2))

self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'})

variables = m.init(random.key(0), x)
params = variables['params']

self.assertIn('down', params)
self.assertIn('up', params)
self.assertIn('Dense_0', params['down'])
self.assertIn('Dense_0', params['up'])

y = m.apply(variables, x)
y_up = m.apply(variables, x, method='up')
y_down = m.apply(variables, x, method='down')

assert y.shape == (1, 3)
assert y_up.shape == (1, 3)
assert y_down.shape == (1, 3)


class LeakTests(absltest.TestCase):
def test_tracer_leaks(self):
Expand Down

0 comments on commit 59c21cf

Please sign in to comment.