From aa2c05d5d6cc7b40404d753dc04672eb1675573e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 18 Jun 2024 18:56:11 +0100 Subject: [PATCH] [nnx] pure API --- docs/api_reference/flax.nnx/graph.rst | 1 + flax/nnx/__init__.py | 2 + flax/nnx/nnx/graph.py | 79 +++++++++++++++++++++++++++ flax/nnx/tests/graph_utils_test.py | 77 +++++++++++++++++++++++++- 4 files changed, 156 insertions(+), 3 deletions(-) diff --git a/docs/api_reference/flax.nnx/graph.rst b/docs/api_reference/flax.nnx/graph.rst index fa5b456ebd..d944e3c7bf 100644 --- a/docs/api_reference/flax.nnx/graph.rst +++ b/docs/api_reference/flax.nnx/graph.rst @@ -13,6 +13,7 @@ graph .. autofunction:: graphdef .. autofunction:: iter_graph .. autofunction:: clone +.. autofunction:: call .. autoclass:: GraphDef :members: diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 0ed7392b5a..df2d3af0c4 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -35,6 +35,7 @@ from .nnx.filterlib import Nothing as Nothing from .nnx.graph import GraphDef as GraphDef from .nnx.graph import GraphState as GraphState +from .nnx.graph import GraphDefState as GraphDefState from .nnx.object import Object as Object from .nnx.helpers import Dict as Dict from .nnx.helpers import List as List @@ -53,6 +54,7 @@ from .nnx.graph import state as state from .nnx.graph import graphdef as graphdef from .nnx.graph import iter_graph as iter_graph +from .nnx.graph import call as call from .nnx.nn import initializers as initializers from .nnx.nn.activations import celu as celu from .nnx.nn.activations import elu as elu diff --git a/flax/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py index 50d20c94e6..a3f1ca6ba1 100644 --- a/flax/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -375,6 +375,8 @@ def _graphdef_unflatten( _graphdef_unflatten, ) +GraphDefState = tuple[GraphDef[A], GraphState] + def flatten( x: Node, @@ -1523,6 +1525,83 @@ def clone(node: Node) -> Node: return merge(graphdef, state) +def call( + graphdef_state: tuple[GraphDef[A], GraphState], / +) -> ApplyCaller[tuple[GraphDef[A], GraphState]]: + """Calls a method underlying graph node defined by a (GraphDef, State) pair. + + ``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be + used to call methods on the underlying graph node. When a method is called, the + output is returned along with a new (GraphDef, State) pair that represents the + updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method`` + > :func:`split`` but is more convenient to use in pure JAX functions. + + Example:: + + >>> from flax import nnx + >>> import jax + >>> import jax.numpy as jnp + ... + >>> class StatefulLinear(nnx.Module): + ... def __init__(self, din, dout, rngs): + ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) + ... self.b = nnx.Param(jnp.zeros((dout,))) + ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) + ... + ... def increment(self): + ... self.count += 1 + ... + ... def __call__(self, x): + ... self.increment() + ... return x @ self.w + self.b + ... + >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) + >>> linear_state = nnx.split(linear) + ... + >>> @jax.jit + ... def forward(x, linear_state): + ... y, linear_state = nnx.call(linear_state)(x) + ... return y, linear_state + ... + >>> x = jnp.ones((1, 3)) + >>> y, linear_state = forward(x, linear_state) + >>> y, linear_state = forward(x, linear_state) + ... + >>> linear = nnx.merge(*linear_state) + >>> linear.count.value + Array(2, dtype=uint32) + + The proxy object returned by ``call`` supports indexing and attribute access + to access nested methods. In the example below, the ``increment`` method indexing + is used to call the ``increment`` method of the ``StatefulLinear`` module + at the ``b`` key of a ``nodes`` dictionary. + + >>> rngs = nnx.Rngs(0) + >>> nodes = dict( + ... a=StatefulLinear(3, 2, rngs), + ... b=StatefulLinear(2, 1, rngs), + ... ) + ... + >>> node_state = nnx.split(nodes) + >>> # use attribute access + >>> _, node_state = nnx.call(node_state)['b'].increment() + ... + >>> nodes = nnx.merge(*node_state) + >>> nodes['a'].count.value + Array(0, dtype=uint32) + >>> nodes['b'].count.value + Array(1, dtype=uint32) + """ + + def pure_caller(accessor: DelayedAccessor, *args, **kwargs): + node = merge(*graphdef_state) + method = accessor(node) + out = method(*args, **kwargs) + return out, split(node) + + return CallableProxy(pure_caller) # type: ignore + + def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: """Iterates over all nested nodes and leaves of the given graph node, including the current node. diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py index a879b17e6d..3ef12152b0 100644 --- a/flax/nnx/tests/graph_utils_test.py +++ b/flax/nnx/tests/graph_utils_test.py @@ -14,14 +14,30 @@ from functools import partial from threading import Thread + import jax +import jax.numpy as jnp import pytest +from absl.testing import absltest + +from flax import nnx, struct + -from flax import nnx -from flax import struct +class StatefulLinear(nnx.Module): + def __init__(self, din, dout, rngs): + self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) + def increment(self): + self.count.value += 1 -class TestGraphUtils: + def __call__(self, x): + self.count.value += 1 + return x @ self.w + self.b[None] + + +class TestGraphUtils(absltest.TestCase): def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] @@ -399,6 +415,61 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): assert m2 is m assert m2.ref is m2 + def test_call_jit_update(self): + class Counter(nnx.Module): + def __init__(self): + self.count = jnp.zeros(()) + + def inc(self): + self.count += 1 + return 1 + + graph_state = nnx.split(Counter()) + + @jax.jit + def update(graph_state: nnx.GraphDefState[Counter]): + out, graph_state = nnx.call(graph_state).inc() + self.assertEqual(out, 1) + return graph_state + + graph_state = update(graph_state) + graph_state = update(graph_state) + + counter = nnx.merge(*graph_state) + + self.assertEqual(counter.count, 2) + + def test_stateful_linear(self): + linear = StatefulLinear(3, 2, nnx.Rngs(0)) + linear_state = nnx.split(linear) + + @jax.jit + def forward(x, pure_linear: nnx.GraphDefState[StatefulLinear]): + y, pure_linear = nnx.call(pure_linear)(x) + return y, pure_linear + + x = jnp.ones((1, 3)) + y, linear_state = forward(x, linear_state) + y, linear_state = forward(x, linear_state) + + self.assertEqual(linear.count.value, 0) + new_linear = nnx.merge(*linear_state) + self.assertEqual(new_linear.count.value, 2) + + def test_getitem(self): + rngs = nnx.Rngs(0) + nodes = dict( + a=StatefulLinear(3, 2, rngs), + b=StatefulLinear(2, 1, rngs), + ) + node_state = nnx.split(nodes) + _, node_state = nnx.call(node_state)['b'].increment() + + nodes = nnx.merge(*node_state) + + self.assertEqual(nodes['a'].count.value, 0) + self.assertEqual(nodes['b'].count.value, 1) + class SimpleModule(nnx.Module): pass