Skip to content

Commit

Permalink
Fix failing test due to change in JAX behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Jun 30, 2023
1 parent ce0b503 commit 057afed
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,17 @@ def test_jax_split_not_supported(self):
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
with pytest.raises(AttributeError):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))

split_axis = iscalar("split_axis")
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
with pytest.raises(jax.errors.TracerIntegerConversionError):
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility
with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)


Expand Down

0 comments on commit 057afed

Please sign in to comment.