Skip to content

Commit

Permalink
Merge pull request #4399 from google:nnx-more-optimizations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700449551
  • Loading branch information
Flax Authors committed Nov 26, 2024
2 parents 6bc9858 + 861936f commit fe8fb59
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 76 deletions.
116 changes: 106 additions & 10 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

118 changes: 61 additions & 57 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,24 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, slots=True)
class SubGraphAttribute:
key: Key
value: NodeDef[tp.Any] | NodeRef[tp.Any]


@dataclasses.dataclass(frozen=True, slots=True)
class StaticAttribute:
key: Key
value: tp.Any


@dataclasses.dataclass(frozen=True, slots=True)
class LeafAttribute:
key: Key
value: VariableDef | NodeRef[tp.Any]


@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
Expand All @@ -298,10 +316,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable):

type: tp.Type[Node]
index: int
attributes: tuple[Key, ...]
subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: HashableMapping[Key, tp.Any]
leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...]
metadata: tp.Any
index_mapping: HashableMapping[Index, Index] | None

Expand All @@ -310,20 +325,14 @@ def create(
cls,
type: tp.Type[Node],
index: int,
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]],
attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...],
metadata: tp.Any,
index_mapping: tp.Mapping[Index, Index] | None,
):
return cls(
type=type,
index=index,
attributes=attributes,
subgraphs=HashableMapping(subgraphs),
static_fields=HashableMapping(static_fields),
leaves=HashableMapping(leaves),
metadata=metadata,
index_mapping=HashableMapping(index_mapping)
if index_mapping is not None
Expand All @@ -335,12 +344,7 @@ def __nnx_repr__(self):

yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
yield reprlib.Attr('attributes', self.attributes)
yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs))
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes))
yield reprlib.Attr('metadata', self.metadata)
yield reprlib.Attr(
'index_mapping',
Expand All @@ -352,18 +356,15 @@ def __nnx_repr__(self):
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'subgraphs': dict(self.subgraphs),
'static_fields': dict(self.static_fields),
'leaves': dict(self.leaves),
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)

def apply(
Expand Down Expand Up @@ -426,40 +427,39 @@ def _graph_flatten(
else:
index = -1

subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
static_fields: list[tuple[Key, tp.Any]] = []
leaves: list[tuple[Key, VariableDef | NodeRef]] = []
attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
if is_node(value):
nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
subgraphs.append((key, nodedef))
# subgraphs.append((key, nodedef))
attributes.append(SubGraphAttribute(key, nodedef))
elif isinstance(value, Variable):
if value in ref_index:
leaves.append((key, NodeRef(type(value), ref_index[value])))
attributes.append(
LeafAttribute(key, NodeRef(type(value), ref_index[value]))
)
else:
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
)
leaves.append((key, variabledef))
attributes.append(LeafAttribute(key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
static_fields.append((key, value))
# static_fields.append((key, value))
attributes.append(StaticAttribute(key, value))

nodedef = NodeDef.create(
type=node_impl.type,
index=index,
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
static_fields=static_fields,
leaves=leaves,
attributes=tuple(attributes),
metadata=metadata,
index_mapping=None,
)
Expand Down Expand Up @@ -529,22 +529,20 @@ def _graph_unflatten(

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
raise ValueError(f'Unknown keys: {unkown_keys}')
state_keys: set = set(state.keys())

# for every key in attributes there are 6 possible cases:
# - (2) the key can either be present in the state or not
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
for attribute in nodedef.attributes:
key = attribute.key
if key not in state:
# if key is not present create an empty types
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
if type(attribute) is StaticAttribute:
children[key] = attribute.value
elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
Expand All @@ -558,8 +556,8 @@ def _get_children():
children[key] = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
elif key in nodedef.leaves:
variabledef = nodedef.leaves[key]
elif type(attribute) is LeafAttribute:
variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
Expand All @@ -572,19 +570,21 @@ def _get_children():
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
state_keys.remove(key)
value = state[key]
if key in nodedef.static_fields:
# if key in nodedef.static_fields:
if type(attribute) is StaticAttribute:
raise ValueError(
f'Got state for static field {key!r}, this is not supported.'
)
if key in nodedef.subgraphs:
elif type(attribute) is SubGraphAttribute:
if is_state_leaf(value):
raise ValueError(
f'Expected value of type {nodedef.subgraphs[key]} for '
f'Expected value of type {attribute.value} for '
f'{key!r}, but got {value!r}'
)
assert isinstance(value, dict)
subgraphdef = nodedef.subgraphs[key]
subgraphdef = attribute.value

if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
Expand All @@ -593,8 +593,8 @@ def _get_children():
subgraphdef, value, index_ref, index_ref_cache
)

elif key in nodedef.leaves:
variabledef = nodedef.leaves[key]
elif type(attribute) is LeafAttribute:
variabledef = attribute.value

if variabledef.index in index_ref:
# add an existing variable
Expand Down Expand Up @@ -631,6 +631,10 @@ def _get_children():
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

# NOTE: we could allw adding new StateLeafs here
if state_keys:
raise ValueError(f'Unknown keys: {state_keys}')

return children

if isinstance(node_impl, GraphNodeImpl):
Expand Down
12 changes: 11 additions & 1 deletion flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,14 @@ def __nnx_repr__(self):
yield Object(type='', value_sep=': ', start='{', end='}')

for key, value in self.mapping.items():
yield Attr(repr(key), value)
yield Attr(repr(key), value)

@dataclasses.dataclass(repr=False)
class PrettySequence(Representable):
list: tp.Sequence

def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self.list:
yield Attr('', value)
3 changes: 1 addition & 2 deletions flax/nnx/scripts/run-all-examples.bash
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
set -e

source .venv/bin/activate
cd flax/nnx

for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do
for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do
echo -e "\n---------------------------------"
echo "$f"
echo "---------------------------------"
Expand Down
15 changes: 10 additions & 5 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,11 +1341,16 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef):
global_index_mapping[nd.index] = nd.index
if isinstance(nd, graph.NodeRef):
return
for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0:
global_index_mapping[l.index] = l.index

for attribute in nd.attributes:
if type(attribute) is graph.SubGraphAttribute:
per_node_def(attribute.value)
elif (
type(attribute) is graph.LeafAttribute
and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef))
and attribute.value.index >= 0
):
global_index_mapping[attribute.value.index] = attribute.value.index
return

per_node_def(ns._graphdef)
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self):

assert 'tree' in state
assert 'a' in state.tree
assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree
assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree

m2 = nnx.merge(graphdef, state)

Expand Down

0 comments on commit fe8fb59

Please sign in to comment.