From 89f530aff34d883b5d3c497a630cd29051f278d8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 10 Jan 2024 13:03:36 -0800 Subject: [PATCH] [array API] implement jnp.pow; alias for jnp.power --- docs/jax.numpy.rst | 1 + jax/_src/numpy/ufuncs.py | 3 +++ jax/experimental/array_api/_elementwise_functions.py | 2 +- jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 1 + tests/lax_numpy_operators_test.py | 6 ++++-- tests/lax_numpy_test.py | 2 +- 7 files changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 4f94b0672b41..b9f0e4c3435a 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -323,6 +323,7 @@ namespace; they are listed below. polysub polyval positive + pow power printoptions prod diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 275b2575d60d..076e8146d5bc 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -355,6 +355,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +# Array API alias +pow = power + @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: x1, x2 = promote_shapes("power", x1, x2) # not dtypes diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 373d29098a16..1b68f7aee955 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -311,7 +311,7 @@ def positive(x, /): def pow(x1, x2, /): """Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2.""" x1, x2 = _promote_dtypes("pow", x1, x2) - return jax.numpy.power(x1, x2) + return jax.numpy.pow(x1, x2) def real(x, /): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d623d7cdaa36..6b9e85e5613d 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -403,6 +403,7 @@ nextafter as nextafter, not_equal as not_equal, positive as positive, + pow as pow, power as power, rad2deg as rad2deg, radians as radians, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 9b574fafb39f..ff539d557c9c 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -630,6 +630,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = ...) -> def polysub(a1: Array, a2: Array) -> Array: ... def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... +def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 8b9f55e0dbfc..c42c5da23cca 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -255,6 +255,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, tolerance={dtypes.bfloat16: 4e-2, np.float16: 2e-2, np.float64: 1e-12}), op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("pow", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + tolerance={np.complex128: 1e-14}, check_dtypes=False, alias="power"), op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], tolerance={np.complex128: 1e-14}, check_dtypes=False), op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []), @@ -458,8 +460,8 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) if jtu.test_device_matches(["tpu"]) and op_name in ( "arccosh", "arcsinh", "sinh", "cosh", "tanh", "sin", "cos", "tan", - "log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "power", - "logaddexp", "logaddexp2", "i0", "acosh", "asinh"): + "log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "pow", + "power", "logaddexp", "logaddexp2", "i0", "acosh", "asinh"): tol = jtu.join_tolerance(tol, 1e-4) tol = functools.reduce(jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()]) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7d0537b21530..0ed7fd5501a0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5615,7 +5615,7 @@ def testWrappedSignaturesMatch(self): # TODO(jakevdp): implement missing ufuncs UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift', - 'bitwise_right_shift', 'pow', 'vecdot'} + 'bitwise_right_shift', 'vecdot'} def _all_numpy_ufuncs() -> Iterator[str]: