diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 30c2dcb499..351ae8b6e2 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [ "skip-execution" @@ -88,7 +88,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -180,7 +192,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -235,7 +259,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -359,7 +395,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -418,7 +466,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -467,7 +527,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -479,7 +551,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -580,7 +652,31 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -592,7 +688,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index fec21add20..2339f5c168 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -290,6 +290,24 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(VariableDef) +@dataclasses.dataclass(frozen=True, slots=True) +class SubGraphAttribute: + key: Key + value: NodeDef[tp.Any] | NodeRef[tp.Any] + + +@dataclasses.dataclass(frozen=True, slots=True) +class StaticAttribute: + key: Key + value: tp.Any + + +@dataclasses.dataclass(frozen=True, slots=True) +class LeafAttribute: + key: Key + value: VariableDef | NodeRef[tp.Any] + + @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a @@ -298,10 +316,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable): type: tp.Type[Node] index: int - attributes: tuple[Key, ...] - subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] - static_fields: HashableMapping[Key, tp.Any] - leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]] + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any index_mapping: HashableMapping[Index, Index] | None @@ -310,10 +325,7 @@ def create( cls, type: tp.Type[Node], index: int, - attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]], - static_fields: tp.Iterable[tuple[Key, tp.Any]], - leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]], + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, index_mapping: tp.Mapping[Index, Index] | None, ): @@ -321,9 +333,6 @@ def create( type=type, index=index, attributes=attributes, - subgraphs=HashableMapping(subgraphs), - static_fields=HashableMapping(static_fields), - leaves=HashableMapping(leaves), metadata=metadata, index_mapping=HashableMapping(index_mapping) if index_mapping is not None @@ -335,12 +344,7 @@ def __nnx_repr__(self): yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) - yield reprlib.Attr('attributes', self.attributes) - yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs)) - yield reprlib.Attr( - 'static_fields', reprlib.PrettyMapping(self.static_fields) - ) - yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves)) + yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes)) yield reprlib.Attr('metadata', self.metadata) yield reprlib.Attr( 'index_mapping', @@ -352,18 +356,15 @@ def __nnx_repr__(self): def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( - object_type=type(self), - attributes={ - 'type': self.type, - 'index': self.index, - 'attributes': self.attributes, - 'subgraphs': dict(self.subgraphs), - 'static_fields': dict(self.static_fields), - 'leaves': dict(self.leaves), - 'metadata': self.metadata, - }, - path=path, - subtree_renderer=subtree_renderer, + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'attributes': self.attributes, + 'metadata': self.metadata, + }, + path=path, + subtree_renderer=subtree_renderer, ) def apply( @@ -426,40 +427,39 @@ def _graph_flatten( else: index = -1 - subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = [] - static_fields: list[tuple[Key, tp.Any]] = [] - leaves: list[tuple[Key, VariableDef | NodeRef]] = [] + attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) - subgraphs.append((key, nodedef)) + # subgraphs.append((key, nodedef)) + attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): if value in ref_index: - leaves.append((key, NodeRef(type(value), ref_index[value]))) + attributes.append( + LeafAttribute(key, NodeRef(type(value), ref_index[value])) + ) else: flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value.get_metadata()) ) - leaves.append((key, variabledef)) + attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) - static_fields.append((key, value)) + # static_fields.append((key, value)) + attributes.append(StaticAttribute(key, value)) nodedef = NodeDef.create( type=node_impl.type, index=index, - attributes=tuple(key for key, _ in values), - subgraphs=subgraphs, - static_fields=static_fields, - leaves=leaves, + attributes=tuple(attributes), metadata=metadata, index_mapping=None, ) @@ -529,22 +529,20 @@ def _graph_unflatten( def _get_children(): children: dict[Key, NodeLeaf | Node] = {} - - # NOTE: we could allw adding new StateLeafs here - if unkown_keys := set(state) - set(nodedef.attributes): - raise ValueError(f'Unknown keys: {unkown_keys}') + state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: # - (2) the key can either be present in the state or not # - (3) the key can be a subgraph, a leaf, or a static attribute - for key in nodedef.attributes: + for attribute in nodedef.attributes: + key = attribute.key if key not in state: # if key is not present create an empty types - if key in nodedef.static_fields: - children[key] = nodedef.static_fields[key] - elif key in nodedef.subgraphs: + if type(attribute) is StaticAttribute: + children[key] = attribute.value + elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache @@ -558,8 +556,8 @@ def _get_children(): children[key] = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - variabledef = nodedef.leaves[key] + elif type(attribute) is LeafAttribute: + variabledef = attribute.value if variabledef.index in index_ref: # variable exists, take it from the cache children[key] = index_ref[variabledef.index] @@ -572,19 +570,21 @@ def _get_children(): else: raise RuntimeError(f'Unknown static field: {key!r}') else: + state_keys.remove(key) value = state[key] - if key in nodedef.static_fields: + # if key in nodedef.static_fields: + if type(attribute) is StaticAttribute: raise ValueError( f'Got state for static field {key!r}, this is not supported.' ) - if key in nodedef.subgraphs: + elif type(attribute) is SubGraphAttribute: if is_state_leaf(value): raise ValueError( - f'Expected value of type {nodedef.subgraphs[key]} for ' + f'Expected value of type {attribute.value} for ' f'{key!r}, but got {value!r}' ) assert isinstance(value, dict) - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): children[key] = index_ref[subgraphdef.index] @@ -593,8 +593,8 @@ def _get_children(): subgraphdef, value, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - variabledef = nodedef.leaves[key] + elif type(attribute) is LeafAttribute: + variabledef = attribute.value if variabledef.index in index_ref: # add an existing variable @@ -631,6 +631,10 @@ def _get_children(): else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') + # NOTE: we could allw adding new StateLeafs here + if state_keys: + raise ValueError(f'Unknown keys: {state_keys}') + return children if isinstance(node_impl, GraphNodeImpl): diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index 855a3049b2..eaa58a051c 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -121,4 +121,14 @@ def __nnx_repr__(self): yield Object(type='', value_sep=': ', start='{', end='}') for key, value in self.mapping.items(): - yield Attr(repr(key), value) \ No newline at end of file + yield Attr(repr(key), value) + +@dataclasses.dataclass(repr=False) +class PrettySequence(Representable): + list: tp.Sequence + + def __nnx_repr__(self): + yield Object(type='', value_sep='', start='[', end=']') + + for value in self.list: + yield Attr('', value) \ No newline at end of file diff --git a/flax/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash index ab896ebd6a..9fcfec0215 100644 --- a/flax/nnx/scripts/run-all-examples.bash +++ b/flax/nnx/scripts/run-all-examples.bash @@ -1,9 +1,8 @@ set -e source .venv/bin/activate -cd flax/nnx -for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do +for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" echo "$f" echo "---------------------------------" diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c9a3c1c4be..994e582862 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1341,11 +1341,16 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): global_index_mapping[nd.index] = nd.index if isinstance(nd, graph.NodeRef): return - for sub_nd in nd.subgraphs.values(): - per_node_def(sub_nd) - for l in nd.leaves.values(): - if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0: - global_index_mapping[l.index] = l.index + + for attribute in nd.attributes: + if type(attribute) is graph.SubGraphAttribute: + per_node_def(attribute.value) + elif ( + type(attribute) is graph.LeafAttribute + and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) + and attribute.value.index >= 0 + ): + global_index_mapping[attribute.value.index] = attribute.value.index return per_node_def(ns._graphdef) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index fb0496e07a..a7bbf178cb 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -303,7 +303,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree + assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state)