Skip to content

Commit

Permalink
Add repeat to the specification (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
kgryte authored Feb 22, 2024
1 parent 11273e6 commit 9d200ea
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions spec/draft/API_specification/manipulation_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Objects in API
flip
moveaxis
permute_dims
repeat
reshape
roll
squeeze
Expand Down
48 changes: 48 additions & 0 deletions src/array_api_stubs/_draft/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down Expand Up @@ -159,6 +160,53 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
"""


def repeat(
x: array,
repeats: Union[int, array],
/,
*,
axis: Optional[int] = None,
) -> array:
"""
Repeats each element of an array a specified number of times on a per-element basis.
.. admonition:: Data-dependent output shape
:class: important
When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details.
Parameters
----------
x: array
input array containing elements to repeat.
repeats: Union[int, array]
the number of repetitions for each element.
If ``axis`` is ``None``, let ``N = prod(x.shape)`` and
- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`.
If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and
- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``.
If ``repeats`` is an array, the array must have an integer data type.
.. note::
For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array.
axis: Optional[int]
the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``.
Returns
-------
out: array
an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated.
"""


def reshape(
x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> array:
Expand Down

0 comments on commit 9d200ea

Please sign in to comment.