Skip to content

Commit

Permalink
Merge pull request #24452 from jakevdp:insert-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688624762
  • Loading branch information
Google-ML-Automation committed Oct 22, 2024
2 parents 1a2737b + 48dd153 commit 1e41d5e
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8307,6 +8307,9 @@ def delete(
may specify ``assume_unique_indices=True`` to perform the operation in a
manner that does not require static indices.
See also:
- :func:`jax.numpy.insert`: insert entries into an array.
Examples:
Delete entries from a 1D array:
Expand Down Expand Up @@ -8400,9 +8403,55 @@ def delete(
return a[tuple(slice(None) for i in range(axis)) + (mask,)]


@util.implements(np.insert)
def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
axis: int | None = None) -> Array:
"""Insert entries into an array at specified indices.
JAX implementation of :func:`numpy.insert`.
Args:
arr: array object into which values will be inserted.
obj: slice or array of indices specifying insertion locations.
values: array of values to be inserted.
axis: specify the insertion axis in the case of multi-dimensional
arrays. If unspecified, ``arr`` will be flattened.
Returns:
A copy of ``arr`` with values inserted at the specified locations.
See also:
- :func:`jax.numpy.delete`: delete entries from an array.
Examples:
Inserting a single value:
>>> x = jnp.arange(5)
>>> jnp.insert(x, 2, 99)
Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
Inserting multiple identical values using a slice:
>>> jnp.insert(x, slice(None, None, 2), -1)
Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
Inserting multiple values using an index:
>>> indices = jnp.array([4, 2, 5])
>>> values = jnp.array([10, 11, 12])
>>> jnp.insert(x, indices, values)
Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
Inserting columns into a 2D array:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> indices = jnp.array([1, 3])
>>> values = jnp.array([[10, 11],
... [12, 13]])
>>> jnp.insert(x, indices, values, axis=1)
Array([[ 1, 10, 2, 3, 11],
[ 4, 12, 5, 6, 13]], dtype=int32)
"""
util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
a = asarray(arr)
values_arr = asarray(values)
Expand Down

0 comments on commit 1e41d5e

Please sign in to comment.