Skip to content

Commit

Permalink
Merge pull request #19293 from jakevdp:jnp-pow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597371608
  • Loading branch information
jax authors committed Jan 10, 2024
2 parents 4e5430d + 1a39d8f commit bbf2ab0
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ namespace; they are listed below.
polysub
polyval
positive
pow
power
printoptions
prod
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# Handle cases #2 and #3 under a jit:
return _power(x1, x2)

# Array API alias
pow = power

@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = promote_shapes("power", x1, x2) # not dtypes
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def positive(x, /):
def pow(x1, x2, /):
"""Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2."""
x1, x2 = _promote_dtypes("pow", x1, x2)
return jax.numpy.power(x1, x2)
return jax.numpy.pow(x1, x2)


def real(x, /):
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@
nextafter as nextafter,
not_equal as not_equal,
positive as positive,
pow as pow,
power as power,
rad2deg as rad2deg,
radians as radians,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = ...) ->
def polysub(a1: Array, a2: Array) -> Array: ...
def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ...
def positive(x: ArrayLike, /) -> Array: ...
def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def power(x: ArrayLike, y: ArrayLike, /) -> Array: ...
printoptions = _np.printoptions
def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
Expand Down
6 changes: 4 additions & 2 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
tolerance={dtypes.bfloat16: 4e-2, np.float16: 2e-2,
np.float64: 1e-12}),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("pow", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={np.complex128: 1e-14}, check_dtypes=False, alias="power"),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={np.complex128: 1e-14}, check_dtypes=False),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
Expand Down Expand Up @@ -460,8 +462,8 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes,
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
if jtu.test_device_matches(["tpu"]) and op_name in (
"arccosh", "arcsinh", "sinh", "cosh", "tanh", "sin", "cos", "tan",
"log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "power",
"logaddexp", "logaddexp2", "i0", "acosh", "asinh"):
"log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "pow",
"power", "logaddexp", "logaddexp2", "i0", "acosh", "asinh"):
tol = jtu.join_tolerance(tol, 1e-4)
tol = functools.reduce(jtu.join_tolerance,
[tolerance, tol, jtu.default_tolerance()])
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5651,8 +5651,7 @@ def testWrappedSignaturesMatch(self):
_available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
if dtype != dtypes.bfloat16]

# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'pow'}
UNIMPLEMENTED_UFUNCS = {'spacing'}


def _all_numpy_ufuncs() -> Iterator[str]:
Expand Down

0 comments on commit bbf2ab0

Please sign in to comment.