-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update sparse.grad() to support re-packing gradients into PyTrees #19760
Update sparse.grad() to support re-packing gradients into PyTrees #19760
Conversation
@jakevdp This PR is ready for review if you get the chance. |
@jakevdp Finally getting back around to this, apologies for the glacial pace. I adopted your suggestions with some minor fixes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Hi - we're seeing a couple test failures here that have been fixed on the main branch. Can you rebase against the most recent main branch commit so we can try again? |
… pytrees & test cases
Fixes #16582
This PR modifies the postprocessing step of
sparse.grad()
to reconstruct the input PyTrees that were indexed for autodiff. Previously, only the gradient corresponding to the first element of a PyTree would be returned.This PR includes several test cases to verify the behavior of
sparse.grad()
matches that ofjax.grad()
when gradients are taken with respect to pytrees.