diff --git a/examples/examples_test.py b/examples/examples_test.py index c9cb2991c030..fd705a4ef799 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -26,6 +26,9 @@ from jax import lax from jax import random import jax.numpy as jnp +from jax._src import test_util as jtu + +del jtu # Needed for flags sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from examples import kernel_lsq