diff --git a/tests/lax_test.py b/tests/lax_test.py index 80a264d01fd9..a3980e0560c5 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3405,7 +3405,6 @@ def testOnComplexPlane(self, name, dtype, kind): + (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else []) + (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])), sinc = ['q1', 'q2', 'q3', 'q4'], - sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], @@ -3417,6 +3416,9 @@ def testOnComplexPlane(self, name, dtype, kind): expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [], ) + if jtu.numpy_version() < (2, 0, 0): + regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'] + jnp_op = getattr(jnp, name) if name == 'square':