diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 0bc456fe22..29dfd152e3 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -176,7 +176,7 @@ def test_jax_split_not_supported(self): UserWarning, match="Split node does not have constant split positions." ): fn = pytensor.function([a], a_splits, mode="JAX") - # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it + # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it with pytest.raises(AttributeError): fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) @@ -184,7 +184,9 @@ def test_jax_split_not_supported(self): a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis) with pytest.warns(UserWarning, match="Split node does not have constant axis."): fn = pytensor.function([a, split_axis], a_splits, mode="JAX") - with pytest.raises(jax.errors.TracerIntegerConversionError): + # Same as above, an AttributeError surpasses the `TracerIntegerConversionError` + # Both errors are included for backwards compatibility + with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)): fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)