Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1 #716

Closed
GaetanLepage opened this issue May 7, 2024 · 8 comments

Comments

@GaetanLepage
Copy link

Since jax 0.4.27, several tests fail with:

args = (_ClosureConvert(
  jaxpr={ lambda ; a:f32[47] b:f32[] c:i32[] d:bool[] e:bool[] f:i32[] g:f32[] h:f32[47]
    i:bool[...t 0x7fff4c435a90>,
  _makes_false_steps=False
), Traced<ShapedArray(float32[47])>with<DynamicJaxprTrace(level=5/1)>))))
kwds = {}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           ValueError: safe_map() argument 2 is shorter than argument 1
E           --------------------
E           For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
@patrick-kidger
Copy link
Owner

Thanks for the report! I've reproduced this as an upstream bug in JAX: jax-ml/jax#21116

@GaetanLepage
Copy link
Author

GaetanLepage commented May 8, 2024

Thank you for the quick response !
The following tests also fail with a similar error message:

=========================== short test summary info ============================
FAILED tests/test_ad.py::test_closure_convert_basic - assert [] == [140730519276576]
FAILED tests/test_debug.py::test_backward_nan - AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe...
FAILED tests/test_noinline.py::test_abstract - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_num_traces - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_pytree_in - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_simple - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_mlp - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_vmap - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_jvp - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_grad - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_complicated - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
=========== 11 failed, 337 passed, 1 skipped, 39 warnings in 23.96s ============
E       TypeError: emit_python_callback() takes 6 positional arguments but 7 positional arguments (and 1 keyword-only argument) were given
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@GaetanLepage
Copy link
Author

GaetanLepage commented May 8, 2024

Finally, a last test broke for me:

FAILED tests/test_debug.py::test_backward_nan - AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe...
______________________________ test_backward_nan _______________________________
[gw13] linux -- Python 3.11.9 /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/bin/python3.11

capfd = <_pytest.capture.CaptureFixture object at 0x7ffff431d410>

    def test_backward_nan(capfd):
        @eqx.filter_custom_vjp
        def backward_nan(x):
            return x
    
        @backward_nan.def_fwd
        def backward_nan_fwd(perturbed, x):
            del perturbed
            return backward_nan(x), None
    
        @backward_nan.def_bwd
        def backward_nan_bwd(residual, grad_x, perturbed, x):
            del residual, grad_x, perturbed, x
            return jnp.nan
    
        @eqx.filter_jit
        @jax.grad
        def f(x, terminate):
            y = eqx.debug.backward_nan(x, name="foo", terminate=terminate)
            return backward_nan(y)
    
        capfd.readouterr()
        f(jnp.array(1.0), terminate=False)
        jax.effects_barrier()
        text, _ = capfd.readouterr()
>       assert (
            text
            == "foo:\n   primals=array(1., dtype=float32)\ncotangents=array(nan, dtype=float32)\n"  # noqa: E501
        )
E       AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe=float32)\n'
E         
E           foo:
E         -    primals=array(1., dtype=float32)
E         ?            ^
E         +    primals=Array(1., dtype=float32)
E         ?            ^
E         - cotangents=array(nan, dtype=float32)...
E         
E         ...Full output truncated (3 lines hidden), use '-vv' to show

tests/test_debug.py:33: AssertionError

@patrick-kidger
Copy link
Owner

Ah, looks like a bit of other work might be required too. These other failures at least don't look too scary: I think they're just small perturbations from where things were before, and that we can/should adjust in Equinox.

@lockwo lockwo mentioned this issue May 8, 2024
@GaetanLepage
Copy link
Author

With jax 0.4.28, the ValueError: safe_map() argument 2 is shorter than argument 1 don't happen anymore.
However, those still happen:

E       TypeError: emit_python_callback() takes 6 positional arguments but 7 positional arguments (and 1 keyword-only argument) were given
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

in the following tests:

=========================== short test summary info ============================
FAILED tests/test_debug.py::test_backward_nan - AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe...
FAILED tests/test_noinline.py::test_abstract - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_num_traces - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_simple - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_pytree_in - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_mlp - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_vmap - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_grad - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_jvp - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
FAILED tests/test_noinline.py::test_complicated - TypeError: emit_python_callback() takes 6 positional arguments but 7 positi...
=========== 10 failed, 436 passed, 1 skipped, 39 warnings in 47.15s ============

@patrick-kidger
Copy link
Owner

Okay, I think things should be fixed with #719. I'll aim to do a new release shortly.

@GaetanLepage
Copy link
Author

Cool ! I will test that as soon as the release is out.
Thanks @patrick-kidger !

GaetanLepage added a commit to GaetanLepage/nixpkgs that referenced this issue May 12, 2024
@GaetanLepage
Copy link
Author

All good, thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants