Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLIP] JAX-style NNX Transforms #4107

Open
cgarciae opened this issue Jul 26, 2024 · 0 comments
Open

[FLIP] JAX-style NNX Transforms #4107

cgarciae opened this issue Jul 26, 2024 · 0 comments

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Jul 26, 2024

JAX-style NNX Transforms

Motivation

NNX allows users to utilize Modules at the top level due to their eager initialization and self-contained state. This naturally leads users to want to use them with transforms and soon start playing with NNX transforms. Since NNX Modules resemble PyTrees in that they contain Arrays, new users often attempt to apply JAX conventions, for example:

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

However, this can be misleading. Currently, NNX transforms follow Linen's convention of treating input Modules as a single unit (all Modules are split together to preserve shared references) and provide APIs for transforming that State separately. The previous example effectively translates to:

# this is what is really happening
@nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0})
def f(m1: Module, m2: Module):
  ...

Note that IGNORE is not a real symbol, but represents the fact that any value placed here won't affect the outcome, as Modules are replaced by empty PyTree placeholders (similar to None). The state_axes parameter controls how the State is vectorized through a mapping of high-level Filters to their desired axes. In this example, ... (ellipsis) is a filter that accepts everything, so by default all States are vectorized on the 0th axis.

To express their original intention, users must resort to more complex custom filters that guess the index of each Module in the monolith. While this is straightforward in simple cases, users generally need to calculate the index (Modules appear in the order specified by jax.tree.leaves over the args):

select_m1 = lambda path, value: path[0] == 0
select_m2 = lambda path, value: path[0] == 1

# To select modules individually, you must create a filter (which can be tricky)
@nnx.vmap(state_axes={select_m1: 1, select_m2: 0})
def f(m1: Module, m2: Module):
  ...

What if JAX conventions Just Worked™?

This proposal aims to align NNX transforms with user's expectations based on their JAX experience, making the syntax work as intuitively as possible. The original example would function as if m1 and m2 were PyTrees vectorized in axes 1 and 0 respectively:

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

The primary advantage of this approach is that for vmap and scan, we could eliminate the state_axes and split_rngs arguments, relying solely on the in_axes API. This syntax alone would likely suffice for 80-90% of use cases, as users tend to manage state in predictable ways.

The Lift symbols

To enable more fine-grained state control within each Module, we introduce the Lift API. By using special types containing State Filters in place of a tree prefix, state lifting can now be done structurally. This allows different Filters to be applied to different Modules in the arguments without the need for complex path-based filters. Ideally, each transform would support its own Lift type, adding the desired behavior through existing JAX APIs.

For example, in vmap, we could allow StateAxes instances (vmap's Lift type) to be accepted by in/out_axes to control how substates are handled by mapping state Filters to an axis specifier:

state_axes = StateAxes({Param: 1, BatchStat: None})

@nnx.vmap(in_axes=(state_axes, 0))
def f(m1: Module, m2: Module):
  ...

In this case, m1's Params are vectorized in axis 1 while its BatchStats are broadcasted, and m2's entire state is vectorized in axis 0.

For nnx.grad, we could allow DiffState to be used in the argnums parameter to specify both the position of the argument to be differentiated and a Filter specifying the differentiable State of the Module:

grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y)

Rng Handling

To simplify RNG state handling, we propose removing the separate split_rngs parameter in vmap and scan. Instead, we suggest introducing a new nnx.split_rngs API that would manage RNG handling before and after the transformation. This approach provides more explicit control to the user and aligns better with JAX transform behavior.

Consistent Aliasing

To ensure the correctness of transformations with objects that obey reference semantics, we must enforce consistent lifting/lowering specifications for all aliases of a reference. Transforms must adhere to two rules:

  1. All aliases of a reference must receive the exact same lifting/lowering specification.
  2. Captured references are not allowed on the output of transformed functions.

For example:

@nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes)
def f(m1, m2, m1_alias):
  return m2

m2 = f(m1, m2, m1)

Here, m1 has two input aliases as it is passed as the first and third input to f, but this is acceptable because m1_axes is assigned to both in in_axes. m2 is passed as the second input and has an output alias, which is also acceptable because m2_axes is assigned in both in_axes and out_axes.

Let's examine some examples of programs that should be rejected based on these criteria:

Inconsistent input aliases

Consider a function with two arguments m1 and m2 being vectorized in axis 0 and 1 respectively. Passing the same Module as both arguments would be inconsistent:

@nnx.vmap(in_axes=(0, 1))
def f(m1: Module, m2: Module):
  ...

f(m, m)  # This should be rejected

Inconsistent input / output aliases

Now consider an identity function g under vmap with in_axes=0 and out_axes=1. In JAX, this would result in transposing the arrays in the inputs:

@nnx.vmap(in_axes=0, out_axes=1)
def g(m: Module):
  return m

While this appears correct, in NNX this behavior is not well-defined because shared mutable references behave as auxiliary outputs. Under the hood, g is converted into a function that has the inputs as an extra first output, and out_axes is set to the same values as in_axes for that output:

@nnx.vmap(in_axes=0, out_axes=(0, 1))
def g_real(m: Module):
  return m, m

This return structure reveals an inconsistency: we're attempting to lower m with both out_axes=0 and out_axes=1.

Inconsistent aliases in nested structures

Similar issues can arise in less obvious cases, such as when m is contained within another structure:

@nnx.vmap(in_axes=0, out_axes=1)
def f(m: Module):
  return SomeModule(m)

This means we must traverse the entire graph of both inputs and outputs to check for consistent assignments. The same problem occurs when passing shared reference inputs/outputs with different specifications:

shared = Shared()
m1, m2 = Foo(shared), Foo(shared)

@nnx.vmap(in_axes=(0, 1))
def f(m1, m2):  # shared is passed through both
  ...

Captured Modules cannot be outputs

Finally, let's consider the second consistent aliasing rule, which states that captured Modules cannot be outputs. The main issue here is that NNX needs to split all input references together to track changes, but captured Modules bypass this process. Treating them as new references would result in implicit cloning:

m = SomeModule()

@nnx.vmap(out_axes=0, axis_size=5)
def f():
  return m

assert m is not f()  # implicit cloning

To preserve reference identity, we must disallow captured Modules as outputs. In practice, we can detect captured Modules using the trace level context machinery used to restrict stateful updates on Modules from a different level.

Recap

In this document, we have:

  • Discussed issues with the current implementation that make it unintuitive for JAX users.
  • Proposed refactoring NNX transforms to allow users to use regular JAX semantics when interacting with objects, removing extra arguments introduced by NNX transforms.
  • Introduced the use of Lift types in JAX APIs to compensate for the lack of a "prefix" notion in NNX objects, enabling independent lifting of Module substates.
  • Proposed a new nnx.split_rngs API to replace the split_rngs arguments in vmap and scan, making RNG handling an explicit operation and giving users more control.
  • Analyzed edge cases resulting from aliasing shared mutable references and proposed enforcing consistent aliasing on all transforms with semantics over the inputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant