Skip to content

Commit

Permalink
Support copy and device keywords in from_dlpack (#741)
Browse files Browse the repository at this point in the history
* support copy in from_dlpack

* specify copy stream

* allow 3-way copy arg to align all constructors

* update to reflect the discussions

* clairfy a bit and fix typos

* sync the copy docs

* clarify what's 'on CPU'

* try to make linter happy

* remove namespace leak clause, clean up, and add an example

* make linter happy

* fix Sphinx complaint about Enum

* add/update v2023-specific notes on device

* remove a note on kDLCPU

---------

Co-authored-by: Ralf Gommers <[email protected]>
  • Loading branch information
leofang and rgommers authored Feb 14, 2024
1 parent 474ec2b commit 71b01c1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
51 changes: 46 additions & 5 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __dlpack__(
*,
stream: Optional[Union[int, Any]] = None,
max_version: Optional[tuple[int, int]] = None,
dl_device: Optional[tuple[Enum, int]] = None,
copy: Optional[bool] = None,
) -> PyCapsule:
"""
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
Expand Down Expand Up @@ -324,6 +326,12 @@ def __dlpack__(
- ``> 2``: stream number represented as a Python integer.
- Using ``1`` and ``2`` is not supported.
.. note::
When ``dl_device`` is provided explicitly, ``stream`` must be a valid
construct for the specified device type. In particular, when ``kDLCPU``
is in use, ``stream`` must be ``None`` and a synchronization must be
performed to ensure data safety.
.. admonition:: Tip
:class: important
Expand All @@ -333,12 +341,40 @@ def __dlpack__(
not want to think about stream handling at all, potentially at the
cost of more synchronizations than necessary.
max_version: Optional[tuple[int, int]]
The maximum DLPack version that the *consumer* (i.e., the caller of
the maximum DLPack version that the *consumer* (i.e., the caller of
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
This method may return a capsule of version ``max_version`` (recommended
if it does support that), or of a different version.
This means the consumer must verify the version even when
`max_version` is passed.
dl_device: Optional[tuple[enum.Enum, int]]
the DLPack device type. Default is ``None``, meaning the exported capsule
should be on the same device as ``self`` is. When specified, the format
must be a 2-tuple, following that of the return value of :meth:`array.__dlpack_device__`.
If the device type cannot be handled by the producer, this function must
raise ``BufferError``.
The v2023.12 standard only mandates that a compliant library should offer a way for
``__dlpack__`` to return a capsule referencing an array whose underlying memory is
accessible to the Python interpreter (represented by the ``kDLCPU`` enumerator in DLPack).
If a copy must be made to enable this support but ``copy`` is set to ``False``, the
function must raise ``ValueError``.
Other device kinds will be considered for standardization in a future version of this
API standard.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the
function must always copy (performed by the producer). If ``False``, the
function must never copy, and raise a ``BufferError`` in case a copy is
deemed necessary (e.g. if a cross-device data movement is requested, and
it is not possible without a copy). If ``None``, the function must reuse
the existing memory buffer if possible and copy otherwise. Default: ``None``.
When a copy happens, the ``DLPACK_FLAG_BITMASK_IS_COPIED`` flag must be set.
.. note::
If a copy happens, and if the consumer-provided ``stream`` and ``dl_device``
can be understood by the producer, the copy must be performed over ``stream``.
Returns
-------
Expand Down Expand Up @@ -394,22 +430,25 @@ def __dlpack__(
# here to tell users that the consumer's max_version is too
# old to allow the data exchange to happen.
And this logic for the consumer in ``from_dlpack``:
And this logic for the consumer in :func:`~array_api.from_dlpack`:
.. code:: python
try:
x.__dlpack__(max_version=(1, 0))
x.__dlpack__(max_version=(1, 0), ...)
# if it succeeds, store info from the capsule named "dltensor_versioned",
# and need to set the name to "used_dltensor_versioned" when we're done
except TypeError:
x.__dlpack__()
x.__dlpack__(...)
This logic is also applicable to handling of the new ``dl_device`` and ``copy``
keywords.
.. versionchanged:: 2022.12
Added BufferError.
.. versionchanged:: 2023.12
Added the ``max_version`` keyword.
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
"""

def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
Expand All @@ -436,6 +475,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
METAL = 8
VPI = 9
ROCM = 10
CUDA_MANAGED = 13
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
Expand Down
46 changes: 43 additions & 3 deletions src/array_api_stubs/_draft/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,36 @@ def eye(
"""


def from_dlpack(x: object, /) -> array:
def from_dlpack(
x: object,
/,
*,
device: Optional[device] = None,
copy: Optional[bool] = None,
) -> array:
"""
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.
Parameters
----------
x: object
input (array) object.
device: Optional[device]
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array must be on the same device as ``x``. Default: ``None``.
The v2023.12 standard only mandates that a compliant library should offer a way for ``from_dlpack`` to return an array
whose underlying memory is accessible to the Python interpreter, when the corresponding ``device`` is provided. If the
array library does not support such cases at all, the function must raise ``BufferError``. If a copy must be made to
enable this support but ``copy`` is set to ``False``, the function must raise ``ValueError``.
Other device kinds will be considered for standardization in a future version of this API standard.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy, and raise ``BufferError`` in case a copy is deemed necessary (e.g. if a cross-device data movement is requested, and it is not possible without a copy). If ``None``, the function must reuse the existing memory buffer if possible and copy otherwise. Default: ``None``.
Returns
-------
out: array
an array containing the data in `x`.
an array containing the data in ``x``.
.. admonition:: Note
:class: note
Expand All @@ -238,19 +255,42 @@ def from_dlpack(x: object, /) -> array:
BufferError
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
may raise ``BufferError`` when the data cannot be exported as DLPack
(e.g., incompatible dtype or strides). It may also raise other errors
(e.g., incompatible dtype, strides, or device). It may also raise other errors
when export fails for other reasons (e.g., not enough memory available
to materialize the data). ``from_dlpack`` must propagate such
exceptions.
AttributeError
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
on the input array. This may happen for libraries that are never able
to export their data with DLPack.
ValueError
If data exchange is possible via an explicit copy but ``copy`` is set to ``False``.
Notes
-----
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
order to handle DLPack versioning correctly.
A way to move data from two array libraries to the same device (assumed supported by both libraries) in
a library-agnostic fashion is illustrated below:
.. code:: python
def func(x, y):
xp_x = x.__array_namespace__()
xp_y = y.__array_namespace__()
# Other functions than `from_dlpack` only work if both arrays are from the same library. So if
# `y` is from a different one than `x`, let's convert `y` into an array of the same type as `x`:
if not xp_x == xp_y:
y = xp_x.from_dlpack(y, copy=True, device=x.device)
# From now on use `xp_x.xxxxx` functions, as both arrays are from the library `xp_x`
...
.. versionchanged:: 2023.12
Added device and copy support.
"""


Expand Down

0 comments on commit 71b01c1

Please sign in to comment.