Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat to the specification #690

Merged
merged 21 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/manipulation_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Objects in API
flip
moveaxis
permute_dims
repeat
reshape
roll
squeeze
Expand Down
48 changes: 47 additions & 1 deletion src/array_api_stubs/_draft/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand All @@ -15,7 +16,7 @@
]


from ._types import List, Optional, Tuple, Union, array
from ._types import List, Optional, Tuple, Union, Sequence, array


def broadcast_arrays(*arrays: array) -> List[array]:
Expand Down Expand Up @@ -159,6 +160,51 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
"""


def repeat(
x: array,
repeats: Union[int, Sequence[int], array],
/,
*,
axis: Optional[int] = None,
) -> array:
"""
Repeats elements of an array.

Parameters
----------
x: array
input array containing elements to repeat.
kgryte marked this conversation as resolved.
Show resolved Hide resolved
repeats: Union[int, Sequence[int], array]
the number of repetitions for each element.

If ``axis`` is ``None``, let ``N = prod(x.shape)`` and

- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``).
- if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(N,)`` (i.e., the number of sequence elements be either ``1`` or ``N``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`.

If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and

- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``).
- if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M,)`` (i.e., the number of sequence elements must be either ``1`` or ``M``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``.

If ``repeats`` is an array, the array must have an integer data type.


.. note::
For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array.
kgryte marked this conversation as resolved.
Show resolved Hide resolved

axis: Optional[int]
the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``.

Returns
-------
out: array
an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array have the same shape as ``x``, except for the axis (dimension) along which elements were repeated.
kgryte marked this conversation as resolved.
Show resolved Hide resolved
"""


def reshape(
x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> array:
Expand Down
Loading