Skip to content

Commit

Permalink
Merge pull request #20322 from jakevdp:complex-plane-numpy2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617265717
  • Loading branch information
jax authors committed Mar 19, 2024
2 parents cc06836 + 9062cfb commit 6fa75aa
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,6 @@ def testOnComplexPlane(self, name, dtype, kind):
+ (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
sinc = ['q1', 'q2', 'q3', 'q4'],
sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
Expand All @@ -3417,6 +3416,9 @@ def testOnComplexPlane(self, name, dtype, kind):
expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [],
)

if jtu.numpy_version() < (2, 0, 0):
regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj']

jnp_op = getattr(jnp, name)

if name == 'square':
Expand Down

0 comments on commit 6fa75aa

Please sign in to comment.