Skip to content

Commit

Permalink
Merge pull request #4004 from google:nnx-pure
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653369837
  • Loading branch information
Flax Authors committed Jul 17, 2024
2 parents 62127f8 + f89e62c commit c4066cc
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api_reference/flax.nnx/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ graph
.. autofunction:: graphdef
.. autofunction:: iter_graph
.. autofunction:: clone
.. autofunction:: call

.. autoclass:: GraphDef
:members:
Expand Down
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .nnx.filterlib import Nothing as Nothing
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphState as GraphState
from .nnx.graph import GraphDefState as GraphDefState
from .nnx.object import Object as Object
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
Expand All @@ -53,6 +54,7 @@
from .nnx.graph import state as state
from .nnx.graph import graphdef as graphdef
from .nnx.graph import iter_graph as iter_graph
from .nnx.graph import call as call
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down
92 changes: 92 additions & 0 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def _graphdef_unflatten(
_graphdef_unflatten,
)

GraphDefState = tuple[GraphDef[A], GraphState]


def flatten(
x: Node,
Expand Down Expand Up @@ -1523,6 +1525,96 @@ def clone(node: Node) -> Node:
return merge(graphdef, state)


def call(
graphdef_state: tuple[GraphDef[A], GraphState], /
) -> ApplyCaller[tuple[GraphDef[A], GraphState]]:
"""Calls a method underlying graph node defined by a (GraphDef, State) pair.
``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be
used to call methods on the underlying graph node. When a method is called, the
output is returned along with a new (GraphDef, State) pair that represents the
updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
> :func:`split`` but is more convenient to use in pure JAX functions.
Example::
>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class StatefulLinear(nnx.Module):
... def __init__(self, din, dout, rngs):
... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
... self.b = nnx.Param(jnp.zeros((dout,)))
... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
... def increment(self):
... self.count += 1
...
... def __call__(self, x):
... self.increment()
... return x @ self.w + self.b
...
>>> linear = StatefulLinear(3, 2, nnx.Rngs(0))
>>> linear_state = nnx.split(linear)
...
>>> @jax.jit
... def forward(x, linear_state):
... y, linear_state = nnx.call(linear_state)(x)
... return y, linear_state
...
>>> x = jnp.ones((1, 3))
>>> y, linear_state = forward(x, linear_state)
>>> y, linear_state = forward(x, linear_state)
...
>>> linear = nnx.merge(*linear_state)
>>> linear.count.value
Array(2, dtype=uint32)
The proxy object returned by ``call`` supports indexing and attribute access
to access nested methods. In the example below, the ``increment`` method indexing
is used to call the ``increment`` method of the ``StatefulLinear`` module
at the ``b`` key of a ``nodes`` dictionary.
>>> class StatefulLinear(nnx.Module):
... def __init__(self, din, dout, rngs):
... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
... self.b = nnx.Param(jnp.zeros((dout,)))
... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
... def increment(self):
... self.count += 1
...
... def __call__(self, x):
... self.increment()
... return x @ self.w + self.b
...
>>> rngs = nnx.Rngs(0)
>>> nodes = dict(
... a=StatefulLinear(3, 2, rngs),
... b=StatefulLinear(2, 1, rngs),
... )
...
>>> node_state = nnx.split(nodes)
>>> # use attribute access
>>> _, node_state = nnx.call(node_state)['b'].increment()
...
>>> nodes = nnx.merge(*node_state)
>>> nodes['a'].count.value
Array(0, dtype=uint32)
>>> nodes['b'].count.value
Array(1, dtype=uint32)
"""

def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
node = merge(*graphdef_state)
method = accessor(node)
out = method(*args, **kwargs)
return out, split(node)

return CallableProxy(pure_caller) # type: ignore


def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
"""Iterates over all nested nodes and leaves of the given graph node, including the current node.
Expand Down
77 changes: 74 additions & 3 deletions flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,30 @@

from functools import partial
from threading import Thread

import jax
import jax.numpy as jnp
import pytest
from absl.testing import absltest

from flax import nnx, struct


from flax import nnx
from flax import struct
class StatefulLinear(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))

def increment(self):
self.count.value += 1

class TestGraphUtils:
def __call__(self, x):
self.count.value += 1
return x @ self.w + self.b[None]


class TestGraphUtils(absltest.TestCase):
def test_flatten(self):
a = {'a': 1, 'b': nnx.Param(2)}
g = [a, 3, a, nnx.Param(4)]
Expand Down Expand Up @@ -399,6 +415,61 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
assert m2 is m
assert m2.ref is m2

def test_call_jit_update(self):
class Counter(nnx.Module):
def __init__(self):
self.count = jnp.zeros(())

def inc(self):
self.count += 1
return 1

graph_state = nnx.split(Counter())

@jax.jit
def update(graph_state: nnx.GraphDefState[Counter]):
out, graph_state = nnx.call(graph_state).inc()
self.assertEqual(out, 1)
return graph_state

graph_state = update(graph_state)
graph_state = update(graph_state)

counter = nnx.merge(*graph_state)

self.assertEqual(counter.count, 2)

def test_stateful_linear(self):
linear = StatefulLinear(3, 2, nnx.Rngs(0))
linear_state = nnx.split(linear)

@jax.jit
def forward(x, pure_linear: nnx.GraphDefState[StatefulLinear]):
y, pure_linear = nnx.call(pure_linear)(x)
return y, pure_linear

x = jnp.ones((1, 3))
y, linear_state = forward(x, linear_state)
y, linear_state = forward(x, linear_state)

self.assertEqual(linear.count.value, 0)
new_linear = nnx.merge(*linear_state)
self.assertEqual(new_linear.count.value, 2)

def test_getitem(self):
rngs = nnx.Rngs(0)
nodes = dict(
a=StatefulLinear(3, 2, rngs),
b=StatefulLinear(2, 1, rngs),
)
node_state = nnx.split(nodes)
_, node_state = nnx.call(node_state)['b'].increment()

nodes = nnx.merge(*node_state)

self.assertEqual(nodes['a'].count.value, 0)
self.assertEqual(nodes['b'].count.value, 1)


class SimpleModule(nnx.Module):
pass
Expand Down

0 comments on commit c4066cc

Please sign in to comment.