diff --git a/tests/api_test.py b/tests/api_test.py index 8d301c742318..d97b3c731d9f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5519,7 +5519,7 @@ def test_vjp_caching(self): def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), static_argnums=(1,)) - _, f_vjp = jax.vjp(identity, 1., True) + _, f_vjp = jax.vjp(lambda x: identity(x, True), 1.) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready()