From 36952eaf8406bc670d42eb555229d05ac61660cc Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 23 Jul 2024 17:02:03 +0100 Subject: [PATCH] [linen] add share_scope --- docs/api_reference/flax.linen/module.rst | 1 + flax/errors.py | 18 +++ flax/linen/__init__.py | 1 + flax/linen/module.py | 89 ++++++++++++++ tests/linen/linen_module_test.py | 141 +++++++++++++++++++++++ 5 files changed, 250 insertions(+) diff --git a/docs/api_reference/flax.linen/module.rst b/docs/api_reference/flax.linen/module.rst index 200bd124ef..44e92dd0a2 100644 --- a/docs/api_reference/flax.linen/module.rst +++ b/docs/api_reference/flax.linen/module.rst @@ -11,3 +11,4 @@ Module .. autofunction:: init .. autofunction:: init_with_output .. autofunction:: intercept_methods +.. autofunction:: share_scope diff --git a/flax/errors.py b/flax/errors.py index 8863ad06eb..7284c6e3fb 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -691,6 +691,24 @@ def __call__(self, x): def __init__(self): super().__init__("Can't call `unbind()` on unbound modules") +class CallShareScopeOnUnboundModuleError(FlaxError): + """This error occurs when you are trying to call ``nn.share_scope`` on an unbound + Module. For instance, when you try to use ``nn.share_scope`` at the top-level:: + + from flax import linen as nn + + class CustomDense(nn.Dense): + def __call__(self, x): + return super().__call__(x) + 1 + + custom_dense = CustomDense(5) + dense = nn.Dense(5) # has the parameters + + nn.share_scope(custom_dense, dense) # <-- ERROR! + """ + + def __init__(self): + super().__init__("Can't call `share_scope` on unbound modules") class InvalidInstanceModuleError(FlaxError): """This error occurs when you are trying to call ``.init()``, ``.init_with_output()``, ``.apply()`` or ``.bind()`` diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index f01ed92880..a6e716052b 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -117,6 +117,7 @@ merge_param as merge_param, nowrap as nowrap, override_named_call as override_named_call, + share_scope as share_scope, ) from .normalization import ( BatchNorm as BatchNorm, diff --git a/flax/linen/module.py b/flax/linen/module.py index a9772da92f..8be82da495 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -3174,3 +3174,92 @@ class CompactNameScope: def __call__(self, *args, **kwargs) -> Any: ... + + +def share_scope(module: Module, other: Module, /): + """Modifies one of the Modules such that they share the same scope. This is useful + when you want to wrap a Module and extend its functionality without changing the + parameter structure. + + ``share_scope`` takes two Modules, ``module`` and ``other``. ``module`` will use + ``other``'s scope if ``other`` has a scope and its not a descendant of``module``'s + scope:: + + >>> import flax.linen as nn + >>> import jax + >>> from jax import numpy as jnp, random + ... + >>> class DenseLoRA(nn.Module): + ... base: nn.Dense + ... rank: int + ... + ... def setup(self): + ... nn.share_scope(self, self.base) + ... + ... @nn.compact + ... def __call__(self, x: jax.Array): + ... din, dout = x.shape[-1], self.base.features + ... A = self.param('A', nn.zeros_init(), (din, self.rank)) + ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) + ... return self.base(x) + x @ A @ B + ... + >>> class Model(nn.Module): + ... @nn.compact + ... def __call__(self, x: jax.Array): + ... dense = nn.Dense(10) # base scope + ... return DenseLoRA(dense, rank=2)(x) # reuse the base scope + ... + >>> model = Model() + ... + >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] + >>> list(params['Dense_0'].keys()) + ['A', 'B', 'kernel', 'bias'] + + When ``other``'s scope is a descendant of ``module``'s scope then ``other`` + will use ``module``'s scope instead:: + + >>> class DenseLoRA(nn.Module): + ... features: int + ... rank: int + ... + ... def setup(self): + ... self.child = nn.Dense(self.features) + ... nn.share_scope(self, self.child) + ... + ... @nn.compact + ... def __call__(self, x: jax.Array): + ... din, dout = x.shape[-1], self.features + ... A = self.param('A', nn.zeros_init(), (din, self.rank)) + ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) + ... return self.child(x) + x @ A @ B + ... + >>> class Model(nn.Module): + ... @nn.compact + ... def __call__(self, x: jax.Array): + ... return DenseLoRA(10, rank=2)(x) + ... + >>> model = Model() + ... + >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] + >>> list(params['DenseLoRA_0'].keys()) + ['A', 'B', 'kernel', 'bias'] + """ + if module.scope is None: + raise errors.CallShareScopeOnUnboundModuleError() + + def _is_child_scope(scope: Scope, other: Scope) -> bool: + target: Scope | None = other + + while target is not None: + if target is scope: + return True + target = target.parent + return False + + if other.scope is not None and _is_child_scope(module.scope, other.scope): + # Child is a true child, overwrite its scope + object.__setattr__(other, 'scope', module.scope) + else: + # Child has its own independent scope, overwrite + # parent scope, so that we preserve the sharing + object.__setattr__(module, 'scope', other.scope) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index d2664003f3..30bfe993c2 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -3167,5 +3167,146 @@ def test_frozendict_flag(self): self.assertTrue(isinstance(params, dict)) +class ShareScopeTest(absltest.TestCase): + def test_basic(self): + class DenseLoRA(nn.Module): + inner: nn.Dense + rank: int + + def setup(self): + nn.share_scope(self, self.inner) + + @nn.compact + def __call__(self, x: jax.Array): + din, dout = x.shape[-1], self.inner.features + A = self.param('A', nn.zeros_init(), (din, self.rank)) + B = self.param('B', nn.zeros_init(), (self.rank, dout)) + return self.inner(x) + x @ A @ B + + dense_lora = DenseLoRA(nn.Dense(10), rank=2) + + params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params'] + + self.assertIn('kernel', params) + self.assertIn('bias', params) + self.assertIn('A', params) + self.assertIn('B', params) + + def test_child_scope(self): + class DenseLoRA(nn.Module): + rank: int + + def setup(self): + self.child = nn.Dense(10) + nn.share_scope(self, self.child) + + @nn.compact + def __call__(self, x: jax.Array): + din, dout = x.shape[-1], self.child.features + A = self.param('A', nn.zeros_init(), (din, self.rank)) + B = self.param('B', nn.zeros_init(), (self.rank, dout)) + return self.child(x) + x @ A @ B + + dense_lora = DenseLoRA(rank=2) + + params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params'] + + self.assertIn('kernel', params) + self.assertIn('bias', params) + self.assertIn('A', params) + self.assertIn('B', params) + + def test_in_compact(self): + class DenseLoRA(nn.Module): + rank: int + + def setup(self): + self.child = nn.Dense(10) + nn.share_scope(self, self.child) + + @nn.compact + def __call__(self, x: jax.Array): + din, dout = x.shape[-1], self.child.features + A = self.param('A', nn.zeros_init(), (din, self.rank)) + B = self.param('B', nn.zeros_init(), (self.rank, dout)) + return self.child(x) + x @ A @ B + + class Model(nn.Module): + @nn.compact + def __call__(self, x: jax.Array): + return DenseLoRA(rank=2)(x) + + model = Model() + + params = model.init(random.key(0), jnp.ones((1, 5)))['params'] + + self.assertIn('kernel', params['DenseLoRA_0']) + self.assertIn('bias', params['DenseLoRA_0']) + self.assertIn('A', params['DenseLoRA_0']) + self.assertIn('B', params['DenseLoRA_0']) + + def test_adopt_child_name(self): + class DenseLoRA(nn.Module): + inner: nn.Dense + rank: int + + def setup(self): + nn.share_scope(self, self.inner) + + @nn.compact + def __call__(self, x: jax.Array): + din, dout = x.shape[-1], self.inner.features + A = self.param('A', nn.zeros_init(), (din, self.rank)) + B = self.param('B', nn.zeros_init(), (self.rank, dout)) + return self.inner(x) + x @ A @ B + + class Model(nn.Module): + @nn.compact + def __call__(self, x: jax.Array): + return DenseLoRA(nn.Dense(10), rank=2)(x) + + model = Model() + + params = model.init(random.key(0), jnp.ones((1, 5)))['params'] + + self.assertIn('kernel', params['Dense_0']) + self.assertIn('bias', params['Dense_0']) + self.assertIn('A', params['Dense_0']) + self.assertIn('B', params['Dense_0']) + + def test_other_scope_is_none(self): + class DenseLoRA(nn.Module): + inner: nn.Dense + rank: int + + def setup(self): + nn.share_scope(self, self.inner) + + @nn.compact + def __call__(self, x: jax.Array): + din, dout = x.shape[-1], self.inner.features + A = self.param('A', nn.zeros_init(), (din, self.rank)) + B = self.param('B', nn.zeros_init(), (self.rank, dout)) + return self.inner(x) + x @ A @ B + + class Model(nn.Module): + def setup(self): + # here Dense doesn't have a scope yet + self.dense_lora = DenseLoRA(nn.Dense(10), rank=2) + + @nn.compact + def __call__(self, x: jax.Array): + return self.dense_lora(x) + + model = Model() + + params = model.init(random.key(0), jnp.ones((1, 5)))['params'] + + self.assertIn('kernel', params['dense_lora']) + self.assertIn('bias', params['dense_lora']) + self.assertIn('A', params['dense_lora']) + self.assertIn('B', params['dense_lora']) + + if __name__ == '__main__': absltest.main()