Skip to content

Commit

Permalink
Improve documentation for jnp.put, jnp.place, jnp.fill_diagonal
Browse files Browse the repository at this point in the history
These are all the APIs that have an inplace parameter
  • Loading branch information
jakevdp committed Jul 29, 2024
1 parent e78e643 commit cfa1e78
Showing 1 changed file with 183 additions and 34 deletions.
217 changes: 183 additions & 34 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5222,18 +5222,71 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])


@util.implements(np.fill_diagonal, lax_description="""
The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which
JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal`
adds the ``inplace`` parameter, which must be set to ``False`` by the user as a
reminder of this API difference.
""", extra_params="""
inplace : bool, default=True
If left to its default value of True, JAX will raise an error. This is because
the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place,
which is not possible in JAX due to the immutability of JAX arrays.
""")
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array:
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *,
inplace: bool = True) -> Array:
"""Return a copy of the array with the diagonal overwritten.
JAX implementation of :func:`numpy.fill_diagonal`.
The semantics of :func:`numpy.fill_diagonal` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
a: input array. Must have ``a.ndim >= 2``. If ``a.ndim >= 3``, then all
dimensions must be the same size.
val: scalar or array with which to fill the diagonal. If an array, it will
be flattened and repeated to fill the diagonal entries.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``a`` with the diagonal set to ``val``.
Examples:
>>> x = jnp.zeros((3, 3), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False)
Array([[1, 0, 0],
[0, 2, 0],
[0, 0, 3]], dtype=int32)
Unlike :func:`numpy.fill_diagonal`, the input ``x`` is not modified.
If the diagonal value has too many entries, it will be truncated
>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False)
Array([[100, 0, 0],
[ 0, 101, 0],
[ 0, 0, 102]], dtype=int32)
If the diagonal has too few entries, it will be repeated:
>>> x = jnp.zeros((4, 4), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False)
Array([[3, 0, 0, 0],
[0, 4, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]], dtype=int32)
For non-square arrays, the diagonal of the leading square slice is filled:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> jnp.fill_diagonal(x, 1, inplace=False)
Array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]], dtype=int32)
And for square N-dimensional arrays, the N-dimensional diagonal is filled:
>>> y = jnp.zeros((2, 2, 2))
>>> jnp.fill_diagonal(y, 1, inplace=False)
Array([[[1., 0.],
[0., 0.]],
<BLANKLINE>
[[0., 0.],
[0., 1.]]], dtype=float32)
"""
if inplace:
raise NotImplementedError("JAX arrays are immutable, must use inplace=False")
if wrap:
Expand Down Expand Up @@ -8830,19 +8883,64 @@ def _tile_to_size(arr: Array, size: int) -> Array:
return arr[:size] if arr.size > size else arr


@util.implements(np.place, lax_description="""
The semantics of :func:`numpy.place` is to modify arrays in-place, which JAX
cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.place` adds
the ``inplace`` parameter, which must be set to ``False`` by the user as a
reminder of this API difference.
""", extra_params="""
inplace : bool, default=True
If left to its default value of True, JAX will raise an error. This is because
the semantics of :func:`numpy.put` are to modify the array in-place, which is
not possible in JAX due to the immutability of JAX arrays.
""")
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: bool = True) -> Array:
"""Update array elements based on a mask.
JAX implementation of :func:`numpy.place`.
The semantics of :func:`numpy.place` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be placed.
mask: boolean mask with the same size as ``arr``.
vals: values to be inserted into ``arr`` at the locations indicated
by mask. If too many values are supplied, they will be truncated.
If not enough values are supplied, they will be repeated.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``arr`` with masked values set to entries from `vals`.
See Also:
- :func:`jax.numpy.put`: put elements into an array at numerical indices.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing
Examples:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False, True, False],
[False, True, False, False, True],
[False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
In this case, ``jnp.place`` is similar to the masked array update syntax:
>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
``place`` differs when placing values from an array. The array is repeated
to fill the masked entries:
>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
[0, 5, 0, 0, 1],
[0, 0, 3, 0, 0]], dtype=int32)
"""
util.check_arraylike("place", arr, mask, vals)
data, mask_arr, vals_arr = asarray(arr), asarray(mask), ravel(vals)
if inplace:
Expand All @@ -8860,19 +8958,70 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape)


@util.implements(np.put, lax_description="""
The semantics of :func:`numpy.put` is to modify arrays in-place, which JAX
cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.put` adds
the ``inplace`` parameter, which must be set to ``False`` by the user as a
reminder of this API difference.
""", extra_params="""
inplace : bool, default=True
If left to its default value of True, JAX will raise an error. This is because
the semantics of :func:`numpy.put` are to modify the array in-place, which is
not possible in JAX due to the immutability of JAX arrays.
""")
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = None, *, inplace: bool = True) -> Array:
"""Put elements into an array at given indices.
JAX implementation of :func:`numpy.put`.
The semantics of :func:`numpy.put` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
a: array into which values will be placed.
ind: array of indices over the flattened array at which to put values.
v: array of values to put into the array.
mode: string specifying how to handle out-of-bound indices. Supported values:
- ``"clip"`` (default): clip out-of-bound indices to the final index.
- ``"wrap"``: wrap out-of-bound indices to the beginning of the array.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
Examples:
>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following :attr:`jax.numpy.ndarray.at` indexing syntax:
>>> x.at[indices].set(values)
Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they are
clipped:
>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10, 0, 0, 0, 0],
[ 0, 0, 20, 0, 0],
[ 0, 0, 0, 0, 30]], dtype=int32)
"""
util.check_arraylike("put", a, ind, v)
arr, ind_arr, v_arr = asarray(a), ravel(ind), ravel(v)
if not arr.size or not ind_arr.size or not v_arr.size:
Expand Down

0 comments on commit cfa1e78

Please sign in to comment.