Skip to content

Commit

Permalink
Merge pull request #4102 from google:linen-transparent
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658378936
  • Loading branch information
Flax Authors committed Aug 1, 2024
2 parents cd6218f + 36952ea commit d20f594
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api_reference/flax.linen/module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Module
.. autofunction:: init
.. autofunction:: init_with_output
.. autofunction:: intercept_methods
.. autofunction:: share_scope
18 changes: 18 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,24 @@ def __call__(self, x):
def __init__(self):
super().__init__("Can't call `unbind()` on unbound modules")

class CallShareScopeOnUnboundModuleError(FlaxError):
"""This error occurs when you are trying to call ``nn.share_scope`` on an unbound
Module. For instance, when you try to use ``nn.share_scope`` at the top-level::
from flax import linen as nn
class CustomDense(nn.Dense):
def __call__(self, x):
return super().__call__(x) + 1
custom_dense = CustomDense(5)
dense = nn.Dense(5) # has the parameters
nn.share_scope(custom_dense, dense) # <-- ERROR!
"""

def __init__(self):
super().__init__("Can't call `share_scope` on unbound modules")

class InvalidInstanceModuleError(FlaxError):
"""This error occurs when you are trying to call ``.init()``, ``.init_with_output()``, ``.apply()`` or ``.bind()``
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
merge_param as merge_param,
nowrap as nowrap,
override_named_call as override_named_call,
share_scope as share_scope,
)
from .normalization import (
BatchNorm as BatchNorm,
Expand Down
89 changes: 89 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3174,3 +3174,92 @@ class CompactNameScope:

def __call__(self, *args, **kwargs) -> Any:
...


def share_scope(module: Module, other: Module, /):
"""Modifies one of the Modules such that they share the same scope. This is useful
when you want to wrap a Module and extend its functionality without changing the
parameter structure.
``share_scope`` takes two Modules, ``module`` and ``other``. ``module`` will use
``other``'s scope if ``other`` has a scope and its not a descendant of``module``'s
scope::
>>> import flax.linen as nn
>>> import jax
>>> from jax import numpy as jnp, random
...
>>> class DenseLoRA(nn.Module):
... base: nn.Dense
... rank: int
...
... def setup(self):
... nn.share_scope(self, self.base)
...
... @nn.compact
... def __call__(self, x: jax.Array):
... din, dout = x.shape[-1], self.base.features
... A = self.param('A', nn.zeros_init(), (din, self.rank))
... B = self.param('B', nn.zeros_init(), (self.rank, dout))
... return self.base(x) + x @ A @ B
...
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, x: jax.Array):
... dense = nn.Dense(10) # base scope
... return DenseLoRA(dense, rank=2)(x) # reuse the base scope
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['Dense_0'].keys())
['A', 'B', 'kernel', 'bias']
When ``other``'s scope is a descendant of ``module``'s scope then ``other``
will use ``module``'s scope instead::
>>> class DenseLoRA(nn.Module):
... features: int
... rank: int
...
... def setup(self):
... self.child = nn.Dense(self.features)
... nn.share_scope(self, self.child)
...
... @nn.compact
... def __call__(self, x: jax.Array):
... din, dout = x.shape[-1], self.features
... A = self.param('A', nn.zeros_init(), (din, self.rank))
... B = self.param('B', nn.zeros_init(), (self.rank, dout))
... return self.child(x) + x @ A @ B
...
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, x: jax.Array):
... return DenseLoRA(10, rank=2)(x)
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['DenseLoRA_0'].keys())
['A', 'B', 'kernel', 'bias']
"""
if module.scope is None:
raise errors.CallShareScopeOnUnboundModuleError()

def _is_child_scope(scope: Scope, other: Scope) -> bool:
target: Scope | None = other

while target is not None:
if target is scope:
return True
target = target.parent
return False

if other.scope is not None and _is_child_scope(module.scope, other.scope):
# Child is a true child, overwrite its scope
object.__setattr__(other, 'scope', module.scope)
else:
# Child has its own independent scope, overwrite
# parent scope, so that we preserve the sharing
object.__setattr__(module, 'scope', other.scope)
141 changes: 141 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3167,5 +3167,146 @@ def test_frozendict_flag(self):
self.assertTrue(isinstance(params, dict))


class ShareScopeTest(absltest.TestCase):
def test_basic(self):
class DenseLoRA(nn.Module):
inner: nn.Dense
rank: int

def setup(self):
nn.share_scope(self, self.inner)

@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.inner.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.inner(x) + x @ A @ B

dense_lora = DenseLoRA(nn.Dense(10), rank=2)

params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params']

self.assertIn('kernel', params)
self.assertIn('bias', params)
self.assertIn('A', params)
self.assertIn('B', params)

def test_child_scope(self):
class DenseLoRA(nn.Module):
rank: int

def setup(self):
self.child = nn.Dense(10)
nn.share_scope(self, self.child)

@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.child.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.child(x) + x @ A @ B

dense_lora = DenseLoRA(rank=2)

params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params']

self.assertIn('kernel', params)
self.assertIn('bias', params)
self.assertIn('A', params)
self.assertIn('B', params)

def test_in_compact(self):
class DenseLoRA(nn.Module):
rank: int

def setup(self):
self.child = nn.Dense(10)
nn.share_scope(self, self.child)

@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.child.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.child(x) + x @ A @ B

class Model(nn.Module):
@nn.compact
def __call__(self, x: jax.Array):
return DenseLoRA(rank=2)(x)

model = Model()

params = model.init(random.key(0), jnp.ones((1, 5)))['params']

self.assertIn('kernel', params['DenseLoRA_0'])
self.assertIn('bias', params['DenseLoRA_0'])
self.assertIn('A', params['DenseLoRA_0'])
self.assertIn('B', params['DenseLoRA_0'])

def test_adopt_child_name(self):
class DenseLoRA(nn.Module):
inner: nn.Dense
rank: int

def setup(self):
nn.share_scope(self, self.inner)

@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.inner.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.inner(x) + x @ A @ B

class Model(nn.Module):
@nn.compact
def __call__(self, x: jax.Array):
return DenseLoRA(nn.Dense(10), rank=2)(x)

model = Model()

params = model.init(random.key(0), jnp.ones((1, 5)))['params']

self.assertIn('kernel', params['Dense_0'])
self.assertIn('bias', params['Dense_0'])
self.assertIn('A', params['Dense_0'])
self.assertIn('B', params['Dense_0'])

def test_other_scope_is_none(self):
class DenseLoRA(nn.Module):
inner: nn.Dense
rank: int

def setup(self):
nn.share_scope(self, self.inner)

@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.inner.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.inner(x) + x @ A @ B

class Model(nn.Module):
def setup(self):
# here Dense doesn't have a scope yet
self.dense_lora = DenseLoRA(nn.Dense(10), rank=2)

@nn.compact
def __call__(self, x: jax.Array):
return self.dense_lora(x)

model = Model()

params = model.init(random.key(0), jnp.ones((1, 5)))['params']

self.assertIn('kernel', params['dense_lora'])
self.assertIn('bias', params['dense_lora'])
self.assertIn('A', params['dense_lora'])
self.assertIn('B', params['dense_lora'])


if __name__ == '__main__':
absltest.main()

0 comments on commit d20f594

Please sign in to comment.