From 6d94ae32745e7353d1813264cab678ac044d2bf8 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 7 Jun 2024 10:03:07 +0530 Subject: [PATCH] Improve docs for jnp.angle and jnp.flip --- jax/_src/numpy/lax_numpy.py | 95 ++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2e222bf6c612..9b50117332f7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -716,8 +716,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) -@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: + """Reverse the order of elements of an array along the given axis. + + JAX implementation of :func:`numpy.flip`. + + Args: + m: Array. + axis: integer or sequence of integers. Specifies along which axis or axes + should the array elements be reversed. Default is ``None``, which flips + along all axes. + + Returns: + An array with the elements in reverse order along ``axis``. + + See Also: + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) + - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) + + Example: + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flip(x1) + Array([[4, 3], + [2, 1]], dtype=int32) + + If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses + the array along that particular axis only. + + >>> jnp.flip(x1, axis=1) + Array([[2, 1], + [4, 3]], dtype=int32) + + >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) + >>> x2 + Array([[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]], dtype=int32) + >>> jnp.flip(x2) + Array([[[8, 7], + [6, 5]], + + [[4, 3], + [2, 1]]], dtype=int32) + + When ``axis`` is specified with a sequence of integers, then + ``jax.numpy.flip`` reverses the array along the specified axes. + + >>> jnp.flip(x2, axis=[1, 2]) + Array([[[4, 3], + [2, 1]], + + [[8, 7], + [6, 5]]], dtype=int32) + """ util.check_arraylike("flip", m) return _flip(asarray(m), reductions._ensure_optional_axes(axis)) @@ -752,9 +806,46 @@ def isreal(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) -@util.implements(np.angle) + @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: + """Return the angle of a complex valued number or array. + + JAX implementation of :func:`numpy.angle`. + + Args: + z: A complex number or an array of complex numbers. + deg: Boolean. If ``True``, returns the result in degrees else returns + in radians. Default is ``False``. + + Returns: + An array of counterclockwise angle of each element of ``z``, with the same + shape as ``z`` of dtype float. + + Example: + + If ``z`` is a number + + >>> z1 = 2+3j + >>> jnp.angle(z1) + Array(0.98279375, dtype=float32, weak_type=True) + + If ``z`` is an array + + >>> z2 = jnp.array([[1+3j, 2-5j], + ... [4-3j, 3+2j]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2)) + [[ 1.25 -1.19] + [-0.64 0.59]] + + If ``deg=True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2, deg=True)) + [[ 71.57 -68.2 ] + [-36.87 33.69]] + """ re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re)