Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Module.clone in deepclone mode for internal usage. #3459

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,8 @@ def clone_fn(m: Module) -> Module:
# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
if field_name == 'parent':
continue
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)
Expand Down
30 changes: 30 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,36 @@ def __call__(self, x):
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)

def test_internal_deep_clone(self):
class Child(nn.Module):
@nn.compact
def __call__(self, x):
w = self.param('w', nn.initializers.zeros, (5, x.shape[1]))
return x @ w

class Parent(nn.Module):
num_layers: int
child_template: Child

@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
x = self.child_template.clone(
parent=self, _deep_clone=True, name=None
)(x)
return x

model = Parent(num_layers=2, child_template=Child())
x = jnp.ones((32, 5))
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(
jnp.all(
variables['params']['Child_0']['w']
== variables['params']['Child_1']['w']
)
)


class FrozenDictTests(absltest.TestCase):
def test_frozendict_flag(self):
Expand Down
Loading