From 43ed74f817196bd3732d1cbbc7db6279f7e2aec8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 30 Nov 2023 13:53:13 -0800 Subject: [PATCH] rewrite test not to include float0 broadcast --- tests/api_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api_test.py b/tests/api_test.py index 8d301c742318..d97b3c731d9f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5519,7 +5519,7 @@ def test_vjp_caching(self): def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), static_argnums=(1,)) - _, f_vjp = jax.vjp(identity, 1., True) + _, f_vjp = jax.vjp(lambda x: identity(x, True), 1.) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready()