From 9caf59d68be30c4a90d3518cc0d69fce9ef9e316 Mon Sep 17 00:00:00 2001 From: Selam Waktola Date: Mon, 6 May 2024 13:43:55 -0700 Subject: [PATCH] improve documentation for ix_ --- jax/_src/numpy/lax_numpy.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffb031a5f2bc..dc5ea58f4154 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3158,9 +3158,39 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) - -@util.implements(np.ix_) def ix_(*args: ArrayLike) -> tuple[Array, ...]: + """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. + + JAX implementation of :func:`numpy.ix_`. + + Args: + *args: N one-dimensional arrays + + Returns: + Tuple of Jax arrays forming an open mesh, each with N dimensions. + + See Also: + - :obj:`jax.numpy.ogrid` + - :obj:`jax.numpy.mgrid` + - :func:`jax.numpy.meshgrid` + + Example: + >>> rows = jnp.array([0, 2]) + >>> cols = jnp.array([1, 3]) + >>> open_mesh = jnp.ix_(rows, cols) + >>> open_mesh + (Array([[0], + [2]], dtype=int32), Array([[1, 3]], dtype=int32)) + >>> [grid.shape for grid in open_mesh] + [(2, 1), (1, 2)] + >>> x = jnp.array([[10, 20, 30, 40], + ... [50, 60, 70, 80], + ... [90, 100, 110, 120], + ... [130, 140, 150, 160]]) + >>> x[open_mesh] + Array([[ 20, 40], + [100, 120]], dtype=int32) + """ util.check_arraylike("ix", *args) n = len(args) output = []