Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add call #4004

Merged
merged 1 commit into from
Jul 17, 2024
Merged

[nnx] add call #4004

merged 1 commit into from
Jul 17, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jun 18, 2024

What does this PR do?

Docs: preview.

Adds nnx.call which takes a (GraphDef, State) pair and creates a proxy object that can be
used to call methods on the underlying graph node. When a method is called, the
output is returned along with a new (GraphDef, State) pair that represents the
updated state of the graph node.

from flax import nnx
import jax
import jax.numpy as jnp

class StatefulLinear(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x):
    self.count.value += 1
    return x @ self.w + self.b

linear = StatefulLinear(3, 2, nnx.Rngs(0))
linear_state = nnx.split(linear)

@jax.jit
def forward(x, linear_state):
  y, linear_state = nnx.call(linear_state)(x)
  return y, linear_state

x = jnp.ones((1, 3))
y, linear_state = forward(x, linear_state)
y, linear_state = forward(x, linear_state)

linear = nnx.merge(*linear_state)
assert linear.count.value == 2

@cgarciae cgarciae force-pushed the nnx-pure branch 5 times, most recently from 68928b0 to 22793db Compare June 21, 2024 11:31
@PhilipVinc
Copy link
Contributor

Hey @cgarciae ! I'd love this to be merged because I'd like to play with it... Is there any reason this is being left behind?

@cgarciae cgarciae force-pushed the nnx-pure branch 4 times, most recently from 1fdbca8 to f08ee63 Compare July 9, 2024 15:02
@cgarciae
Copy link
Collaborator Author

cgarciae commented Jul 9, 2024

Hey @PhilipVinc, I was wondering if there was a way to avoid increasing the API surface. If you check the updated notes this new version, Pure is now a NamedTuple and nnx.split now returns Pure instances when no filters are passed.

@cgarciae cgarciae requested a review from superbobry July 9, 2024 15:12
@cgarciae cgarciae force-pushed the nnx-pure branch 3 times, most recently from e8041ae to aa2c05d Compare July 11, 2024 11:25
@cgarciae cgarciae changed the title [nnx] pure API [nnx] add call Jul 11, 2024
@copybara-service copybara-service bot merged commit c4066cc into main Jul 17, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-pure branch July 17, 2024 21:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants