Skip to content

Commit

Permalink
Merge pull request #18754 from mattjj:fix-float0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586781308
  • Loading branch information
jax authors committed Nov 30, 2023
2 parents 5b3fc1b + 43ed74f commit efb4924
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5519,7 +5519,7 @@ def test_vjp_caching(self):
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
_, f_vjp = jax.vjp(lambda x: identity(x, True), 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
Expand Down

0 comments on commit efb4924

Please sign in to comment.