diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb
index 30c2dcb499..351ae8b6e2 100644
--- a/docs_nnx/nnx_basics.ipynb
+++ b/docs_nnx/nnx_basics.ipynb
@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
"tags": [
"skip-execution"
@@ -88,7 +88,19 @@
{
"data": {
"text/html": [
- "
(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -180,7 +192,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -235,7 +259,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -359,7 +395,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -418,7 +466,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -467,7 +527,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -479,7 +551,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
],
"text/plain": [
""
@@ -580,7 +652,31 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -592,7 +688,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
],
"text/plain": [
""
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index fec21add20..2339f5c168 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -290,6 +290,24 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)
+@dataclasses.dataclass(frozen=True, slots=True)
+class SubGraphAttribute:
+ key: Key
+ value: NodeDef[tp.Any] | NodeRef[tp.Any]
+
+
+@dataclasses.dataclass(frozen=True, slots=True)
+class StaticAttribute:
+ key: Key
+ value: tp.Any
+
+
+@dataclasses.dataclass(frozen=True, slots=True)
+class LeafAttribute:
+ key: Key
+ value: VariableDef | NodeRef[tp.Any]
+
+
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
@@ -298,10 +316,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable):
type: tp.Type[Node]
index: int
- attributes: tuple[Key, ...]
- subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
- static_fields: HashableMapping[Key, tp.Any]
- leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
+ attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...]
metadata: tp.Any
index_mapping: HashableMapping[Index, Index] | None
@@ -310,10 +325,7 @@ def create(
cls,
type: tp.Type[Node],
index: int,
- attributes: tuple[Key, ...],
- subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]],
- static_fields: tp.Iterable[tuple[Key, tp.Any]],
- leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]],
+ attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...],
metadata: tp.Any,
index_mapping: tp.Mapping[Index, Index] | None,
):
@@ -321,9 +333,6 @@ def create(
type=type,
index=index,
attributes=attributes,
- subgraphs=HashableMapping(subgraphs),
- static_fields=HashableMapping(static_fields),
- leaves=HashableMapping(leaves),
metadata=metadata,
index_mapping=HashableMapping(index_mapping)
if index_mapping is not None
@@ -335,12 +344,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
- yield reprlib.Attr('attributes', self.attributes)
- yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs))
- yield reprlib.Attr(
- 'static_fields', reprlib.PrettyMapping(self.static_fields)
- )
- yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
+ yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes))
yield reprlib.Attr('metadata', self.metadata)
yield reprlib.Attr(
'index_mapping',
@@ -352,18 +356,15 @@ def __nnx_repr__(self):
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
- object_type=type(self),
- attributes={
- 'type': self.type,
- 'index': self.index,
- 'attributes': self.attributes,
- 'subgraphs': dict(self.subgraphs),
- 'static_fields': dict(self.static_fields),
- 'leaves': dict(self.leaves),
- 'metadata': self.metadata,
- },
- path=path,
- subtree_renderer=subtree_renderer,
+ object_type=type(self),
+ attributes={
+ 'type': self.type,
+ 'index': self.index,
+ 'attributes': self.attributes,
+ 'metadata': self.metadata,
+ },
+ path=path,
+ subtree_renderer=subtree_renderer,
)
def apply(
@@ -426,40 +427,39 @@ def _graph_flatten(
else:
index = -1
- subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
- static_fields: list[tuple[Key, tp.Any]] = []
- leaves: list[tuple[Key, VariableDef | NodeRef]] = []
+ attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = []
values, metadata = node_impl.flatten(node)
for key, value in values:
if is_node(value):
nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
- subgraphs.append((key, nodedef))
+ # subgraphs.append((key, nodedef))
+ attributes.append(SubGraphAttribute(key, nodedef))
elif isinstance(value, Variable):
if value in ref_index:
- leaves.append((key, NodeRef(type(value), ref_index[value])))
+ attributes.append(
+ LeafAttribute(key, NodeRef(type(value), ref_index[value]))
+ )
else:
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
)
- leaves.append((key, variabledef))
+ attributes.append(LeafAttribute(key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
- static_fields.append((key, value))
+ # static_fields.append((key, value))
+ attributes.append(StaticAttribute(key, value))
nodedef = NodeDef.create(
type=node_impl.type,
index=index,
- attributes=tuple(key for key, _ in values),
- subgraphs=subgraphs,
- static_fields=static_fields,
- leaves=leaves,
+ attributes=tuple(attributes),
metadata=metadata,
index_mapping=None,
)
@@ -529,22 +529,20 @@ def _graph_unflatten(
def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
-
- # NOTE: we could allw adding new StateLeafs here
- if unkown_keys := set(state) - set(nodedef.attributes):
- raise ValueError(f'Unknown keys: {unkown_keys}')
+ state_keys: set = set(state.keys())
# for every key in attributes there are 6 possible cases:
# - (2) the key can either be present in the state or not
# - (3) the key can be a subgraph, a leaf, or a static attribute
- for key in nodedef.attributes:
+ for attribute in nodedef.attributes:
+ key = attribute.key
if key not in state:
# if key is not present create an empty types
- if key in nodedef.static_fields:
- children[key] = nodedef.static_fields[key]
- elif key in nodedef.subgraphs:
+ if type(attribute) is StaticAttribute:
+ children[key] = attribute.value
+ elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
- subgraphdef = nodedef.subgraphs[key]
+ subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
@@ -558,8 +556,8 @@ def _get_children():
children[key] = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
- elif key in nodedef.leaves:
- variabledef = nodedef.leaves[key]
+ elif type(attribute) is LeafAttribute:
+ variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
@@ -572,19 +570,21 @@ def _get_children():
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
+ state_keys.remove(key)
value = state[key]
- if key in nodedef.static_fields:
+ # if key in nodedef.static_fields:
+ if type(attribute) is StaticAttribute:
raise ValueError(
f'Got state for static field {key!r}, this is not supported.'
)
- if key in nodedef.subgraphs:
+ elif type(attribute) is SubGraphAttribute:
if is_state_leaf(value):
raise ValueError(
- f'Expected value of type {nodedef.subgraphs[key]} for '
+ f'Expected value of type {attribute.value} for '
f'{key!r}, but got {value!r}'
)
assert isinstance(value, dict)
- subgraphdef = nodedef.subgraphs[key]
+ subgraphdef = attribute.value
if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
@@ -593,8 +593,8 @@ def _get_children():
subgraphdef, value, index_ref, index_ref_cache
)
- elif key in nodedef.leaves:
- variabledef = nodedef.leaves[key]
+ elif type(attribute) is LeafAttribute:
+ variabledef = attribute.value
if variabledef.index in index_ref:
# add an existing variable
@@ -631,6 +631,10 @@ def _get_children():
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
+ # NOTE: we could allw adding new StateLeafs here
+ if state_keys:
+ raise ValueError(f'Unknown keys: {state_keys}')
+
return children
if isinstance(node_impl, GraphNodeImpl):
diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py
index 855a3049b2..eaa58a051c 100644
--- a/flax/nnx/reprlib.py
+++ b/flax/nnx/reprlib.py
@@ -121,4 +121,14 @@ def __nnx_repr__(self):
yield Object(type='', value_sep=': ', start='{', end='}')
for key, value in self.mapping.items():
- yield Attr(repr(key), value)
\ No newline at end of file
+ yield Attr(repr(key), value)
+
+@dataclasses.dataclass(repr=False)
+class PrettySequence(Representable):
+ list: tp.Sequence
+
+ def __nnx_repr__(self):
+ yield Object(type='', value_sep='', start='[', end=']')
+
+ for value in self.list:
+ yield Attr('', value)
\ No newline at end of file
diff --git a/flax/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash
index ab896ebd6a..9fcfec0215 100644
--- a/flax/nnx/scripts/run-all-examples.bash
+++ b/flax/nnx/scripts/run-all-examples.bash
@@ -1,9 +1,8 @@
set -e
source .venv/bin/activate
-cd flax/nnx
-for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do
+for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do
echo -e "\n---------------------------------"
echo "$f"
echo "---------------------------------"
diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py
index c9a3c1c4be..994e582862 100644
--- a/flax/nnx/transforms/iteration.py
+++ b/flax/nnx/transforms/iteration.py
@@ -1341,11 +1341,16 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef):
global_index_mapping[nd.index] = nd.index
if isinstance(nd, graph.NodeRef):
return
- for sub_nd in nd.subgraphs.values():
- per_node_def(sub_nd)
- for l in nd.leaves.values():
- if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0:
- global_index_mapping[l.index] = l.index
+
+ for attribute in nd.attributes:
+ if type(attribute) is graph.SubGraphAttribute:
+ per_node_def(attribute.value)
+ elif (
+ type(attribute) is graph.LeafAttribute
+ and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef))
+ and attribute.value.index >= 0
+ ):
+ global_index_mapping[attribute.value.index] = attribute.value.index
return
per_node_def(ns._graphdef)
diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py
index fb0496e07a..a7bbf178cb 100644
--- a/tests/nnx/graph_utils_test.py
+++ b/tests/nnx/graph_utils_test.py
@@ -303,7 +303,7 @@ def __init__(self):
assert 'tree' in state
assert 'a' in state.tree
- assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree
+ assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree
m2 = nnx.merge(graphdef, state)