diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index a83804dd3d3c..026f6aa3e332 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -326,6 +326,7 @@ namespace; they are listed below. polysub polyval positive + pow power printoptions prod diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 46e018114fbc..ceb45f1898ca 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -365,6 +365,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +# Array API alias +pow = power + @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: x1, x2 = promote_shapes("power", x1, x2) # not dtypes diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index a3084473bfd7..0442cef7b93b 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -311,7 +311,7 @@ def positive(x, /): def pow(x1, x2, /): """Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2.""" x1, x2 = _promote_dtypes("pow", x1, x2) - return jax.numpy.power(x1, x2) + return jax.numpy.pow(x1, x2) def real(x, /): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index f71770f0a5d2..b305b8a5e2c0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -407,6 +407,7 @@ nextafter as nextafter, not_equal as not_equal, positive as positive, + pow as pow, power as power, rad2deg as rad2deg, radians as radians, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 7b572f43c26d..8b8d0be502a9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -633,6 +633,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = ...) -> def polysub(a1: Array, a2: Array) -> Array: ... def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... +def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index a169f6ffd695..c16dbb467c47 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -255,6 +255,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, tolerance={dtypes.bfloat16: 4e-2, np.float16: 2e-2, np.float64: 1e-12}), op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("pow", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + tolerance={np.complex128: 1e-14}, check_dtypes=False, alias="power"), op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], tolerance={np.complex128: 1e-14}, check_dtypes=False), op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []), @@ -460,8 +462,8 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) if jtu.test_device_matches(["tpu"]) and op_name in ( "arccosh", "arcsinh", "sinh", "cosh", "tanh", "sin", "cos", "tan", - "log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "power", - "logaddexp", "logaddexp2", "i0", "acosh", "asinh"): + "log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "pow", + "power", "logaddexp", "logaddexp2", "i0", "acosh", "asinh"): tol = jtu.join_tolerance(tol, 1e-4) tol = functools.reduce(jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()]) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d6d4489f00c1..6f725d93483e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5651,8 +5651,7 @@ def testWrappedSignaturesMatch(self): _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all if dtype != dtypes.bfloat16] -# TODO(jakevdp): implement missing ufuncs -UNIMPLEMENTED_UFUNCS = {'spacing', 'pow'} +UNIMPLEMENTED_UFUNCS = {'spacing'} def _all_numpy_ufuncs() -> Iterator[str]: