Skip to content

Commit

Permalink
jnp.sign: use x/abs(x) for complex arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 17, 2024
1 parent 08837a9 commit dbefbdc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
13 changes: 1 addition & 12 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _arccosh(x: ArrayLike, /) -> Array:
arccosh = _one_to_one_unop(np.arccosh, _arccosh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sign = _one_to_one_unop(np.sign, lax.sign)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)

Expand Down Expand Up @@ -257,18 +258,6 @@ def rint(x: ArrayLike, /) -> Array:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)


@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike, /) -> Array:
check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)


@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Expand Down
21 changes: 20 additions & 1 deletion tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
real_dtypes = float_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes


Expand Down Expand Up @@ -272,7 +273,9 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
[]),
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy < 2.0.0 has a different convention for complex sign.
op_record("sign", 1, real_dtypes if jtu.numpy_version() < (2, 0, 0) else number_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately.
op_record("copysign", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False),
Expand Down Expand Up @@ -646,6 +649,22 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(np_op, op, args_maker)

# This test can be deleted once we test against NumPy 2.0.
@jtu.sample_product(
shape=all_shapes,
dtype=complex_dtypes
)
def testSignComplex(self, shape, dtype):
rng = jtu.rand_default(self.rng())
if jtu.numpy_version() >= (2, 0, 0):
np_fun = np.sign
else:
np_fun = lambda x: x / np.where(x == 0, 1, abs(x))
jnp_fun = jnp.sign
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testDeferToNamedTuple(self):
class MyArray(NamedTuple):
arr: jax.Array
Expand Down

0 comments on commit dbefbdc

Please sign in to comment.