Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sparse.grad() to support re-packing gradients into PyTrees #19760

Merged
merged 1 commit into from
Jul 29, 2024
Merged

Update sparse.grad() to support re-packing gradients into PyTrees #19760

merged 1 commit into from
Jul 29, 2024

Conversation

Blair-Johnson
Copy link
Contributor

@Blair-Johnson Blair-Johnson commented Feb 12, 2024

Fixes #16582

This PR modifies the postprocessing step of sparse.grad() to reconstruct the input PyTrees that were indexed for autodiff. Previously, only the gradient corresponding to the first element of a PyTree would be returned.

This PR includes several test cases to verify the behavior of sparse.grad() matches that of jax.grad() when gradients are taken with respect to pytrees.

@Blair-Johnson Blair-Johnson changed the title Fix #16582: sparse.grad() not re-packing gradients to match argument PyTrees Fix #16582 sparse.grad() not re-packing gradients to match argument PyTrees Feb 12, 2024
@Blair-Johnson Blair-Johnson changed the title Fix #16582 sparse.grad() not re-packing gradients to match argument PyTrees Update sparse.grad() to support re-packing gradients into PyTrees Feb 12, 2024
@Blair-Johnson Blair-Johnson marked this pull request as ready for review March 27, 2024 20:27
@Blair-Johnson
Copy link
Contributor Author

@jakevdp This PR is ready for review if you get the chance.

@jakevdp jakevdp self-assigned this Apr 16, 2024
@jakevdp jakevdp self-requested a review April 16, 2024 20:16
tests/sparse_test.py Outdated Show resolved Hide resolved
tests/sparse_test.py Outdated Show resolved Hide resolved
@Blair-Johnson Blair-Johnson requested a review from jakevdp April 16, 2024 22:09
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 19, 2024
jax/experimental/sparse/ad.py Outdated Show resolved Hide resolved
jax/experimental/sparse/ad.py Outdated Show resolved Hide resolved
jax/experimental/sparse/ad.py Outdated Show resolved Hide resolved
@Blair-Johnson
Copy link
Contributor Author

@jakevdp Finally getting back around to this, apologies for the glacial pace. I adopted your suggestions with some minor fixes.

@Blair-Johnson Blair-Johnson requested a review from jakevdp July 26, 2024 17:51
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 29, 2024

Hi - we're seeing a couple test failures here that have been fixed on the main branch. Can you rebase against the most recent main branch commit so we can try again?

@copybara-service copybara-service bot merged commit 9beb4f1 into jax-ml:main Jul 29, 2024
10 checks passed
@Blair-Johnson Blair-Johnson deleted the fix-pytree-grads-sparse branch July 29, 2024 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sparse.grad only returns the gradient with respect to the first element of a PyTree
3 participants