Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix custom_vjp #4306

Merged
merged 1 commit into from
Oct 30, 2024
Merged

[nnx] fix custom_vjp #4306

merged 1 commit into from
Oct 30, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Oct 17, 2024

What does this PR do?

  • Solves the issue where FwdFn is called outside of an update context. When this happens FwdFn behaves as a pure function.
  • Correctly passes the non differentiable states as captures instead of using global state.

Solves #4265.

@cgarciae cgarciae force-pushed the nnx-custom-vjp-issue branch 10 times, most recently from 8fc186c to ed402f6 Compare October 21, 2024 14:34
@cgarciae cgarciae self-assigned this Oct 21, 2024
@cgarciae cgarciae force-pushed the nnx-custom-vjp-issue branch 4 times, most recently from de6a581 to 3e17c06 Compare October 25, 2024 15:39
@hrbigelow
Copy link

Hi @cgarciae - I'm excited to try this fix!

Just wanted to mention a slight discrepancy in jax vs. nnx custom_vjp signatures that are expected:

nnx bwd: res, (ins_g, outs_g) -> tangent
jax bwd: res, outs_g -> tangent

The docs here seem t obe showing an example of jax.custom_vjp.

@cgarciae
Copy link
Collaborator Author

@hrbigelow this is correct! I expanded a bit more on the docs to reflect this.

BTW: thanks for the wait. I've been trying to get this transform right but its taken some rewrites to solve some edge cases.

@copybara-service copybara-service bot merged commit e4dad9c into main Oct 30, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-custom-vjp-issue branch October 30, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants