Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] use jax-style transforms API in nnx_basics #4155

Merged
merged 1 commit into from
Sep 2, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Aug 29, 2024

What does this PR do?

Updates nnx_basics to use the new JAX-like transforms syntax in the Transforms section. Also adds some additional notes about nnx.scan.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -225,22 +225,24 @@ input (scan over layers).

Notice the following:
1. The `create_model` function creates a (single) `MLP` object that is lifted by
`nnx.vmap` to have an additional dimension of size `axis_size`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more explicitly that 5 MLP layers are created because keys has 5 dimensions, and vmap inferred the axis_size from it? This now looks a bit confusing b/c the user would have to guess how do we customize the number of layers here.

I actually still prefer if we can explicitly customize the number of layers in nnx.vmap line, instead of implicitly at the keys line.

2. The `forward` function indexes the `MLP` object's state to get a different set of
parameters at each step.
3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and
`Dropout` layers from within `forward` to the `model` reference outside.
3. The `nnx.scan` transform consciously deviates from its JAX equivalent in order to mimick
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add links for the nnx.scan, nnx.vmap, jax.scan and jax.vmap API docs? Most people reading this will be new to JAX/Flax and don't know these subjects.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add add the reference on their first appearance above.

def create_model(rngs: nnx.Rngs):
return MLP(10, 32, 10, rngs=rngs)
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a line that print out the shape of model's params? To show this additional axis.

@@ -225,22 +225,24 @@ input (scan over layers).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd want to explain a bit about scan here. What about changing this line to:

Next lets take a look at a different example, which uses nnx.vmap to create a stack of multiple MLP layers and nnx.scan to iteratively apply each layer of the stack to the input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@cgarciae cgarciae changed the title [nnx] use jax-like transforms API in nnx_basics [nnx] use jax-style transforms API in nnx_basics Sep 2, 2024
@copybara-service copybara-service bot merged commit 98dff5e into main Sep 2, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-fix-nnx-basics-2 branch September 2, 2024 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants