Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse.grad only returns the gradient with respect to the first element of a PyTree #16582

Closed
al-jshen opened this issue Jun 28, 2023 · 4 comments · Fixed by #19760
Closed

sparse.grad only returns the gradient with respect to the first element of a PyTree #16582

al-jshen opened this issue Jun 28, 2023 · 4 comments · Fixed by #19760
Assignees
Labels
bug Something isn't working

Comments

@al-jshen
Copy link

al-jshen commented Jun 28, 2023

Description

When applying sparse.grad from jax.experimental.sparse to a function which take in a Pytree as the first argument, only the gradient with respect to the first item in the Pytree is returned. This is both unexpected and inconsistent with the behaviour of jax.grad, which returns a gradient which has the same tree structure as the input.

Here is a small working example demonstrating this behaviour:

import jax
import jax.numpy as jnp
from jax.experimental import sparse

def foo1(wb, x, y): # first item is a tuple with W first and B second
    w, b = wb
    return ((w @ x + b) - y).sum()

def foo2(bw, x, y): # first item is a tuple with B first and W second
    b, w = bw
    return ((w @ x + b) - y).sum()

rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 4)
w = jax.random.normal(keys[0], (3, 3))
b = jax.random.normal(keys[1], (3,))
x = jax.random.normal(keys[2], (3,))
y = jax.random.normal(keys[3], (3,))
"""

Here are the outputs for the different tests:

# normal jax.grad, W first and B second

jax.grad(foo1)((w, b), x, y)

(Array([[-0.47994015,  0.42577833,  0.765658  ],
        [-0.47994015,  0.42577833,  0.765658  ],
        [-0.47994015,  0.42577833,  0.765658  ]], dtype=float32),
 Array([1., 1., 1.], dtype=float32))

# ============================

# normal jax.grad, B first and W second

jax.grad(foo2)((b, w), x, y)

(Array([1., 1., 1.], dtype=float32),
 Array([[-0.47994015,  0.42577833,  0.765658  ],
        [-0.47994015,  0.42577833,  0.765658  ],
        [-0.47994015,  0.42577833,  0.765658  ]], dtype=float32))

# ============================

# sparse.grad, W first and B second. only the gradient with respect to W is returned!

sparse.grad(foo1)((w, b), x, y)

Array([[-0.47994015,  0.42577833,  0.765658  ],
       [-0.47994015,  0.42577833,  0.765658  ],
       [-0.47994015,  0.42577833,  0.765658  ]], dtype=float32)

# ============================

# sparse.grad, B first and W second. only the gradient with respect to B is returned!

sparse.grad(foo2)((b, w), x, y)

Array([1., 1., 1.], dtype=float32)

What jax/jaxlib version are you using?

jax v0.4.13

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

@al-jshen al-jshen added the bug Something isn't working label Jun 28, 2023
@jakevdp jakevdp self-assigned this Jul 6, 2023
@Blair-Johnson
Copy link
Contributor

Have there been any updates on this bug or hints about where it may originate from?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 8, 2024

Hey - sorry for being silent here. This is a bug in how sparse.grad is implemented. I don't think we have any plans to fix it at the moment: jax.experimental.sparse is experimental, and you should expect it to have some rough edges.

@Blair-Johnson
Copy link
Contributor

It looks like this is bug in the logic for postprocessing gradients.
https://github.com/google/jax/blob/c1f234a95cb0932cd23ad63a9ddbe0a8d43333b7/jax/experimental/sparse/ad.py#L69-L71
Currently, if argnums indexes a single pytree, such as a dictionary of parameters, this triggers only the first of the computed gradients to be returned. The logic doesn't account for pytrees being unpacked into multiple arguments when flattened.

I opened a draft PR which passes the current sparse tests and should match the behavior of jax.grad() when argnums indexes a pytree. The current sparse testing lacks coverage of this pytree repacking behavior (clearly), so I think writing tests for that is the next step if this looks reasonable.

@Blair-Johnson
Copy link
Contributor

Added some tests and took the PR out of draft.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants