diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 37592d52fa49..572920fa4d3b 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -78,7 +78,7 @@ def f(x): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeedThenOutfeed(self): - hcb.stop_outfeed_receiver() + hcb._deprecated_stop_outfeed_receiver() @jax.jit def f(x): @@ -100,7 +100,7 @@ def f(x): self.assertAllClose(out, y + np.float32(1)) def testInfeedThenOutfeedInALoop(self): - hcb.stop_outfeed_receiver() + hcb._deprecated_stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed(