Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] cleanup graph #3915

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 58 additions & 34 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', Index]]
static_fields: _HashableMapping[Key, tp.Any]
variables: _HashableMapping[Key, Index]
leaves: _HashableMapping[Key, Index | None]
metadata: tp.Any

@classmethod
Expand All @@ -269,7 +269,7 @@ def create(
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, tp.Union['NodeDef[tp.Any]', Index]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
variables: tp.Iterable[tuple[Key, Index]],
leaves: tp.Iterable[tuple[Key, Index | None]],
metadata: tp.Any,
):
return cls(
Expand All @@ -278,7 +278,7 @@ def create(
attributes=attributes,
subgraphs=_HashableMapping(subgraphs),
static_fields=_HashableMapping(static_fields),
variables=_HashableMapping(variables),
leaves=_HashableMapping(leaves),
metadata=metadata,
)

Expand All @@ -292,7 +292,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.variables))
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('metadata', self.metadata)


Expand Down Expand Up @@ -396,9 +396,9 @@ def _graph_flatten(
else:
index = -1

subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], int]]] = []
subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], Index]]] = []
static_fields: list[tuple[Key, tp.Any]] = []
variables: list[tuple[Key, int]] = []
leaves: list[tuple[Key, Index | None]] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
Expand All @@ -407,13 +407,14 @@ def _graph_flatten(
subgraphs.append((key, nodedef))
elif isinstance(value, Variable):
if value in refmap:
variables.append((key, refmap[value]))
leaves.append((key, refmap[value]))
else:
flat_state[(*path, key)] = value.to_state()
variable_index = refmap[value] = len(refmap)
variables.append((key, variable_index))
leaves.append((key, variable_index))
elif is_state_leaf(value):
flat_state[(*path, key)] = value
leaves.append((key, None))
else:
static_fields.append((key, value))

Expand All @@ -423,7 +424,7 @@ def _graph_flatten(
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
static_fields=static_fields,
variables=variables,
leaves=leaves,
metadata=metadata,
)
return nodedef
Expand Down Expand Up @@ -487,41 +488,53 @@ def _graph_unflatten(
def _get_children():
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
raise ValueError(f'Unknown keys: {unkown_keys}')

# for every key in attributes there are 6 possible cases:
# - (2) the key can either be present in the state or not
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key not in state:
if key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in nodedef.subgraphs:
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
if isinstance(subgraphdef, int):
# subgraph exists, take it from the cache
children[key] = index_to_ref[subgraphdef]
else:
# create an empty node
# create a node from an empty state, reasoning:
# * its a node with no state
# * its a node with state but only through references of already
# created nodes
substate = {}
children[key] = _graph_unflatten(
subgraphdef, substate, index_to_ref, idxmap
)
elif key in nodedef.variables:
variable_index = nodedef.variables[key]
if variable_index in index_to_ref:
elif key in nodedef.leaves:
leaf_index = nodedef.leaves[key]
if leaf_index is not None and leaf_index in index_to_ref:
# variable exists, take it from the cache
children[key] = index_to_ref[variable_index]
children[key] = index_to_ref[leaf_index]
else:
# key for a variable is missing, raise an error
raise ValueError(
f'Expected key for Variable but was not found in state: {key!r}'
f'Expected key {key!r} in state while building node of type '
f'{nodedef.type.__name__}.'
)
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
value = state[key]
if key in nodedef.static_fields:
raise ValueError(
f'Got state for static field {key!r}, this is not supported.'
)
if key in nodedef.subgraphs:
if is_state_leaf(value):
raise ValueError(
Expand All @@ -532,39 +545,50 @@ def _get_children():
subgraphdef = nodedef.subgraphs[key]

if isinstance(subgraphdef, int):
node = index_to_ref[subgraphdef]
children[key] = index_to_ref[subgraphdef]
else:
node = children[key] = _graph_unflatten(
children[key] = _graph_unflatten(
subgraphdef, value, index_to_ref, idxmap
)

elif key in nodedef.variables:
variable_index = nodedef.variables[key]
if variable_index in index_to_ref:
children[key] = index_to_ref[variable_index]
elif key in nodedef.leaves:
if not is_state_leaf(value):
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')

leaf_index = nodedef.leaves[key]

if leaf_index is None:
# if the leaf is None, it means that the value was originally
# a non-VariableState leaf, however we allow providing a
# VariableState presumbly created by modifying the State
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
elif leaf_index in index_to_ref:
# add an existing variable
children[key] = index_to_ref[leaf_index]
else:
# its a unseen variable, create a new one
if not isinstance(value, VariableState):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(value)}.'
)
if idxmap is not None and variable_index in idxmap:
variable = idxmap[variable_index]
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if idxmap is not None and leaf_index in idxmap:
variable = idxmap[leaf_index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
variable.copy_from_state(value)
else:
else: # if it doesn't, create a new variable
assert isinstance(value, VariableState)
variable = value.to_variable()
children[key] = variable
index_to_ref[variable_index] = variable
elif is_state_leaf(value):
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
index_to_ref[leaf_index] = variable
else:
raise RuntimeError
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

return children

Expand Down
4 changes: 1 addition & 3 deletions flax/experimental/nnx/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def test_unflatten_empty(self):

graphdef, state = nnx.split(g)

with pytest.raises(
ValueError, match='Expected key for Variable but was not found in state'
):
with pytest.raises(ValueError, match='Expected key'):
nnx.graph.unflatten(graphdef, nnx.State({}))

def test_update_dynamic(self):
Expand Down
Loading