diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index 680efb55f..395c1c3e2 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -25,6 +25,7 @@ Objects in API flip moveaxis permute_dims + repeat reshape roll squeeze diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 74538a8a3..4d7a17dda 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -6,6 +6,7 @@ "flip", "moveaxis", "permute_dims", + "repeat", "reshape", "roll", "squeeze", @@ -159,6 +160,53 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: """ +def repeat( + x: array, + repeats: Union[int, array], + /, + *, + axis: Optional[int] = None, +) -> array: + """ + Repeats each element of an array a specified number of times on a per-element basis. + + .. admonition:: Data-dependent output shape + :class: important + + When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details. + + Parameters + ---------- + x: array + input array containing elements to repeat. + repeats: Union[int, array] + the number of repetitions for each element. + + If ``axis`` is ``None``, let ``N = prod(x.shape)`` and + + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`. + + If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and + + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``. + + If ``repeats`` is an array, the array must have an integer data type. + + .. note:: + For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + + axis: Optional[int] + the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. + + Returns + ------- + out: array + an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated. + """ + + def reshape( x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None ) -> array: