Skip to content

Commit

Permalink
[nnx] cleanup graph
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed May 13, 2024
1 parent e0fa96f commit bffe7e0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 37 deletions.
92 changes: 58 additions & 34 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', Index]]
static_fields: _HashableMapping[Key, tp.Any]
variables: _HashableMapping[Key, Index]
leaves: _HashableMapping[Key, Index | None]
metadata: tp.Any

@classmethod
Expand All @@ -269,7 +269,7 @@ def create(
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, tp.Union['GraphDef[tp.Any]', Index]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
variables: tp.Iterable[tuple[Key, Index]],
leaves: tp.Iterable[tuple[Key, Index | None]],
metadata: tp.Any,
):
return cls(
Expand All @@ -278,7 +278,7 @@ def create(
attributes=attributes,
subgraphs=_HashableMapping(subgraphs),
static_fields=_HashableMapping(static_fields),
variables=_HashableMapping(variables),
leaves=_HashableMapping(leaves),
metadata=metadata,
)

Expand All @@ -292,7 +292,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.variables))
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('metadata', self.metadata)


Expand Down Expand Up @@ -396,9 +396,9 @@ def _graph_flatten(
else:
index = -1

subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], int]]] = []
subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], Index]]] = []
static_fields: list[tuple[Key, tp.Any]] = []
variables: list[tuple[Key, int]] = []
leaves: list[tuple[Key, Index | None]] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
Expand All @@ -407,13 +407,14 @@ def _graph_flatten(
subgraphs.append((key, nodedef))
elif isinstance(value, Variable):
if value in refmap:
variables.append((key, refmap[value]))
leaves.append((key, refmap[value]))
else:
flat_state[(*path, key)] = value.to_state()
variable_index = refmap[value] = len(refmap)
variables.append((key, variable_index))
leaves.append((key, variable_index))
elif is_state_leaf(value):
flat_state[(*path, key)] = value
leaves.append((key, None))
else:
static_fields.append((key, value))

Expand All @@ -423,7 +424,7 @@ def _graph_flatten(
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
static_fields=static_fields,
variables=variables,
leaves=leaves,
metadata=metadata,
)
return nodedef
Expand Down Expand Up @@ -487,41 +488,53 @@ def _graph_unflatten(
def _get_children():
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
raise ValueError(f'Unknown keys: {unkown_keys}')

# for every key in attributes there are 6 possible cases:
# - (2) the key can either be present in the state or not
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key not in state:
if key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in nodedef.subgraphs:
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
if isinstance(subgraphdef, int):
# subgraph exists, take it from the cache
children[key] = index_to_ref[subgraphdef]
else:
# create an empty node
# create a node from an empty state, reasoning:
# * its a node with no state
# * its a node with state but only through references of already
# created nodes
substate = {}
children[key] = _graph_unflatten(
subgraphdef, substate, index_to_ref, idxmap
)
elif key in nodedef.variables:
variable_index = nodedef.variables[key]
if variable_index in index_to_ref:
elif key in nodedef.leaves:
leaf_index = nodedef.leaves[key]
if leaf_index is not None and leaf_index in index_to_ref:
# variable exists, take it from the cache
children[key] = index_to_ref[variable_index]
children[key] = index_to_ref[leaf_index]
else:
# key for a variable is missing, raise an error
raise ValueError(
f'Expected key for Variable but was not found in state: {key!r}'
f'Expected key {key!r} in state while building node of type '
f'{nodedef.type.__name__}.'
)
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
value = state[key]
if key in nodedef.static_fields:
raise ValueError(
f'Got state for static field {key!r}, this is not supported.'
)
if key in nodedef.subgraphs:
if is_state_leaf(value):
raise ValueError(
Expand All @@ -532,39 +545,50 @@ def _get_children():
subgraphdef = nodedef.subgraphs[key]

if isinstance(subgraphdef, int):
node = index_to_ref[subgraphdef]
children[key] = index_to_ref[subgraphdef]
else:
node = children[key] = _graph_unflatten(
children[key] = _graph_unflatten(
subgraphdef, value, index_to_ref, idxmap
)

elif key in nodedef.variables:
variable_index = nodedef.variables[key]
if variable_index in index_to_ref:
children[key] = index_to_ref[variable_index]
elif key in nodedef.leaves:
if not is_state_leaf(value):
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')

leaf_index = nodedef.leaves[key]

if leaf_index is None:
# if the leaf is None, it means that the value was originally
# a non-VariableState leaf, however we allow providing a
# VariableState presumbly created by modifying the State
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
elif leaf_index in index_to_ref:
# add an existing variable
children[key] = index_to_ref[leaf_index]
else:
# its a unseen variable, create a new one
if not isinstance(value, VariableState):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(value)}.'
)
if idxmap is not None and variable_index in idxmap:
variable = idxmap[variable_index]
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if idxmap is not None and leaf_index in idxmap:
variable = idxmap[leaf_index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
variable.copy_from_state(value)
else:
else: # if it doesn't, create a new variable
assert isinstance(value, VariableState)
variable = value.to_variable()
children[key] = variable
index_to_ref[variable_index] = variable
elif is_state_leaf(value):
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
index_to_ref[leaf_index] = variable
else:
raise RuntimeError
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

return children

Expand Down
4 changes: 1 addition & 3 deletions flax/experimental/nnx/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def test_unflatten_empty(self):

graphdef, state = nnx.split(g)

with pytest.raises(
ValueError, match='Expected key for Variable but was not found in state'
):
with pytest.raises(ValueError, match='Expected key'):
nnx.graph.unflatten(graphdef, nnx.State({}))

def test_update_dynamic(self):
Expand Down

0 comments on commit bffe7e0

Please sign in to comment.