Skip to content

Commit

Permalink
updated nnx.graph docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Jun 7, 2024
1 parent 7c6a655 commit 0830a87
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/api_reference/flax.nnx/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ graph
.. autofunction:: iter_graph
.. autofunction:: clone

.. autoclass:: GraphDef
:members:

.. autoclass:: UpdateContext
:members:

Expand Down
1 change: 1 addition & 0 deletions docs/api_reference/flax.nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/i
nn/index
rnglib
spmd
state
training/index
transforms
variables
Expand Down
9 changes: 9 additions & 0 deletions docs/api_reference/flax.nnx/state.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
state
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx


.. autoclass:: State
:members:
4 changes: 2 additions & 2 deletions docs/nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Learn more
:hidden:
:maxdepth: 1

haiku_linen_vs_nnx
nnx_basics
mnist_tutorial
transforms
transforms
haiku_linen_vs_nnx
135 changes: 130 additions & 5 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ def __penzai_repr__(self, path, subtree_renderer):

@dataclasses.dataclass(frozen=True, repr=False)
class GraphDef(tp.Generic[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
:class:`Module`. A ``GraphDef`` can be generated by either
calling :func:`split` or :func:`graphdef` on the :class:`Module`."""

nodedef: NodeDef[Node]
index_mapping: dict[Index, Index] | None

Expand Down Expand Up @@ -955,8 +959,8 @@ def split(
node: graph node to split.
filters: some optional filters to group the state into mutually exclusive substates.
Returns:
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
:class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
filters are passed, a single :class:`State` is returned.
"""
if self.refmap is not None and self.idxmap is None:
raise ValueError(
Expand Down Expand Up @@ -1263,7 +1267,7 @@ def merge(
) -> A:
"""The inverse of :func:`split`.
``merge`` takes a :class:`GraphDef` and one or more :class:`State`s and creates
``merge`` takes a :class:`GraphDef` and one or more :class:`State`'s and creates
a new node with the same structure as the original node.
Example usage::
Expand Down Expand Up @@ -1293,6 +1297,8 @@ def merge(
graphdef: A :class:`GraphDef` object.
state: A :class:`State` object.
states: Additional :class:`State` objects.
Returns:
The merged :class:`Module`.
"""
if states:
state = GraphState.merge(state, *states)
Expand All @@ -1301,7 +1307,32 @@ def merge(
return node


def update(node, state: GraphState, /, *states: GraphState) -> None:
def update(node, state: State, /, *states: State) -> None:
"""Update the given graph node with a new :class:`State` in-place.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> def loss_fn(model, x, y):
... return jnp.mean((y - model(x))**2)
>>> prev_loss = loss_fn(model, x, y)
>>> grads = nnx.grad(loss_fn)(model, x, y)
>>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads)
>>> nnx.update(model, new_state)
>>> assert loss_fn(model, x, y) < prev_loss
Args:
node: A graph node to update.
state: A :class:`State` object.
states: Additional :class:`State` objects.
"""
if states:
state = GraphState.merge(state, *states)

Expand Down Expand Up @@ -1330,6 +1361,35 @@ def state(
node,
*filters: filterlib.Filter,
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
"""Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters.
Example usage::
>>> from flax import nnx
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... def __call__(self, x):
... return self.linear(self.batch_norm(x))
>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
Args:
node: A graph node object.
filters: One or more :class:`Variable` objects to filter by.
Returns:
One or more :class:`State` mappings.
"""
state = flatten(node)[1]

states: GraphState | tuple[GraphState, ...]
Expand All @@ -1344,6 +1404,21 @@ def state(


def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]:
"""Get the :class:`GraphDef` of the given graph node.
Example usage::
>>> from flax import nnx
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, _ = nnx.split(model)
>>> assert graphdef == nnx.graphdef(model)
Args:
node: A graph node object.
Returns:
The :class:`GraphDef` of the :class:`Module` object.
"""
graphdef, _, _ = flatten(node)
return graphdef

Expand All @@ -1369,6 +1444,40 @@ def pop(
def pop(
node, *filters: filterlib.Filter
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
"""Pop one or more :class:`Variable` types from the graph node.
Example usage::
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... self.sow(nnx.Intermediate, 'i', x)
... x = self.linear2(x)
... return x
>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> intermediates = nnx.pop(model, nnx.Intermediate)
>>> assert intermediates['i'].value[0].shape == (1, 3)
>>> assert not hasattr(model, 'i')
Args:
node: A graph node object.
filters: One or more :class:`Variable` objects to filter by.
Returns:
The popped :class:`State` containing the :class:`Variable`
objects that were filtered for.
"""
if len(filters) == 0:
raise ValueError('Expected at least one filter')

Expand All @@ -1394,12 +1503,28 @@ def pop(


def clone(node: Node) -> Node:
"""Create a deep copy of the given graph node.
Example usage::
>>> from flax import nnx
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model)
>>> model.bias.value += 1
>>> assert (model.bias.value != cloned_model.bias.value).all()
Args:
node: A graph node object.
Returns:
A deep copy of the :class:`Module` object.
"""
graphdef, state = split(node)
return merge(graphdef, state)


def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
"""Iterates over all nested nodes and leaves of a graph node, including the current node.
"""Iterates over all nested nodes and leaves of the given graph node, including the current node.
``iter_graph`` creates a generator that yields path and value pairs, where
the path is a tuple of strings or integers representing the path to the value from the
Expand Down
91 changes: 91 additions & 0 deletions flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def __penzai_repr__(self, path, subtree_renderer):


class State(MutableMapping[K, V], reprlib.Representable):
"""A pytree-like structure that contains a ``Mapping`` from strings or
integers to leaves. A valid leaf type is either :class:`Variable`,
``jax.Array``, ``numpy.ndarray`` or nested ``State``'s. A ``State``
can be generated by either calling :func:`split` or :func:`state` on
the :class:`Module`."""

def __init__(
self,
mapping: tp.Union[
Expand Down Expand Up @@ -173,6 +179,32 @@ def split(
def split(
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Split a ``State`` into one or more ``State``'s. The
user must pass at least one ``Filter`` (i.e. :class:`Variable`),
and the filters must be exhaustive (i.e. they must cover all
:class:`Variable` types in the ``State``).
Example usage::
>>> from flax import nnx
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... def __call__(self, x):
... return self.linear(self.batchnorm(x))
>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
Arguments:
first: The first filter
filters: The optional, additional filters to group the state into mutually exclusive substates.
Returns:
One or more ``States`` equal to the number of filters passed.
"""
filters = (first, *filters)
*states_, rest = _split_state(self, *filters)

Expand Down Expand Up @@ -211,6 +243,34 @@ def filter(
/,
*filters: filterlib.Filter,
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Filter a ``State`` into one or more ``State``'s. The
user must pass at least one ``Filter`` (i.e. :class:`Variable`).
This method is similar to :meth:`split() <flax.nnx.nnx.State.state.split>`,
except the filters can be non-exhaustive.
Example usage::
>>> from flax import nnx
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... def __call__(self, x):
... return self.linear(self.batchnorm(x))
>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param = state.filter(nnx.Param)
>>> batch_stats = state.filter(nnx.BatchStat)
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
Arguments:
first: The first filter
filters: The optional, additional filters to group the state into mutually exclusive substates.
Returns:
One or more ``States`` equal to the number of filters passed.
"""
*states_, _rest = _split_state(self, first, *filters)

assert len(states_) == len(filters) + 1
Expand All @@ -225,6 +285,37 @@ def filter(

@staticmethod
def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]:
"""The inverse of :meth:`split() <flax.nnx.nnx.State.state.split>`.
``merge`` takes one or more ``State``'s and creates
a new ``State``.
Example usage::
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... def __call__(self, x):
... return self.linear(self.batchnorm(x))
>>> model = Model(rngs=nnx.Rngs(0))
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> params.linear.bias.value += 1
>>> state = nnx.State.merge(params, batch_stats)
>>> nnx.update(model, state)
>>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
Args:
state: A ``State`` object.
states: Additional ``State`` objects.
Returns:
The merged ``State``.
"""
states = (state, *states)

if len(states) == 1:
Expand Down

0 comments on commit 0830a87

Please sign in to comment.