sparse.grad
only returns the gradient with respect to the first element of a PyTree
#16582
Labels
bug
Something isn't working
Description
When applying
sparse.grad
fromjax.experimental.sparse
to a function which take in a Pytree as the first argument, only the gradient with respect to the first item in the Pytree is returned. This is both unexpected and inconsistent with the behaviour ofjax.grad
, which returns a gradient which has the same tree structure as the input.Here is a small working example demonstrating this behaviour:
Here are the outputs for the different tests:
What jax/jaxlib version are you using?
jax v0.4.13
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: