Skip to content

Commit

Permalink
[nnx] fix initializing propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 20, 2024
1 parent 71b5a46 commit 4e3c8c8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
7 changes: 4 additions & 3 deletions flax/nnx/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _graph_node_flatten(self):
for key, value in vars(self).items()
if key != '_object__state'
)
return nodes, type(self)
return nodes, (type(self), self._object__state._initializing)

def _graph_node_set_key(self, key: Key, value: tp.Any):
if not isinstance(key, str):
Expand All @@ -214,9 +214,10 @@ def _graph_node_pop_key(self, key: Key):
return vars(self).pop(key)

@staticmethod
def _graph_node_create_empty(node_type: tp.Type[G]) -> G:
def _graph_node_create_empty(static: tuple[tp.Type[G], bool]) -> G:
node_type, initializing = static
node = object.__new__(node_type)
vars(node).update(_object__state=ObjectState())
vars(node).update(_object__state=ObjectState(initializing))
return node

def _graph_node_clear(self):
Expand Down
33 changes: 32 additions & 1 deletion flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from absl.testing import absltest

from flax import nnx, struct
from flax import nnx, struct, linen


class StatefulLinear(nnx.Module):
Expand Down Expand Up @@ -470,6 +470,37 @@ def test_getitem(self):
self.assertEqual(nodes['a'].count.value, 0)
self.assertEqual(nodes['b'].count.value, 1)

def test_object_state_propagation(self):
test = self

class Foo(nnx.Module):
def __call__(self):
test.assertTrue(self._object__state.initializing)
self = nnx.merge(*nnx.split(self))
test.assertTrue(self._object__state.initializing)

module = Foo()
nnx.bridge.lazy_init(module)

def test_object_state_propagation_nested(self):
class NNXOuter(nnx.Module):
def __init__(self, dout: int, rngs: nnx.Rngs):
self.inner = nnx.bridge.ToNNX(linen.Dense(dout), rngs=rngs)
self.rngs = rngs

def __call__(self, x):
@partial(nnx.vmap, in_axes=None, state_axes={...: 0}, axis_size=5)
def vmap_fn(inner, x):
return inner(x)

return vmap_fn(self.inner, x)

x = jax.random.normal(jax.random.key(0), (2, 4))
model = NNXOuter(3, rngs=nnx.Rngs(0))
nnx.bridge.lazy_init(model, x)

self.assertEqual(model.inner.params['kernel'].shape, (5, 4, 3))
self.assertEqual(model.inner.params['bias'].shape, (5, 3))

class SimpleModule(nnx.Module):
pass
Expand Down

0 comments on commit 4e3c8c8

Please sign in to comment.