Skip to content

Commit

Permalink
Improve docs for jnp.angle and jnp.flip
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Jun 7, 2024
1 parent 55d0f5e commit 6d94ae3
Showing 1 changed file with 93 additions and 2 deletions.
95 changes: 93 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
return flip(transpose(m, perm), ax2)


@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
"""Reverse the order of elements of an array along the given axis.
JAX implementation of :func:`numpy.flip`.
Args:
m: Array.
axis: integer or sequence of integers. Specifies along which axis or axes
should the array elements be reversed. Default is ``None``, which flips
along all axes.
Returns:
An array with the elements in reverse order along ``axis``.
See Also:
- :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right)
- :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down)
Example:
>>> x1 = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
[2, 1]], dtype=int32)
If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses
the array along that particular axis only.
>>> jnp.flip(x1, axis=1)
Array([[2, 1],
[4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
[3, 4]],
<BLANKLINE>
[[5, 6],
[7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
[6, 5]],
<BLANKLINE>
[[4, 3],
[2, 1]]], dtype=int32)
When ``axis`` is specified with a sequence of integers, then
``jax.numpy.flip`` reverses the array along the specified axes.
>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
[2, 1]],
<BLANKLINE>
[[8, 7],
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))

Expand Down Expand Up @@ -752,9 +806,46 @@ def isreal(x: ArrayLike) -> Array:
i = ufuncs.imag(x)
return lax.eq(i, _lax_const(i, 0))

@util.implements(np.angle)

@partial(jit, static_argnames=['deg'])
def angle(z: ArrayLike, deg: bool = False) -> Array:
"""Return the angle of a complex valued number or array.
JAX implementation of :func:`numpy.angle`.
Args:
z: A complex number or an array of complex numbers.
deg: Boolean. If ``True``, returns the result in degrees else returns
in radians. Default is ``False``.
Returns:
An array of counterclockwise angle of each element of ``z``, with the same
shape as ``z`` of dtype float.
Example:
If ``z`` is a number
>>> z1 = 2+3j
>>> jnp.angle(z1)
Array(0.98279375, dtype=float32, weak_type=True)
If ``z`` is an array
>>> z2 = jnp.array([[1+3j, 2-5j],
... [4-3j, 3+2j]])
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.angle(z2))
[[ 1.25 -1.19]
[-0.64 0.59]]
If ``deg=True``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.angle(z2, deg=True))
[[ 71.57 -68.2 ]
[-36.87 33.69]]
"""
re = ufuncs.real(z)
im = ufuncs.imag(z)
dtype = _dtype(re)
Expand Down

0 comments on commit 6d94ae3

Please sign in to comment.