From bd5e9bef33eb54a61c1d2bdfbf4e85563607f733 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 29 Jan 2024 09:21:23 -0800 Subject: [PATCH] testOgrid: make test compatible with NumPy 2.0 --- tests/lax_numpy_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 91bf646a5f0b..bdfee5958f4d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4803,9 +4803,9 @@ def testMgrid(self): def testOgrid(self): # wrap indexer for appropriate dtype defaults. np_ogrid = _indexer_with_default_outputs(np.ogrid) - def assertListOfArraysEqual(xs, ys): - self.assertIsInstance(xs, list) - self.assertIsInstance(ys, list) + def assertSequenceOfArraysEqual(xs, ys): + self.assertIsInstance(xs, (list, tuple)) + self.assertIsInstance(ys, (list, tuple)) self.assertEqual(len(xs), len(ys)) for x, y in zip(xs, ys): self.assertArraysEqual(x, y) @@ -4814,10 +4814,10 @@ def assertListOfArraysEqual(xs, ys): self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])()) self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) # List of arrays - assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) - assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) - assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) - assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) + assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) + assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) + assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) + assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) # Corner cases self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:]) # Complex number steps