diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/experimental/nnx/nnx/graph.py index 51e1ca8480..950f7193dd 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -258,7 +258,7 @@ class NodeDef(tp.Generic[Node], reprlib.Representable): attributes: tuple[Key, ...] subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', Index]] static_fields: _HashableMapping[Key, tp.Any] - variables: _HashableMapping[Key, Index] + leaves: _HashableMapping[Key, Index | None] metadata: tp.Any @classmethod @@ -269,7 +269,7 @@ def create( attributes: tuple[Key, ...], subgraphs: tp.Iterable[tuple[Key, tp.Union['NodeDef[tp.Any]', Index]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], - variables: tp.Iterable[tuple[Key, Index]], + leaves: tp.Iterable[tuple[Key, Index | None]], metadata: tp.Any, ): return cls( @@ -278,7 +278,7 @@ def create( attributes=attributes, subgraphs=_HashableMapping(subgraphs), static_fields=_HashableMapping(static_fields), - variables=_HashableMapping(variables), + leaves=_HashableMapping(leaves), metadata=metadata, ) @@ -292,7 +292,7 @@ def __nnx_repr__(self): yield reprlib.Attr( 'static_fields', reprlib.PrettyMapping(self.static_fields) ) - yield reprlib.Attr('variables', reprlib.PrettyMapping(self.variables)) + yield reprlib.Attr('variables', reprlib.PrettyMapping(self.leaves)) yield reprlib.Attr('metadata', self.metadata) @@ -396,9 +396,9 @@ def _graph_flatten( else: index = -1 - subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], int]]] = [] + subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], Index]]] = [] static_fields: list[tuple[Key, tp.Any]] = [] - variables: list[tuple[Key, int]] = [] + leaves: list[tuple[Key, Index | None]] = [] values, metadata = node_impl.flatten(node) for key, value in values: @@ -407,13 +407,14 @@ def _graph_flatten( subgraphs.append((key, nodedef)) elif isinstance(value, Variable): if value in refmap: - variables.append((key, refmap[value])) + leaves.append((key, refmap[value])) else: flat_state[(*path, key)] = value.to_state() variable_index = refmap[value] = len(refmap) - variables.append((key, variable_index)) + leaves.append((key, variable_index)) elif is_state_leaf(value): flat_state[(*path, key)] = value + leaves.append((key, None)) else: static_fields.append((key, value)) @@ -423,7 +424,7 @@ def _graph_flatten( attributes=tuple(key for key, _ in values), subgraphs=subgraphs, static_fields=static_fields, - variables=variables, + leaves=leaves, metadata=metadata, ) return nodedef @@ -487,41 +488,53 @@ def _graph_unflatten( def _get_children(): children: dict[Key, NodeLeaf | Node] = {} + # NOTE: we could allw adding new StateLeafs here if unkown_keys := set(state) - set(nodedef.attributes): raise ValueError(f'Unknown keys: {unkown_keys}') + # for every key in attributes there are 6 possible cases: + # - (2) the key can either be present in the state or not + # - (3) the key can be a subgraph, a leaf, or a static attribute for key in nodedef.attributes: - if key in nodedef.static_fields: - children[key] = nodedef.static_fields[key] - elif key not in state: + if key not in state: # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys? # if key is not present create an empty types - if key in nodedef.subgraphs: + if key in nodedef.static_fields: + children[key] = nodedef.static_fields[key] + elif key in nodedef.subgraphs: # if the key is a subgraph we create an empty node subgraphdef = nodedef.subgraphs[key] if isinstance(subgraphdef, int): # subgraph exists, take it from the cache children[key] = index_to_ref[subgraphdef] else: - # create an empty node + # create a node from an empty state, reasoning: + # * its a node with no state + # * its a node with state but only through references of already + # created nodes substate = {} children[key] = _graph_unflatten( subgraphdef, substate, index_to_ref, idxmap ) - elif key in nodedef.variables: - variable_index = nodedef.variables[key] - if variable_index in index_to_ref: + elif key in nodedef.leaves: + leaf_index = nodedef.leaves[key] + if leaf_index is not None and leaf_index in index_to_ref: # variable exists, take it from the cache - children[key] = index_to_ref[variable_index] + children[key] = index_to_ref[leaf_index] else: # key for a variable is missing, raise an error raise ValueError( - f'Expected key for Variable but was not found in state: {key!r}' + f'Expected key {key!r} in state while building node of type ' + f'{nodedef.type.__name__}.' ) else: raise RuntimeError(f'Unknown static field: {key!r}') else: value = state[key] + if key in nodedef.static_fields: + raise ValueError( + f'Got state for static field {key!r}, this is not supported.' + ) if key in nodedef.subgraphs: if is_state_leaf(value): raise ValueError( @@ -532,39 +545,50 @@ def _get_children(): subgraphdef = nodedef.subgraphs[key] if isinstance(subgraphdef, int): - node = index_to_ref[subgraphdef] + children[key] = index_to_ref[subgraphdef] else: - node = children[key] = _graph_unflatten( + children[key] = _graph_unflatten( subgraphdef, value, index_to_ref, idxmap ) - elif key in nodedef.variables: - variable_index = nodedef.variables[key] - if variable_index in index_to_ref: - children[key] = index_to_ref[variable_index] + elif key in nodedef.leaves: + if not is_state_leaf(value): + raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}') + + leaf_index = nodedef.leaves[key] + + if leaf_index is None: + # if the leaf is None, it means that the value was originally + # a non-VariableState leaf, however we allow providing a + # VariableState presumbly created by modifying the State + if isinstance(value, VariableState): + value = value.to_variable() + children[key] = value + elif leaf_index in index_to_ref: + # add an existing variable + children[key] = index_to_ref[leaf_index] else: + # its a unseen variable, create a new one if not isinstance(value, VariableState): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(value)}.' ) - if idxmap is not None and variable_index in idxmap: - variable = idxmap[variable_index] + # when idxmap is present, check if the Varable exists there + # and update existing variables if it does + if idxmap is not None and leaf_index in idxmap: + variable = idxmap[leaf_index] if not isinstance(variable, Variable): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) variable.copy_from_state(value) - else: + else: # if it doesn't, create a new variable assert isinstance(value, VariableState) variable = value.to_variable() children[key] = variable - index_to_ref[variable_index] = variable - elif is_state_leaf(value): - if isinstance(value, VariableState): - value = value.to_variable() - children[key] = value + index_to_ref[leaf_index] = variable else: - raise RuntimeError + raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') return children diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 0c25222f41..64a07a1938 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -60,9 +60,7 @@ def test_unflatten_empty(self): graphdef, state = nnx.split(g) - with pytest.raises( - ValueError, match='Expected key for Variable but was not found in state' - ): + with pytest.raises(ValueError, match='Expected key'): nnx.graph.unflatten(graphdef, nnx.State({})) def test_update_dynamic(self):