diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ceb45f1898ca..40032e285f54 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -173,6 +173,7 @@ def _arccosh(x: ArrayLike, /) -> Array: arccosh = _one_to_one_unop(np.arccosh, _arccosh, True) tanh = _one_to_one_unop(np.tanh, lax.tanh, True) arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) +sign = _one_to_one_unop(np.sign, lax.sign) sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) @@ -257,18 +258,6 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) -@_wraps(np.sign, module='numpy') -@jit -def sign(x: ArrayLike, /) -> Array: - check_arraylike('sign', x) - dtype = dtypes.dtype(x) - if dtypes.issubdtype(dtype, np.complexfloating): - re = lax.real(x) - return lax.complex( - lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) - return lax.sign(x) - - @_wraps(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index c16dbb467c47..29f41a6201f3 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -56,6 +56,7 @@ default_dtypes = float_dtypes + int_dtypes inexact_dtypes = float_dtypes + complex_dtypes number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes +real_dtypes = float_dtypes + int_dtypes + unsigned_dtypes all_dtypes = number_dtypes + bool_dtypes @@ -272,7 +273,9 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, []), op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), - op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + # numpy < 2.0.0 has a different convention for complex sign. + op_record("sign", 1, real_dtypes if jtu.numpy_version() < (2, 0, 0) else number_dtypes, + all_shapes, jtu.rand_some_inf_and_nan, []), # numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately. op_record("copysign", 2, default_dtypes + unsigned_dtypes, all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False), @@ -646,6 +649,22 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(np_op, op, args_maker) + # This test can be deleted once we test against NumPy 2.0. + @jtu.sample_product( + shape=all_shapes, + dtype=complex_dtypes + ) + def testSignComplex(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = np.sign + else: + np_fun = lambda x: x / np.where(x == 0, 1, abs(x)) + jnp_fun = jnp.sign + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def testDeferToNamedTuple(self): class MyArray(NamedTuple): arr: jax.Array