Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve interoperability between SyclDevice and DLPack devices #1953

Merged
merged 13 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 73 additions & 18 deletions dpctl/_sycl_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ cdef class SyclDevice(_SyclDevice):

Args:
arg (str, optional):
The argument can be a selector string or ``None``.
The argument can be a selector string, another
:class:`dpctl.SyclDevice`, or ``None``.
Defaults to ``None``.

Raises:
Expand All @@ -293,9 +294,7 @@ cdef class SyclDevice(_SyclDevice):
SyclDeviceCreationError:
If the :class:`dpctl.SyclDevice` object creation failed.
TypeError:
If the list of :class:`dpctl.SyclDevice` objects was empty,
or the input capsule contained a null pointer or could not
be renamed.
If the argument is not a :class:`dpctl.SyclDevice` or string.
"""
@staticmethod
cdef SyclDevice _create(DPCTLSyclDeviceRef dref):
Expand Down Expand Up @@ -363,9 +362,9 @@ cdef class SyclDevice(_SyclDevice):
"Could not create a SyclDevice from default selector"
)
else:
raise ValueError(
raise TypeError(
"Invalid argument. Argument should be a str object specifying "
"a SYCL filter selector string."
"a SYCL filter selector string or another SyclDevice."
)

def print_device_info(self):
Expand Down Expand Up @@ -1557,7 +1556,7 @@ cdef class SyclDevice(_SyclDevice):
cdef int i

if ncounts == 0:
raise TypeError(
raise ValueError(
"Non-empty object representing list of counts is expected."
)
counts_buff = <size_t *> malloc((<size_t> ncounts) * sizeof(size_t))
Expand Down Expand Up @@ -1659,7 +1658,7 @@ cdef class SyclDevice(_SyclDevice):
Created sub-devices.

Raises:
TypeError:
ValueError:
If the ``partition`` keyword argument is not specified or
the affinity domain string is not legal or is not one of the
three supported options.
Expand Down Expand Up @@ -1695,7 +1694,7 @@ cdef class SyclDevice(_SyclDevice):
_partition_affinity_domain_type._next_partitionable
)
else:
raise TypeError(
raise ValueError(
"Partition affinity domain {} is not understood.".format(
partition
)
Expand All @@ -1708,11 +1707,11 @@ cdef class SyclDevice(_SyclDevice):
else:
try:
partition = int(partition)
return self.create_sub_devices_equally(partition)
except Exception as e:
raise TypeError(
"Unsupported type of sub-device argument"
) from e
return self.create_sub_devices_equally(partition)

@property
def parent_device(self):
Expand Down Expand Up @@ -1877,7 +1876,7 @@ cdef class SyclDevice(_SyclDevice):
A Python string representing a filter selector string.

Raises:
TypeError:
ValueError:
If the device is a sub-device.

:Example:
Expand All @@ -1902,7 +1901,7 @@ cdef class SyclDevice(_SyclDevice):
else:
# this a sub-device, free it, and raise an exception
DPCTLDevice_Delete(pDRef)
raise TypeError("This SyclDevice is not a root device")
raise ValueError("This SyclDevice is not a root device")

cdef int get_backend_and_device_type_ordinal(self):
""" If this device is a root ``sycl::device``, returns the ordinal
Expand Down Expand Up @@ -1950,9 +1949,7 @@ cdef class SyclDevice(_SyclDevice):

cdef int get_overall_ordinal(self):
""" If this device is a root ``sycl::device``, returns the ordinal
position of this device in the vector ``sycl::device::get_devices()``
filtered to contain only devices with the same backend as this
device.
position of this device in the vector ``sycl::device::get_devices()``.

Returns -1 if the device is a sub-device, or the device could not
be found in the vector.
Expand Down Expand Up @@ -1985,9 +1982,9 @@ cdef class SyclDevice(_SyclDevice):
A Python string representing a filter selector string.

Raises:
TypeError:
If the device is a sub-device.
ValueError:
If the device is a sub-device.

If no match for the device was found in the vector
returned by ``sycl::device::get_devices()``

Expand Down Expand Up @@ -2026,7 +2023,7 @@ cdef class SyclDevice(_SyclDevice):
else:
# this a sub-device, free it, and raise an exception
DPCTLDevice_Delete(pDRef)
raise TypeError("This SyclDevice is not a root device")
raise ValueError("This SyclDevice is not a root device")
else:
if include_backend:
BTy = DPCTLDevice_GetBackend(self._device_ref)
Expand All @@ -2045,6 +2042,64 @@ cdef class SyclDevice(_SyclDevice):
else:
return str(relId)

def get_unpartitioned_parent_device(self):
""" get_unpartitioned_parent_device()

Returns the unpartitioned parent device of this device.

If this device is already an unpartitioned, root device,
the same device is returned.

Returns:
dpctl.SyclDevice:
A parent, unpartitioned :class:`dpctl.SyclDevice` instance, or
``self`` if already a root device.
"""
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
pDRef = DPCTLDevice_GetParentDevice(self._device_ref)
if pDRef is NULL:
return self
else:
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
return SyclDevice._create(pDRef)

def get_device_id(self):
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
""" get_device_id()
For an unpartitioned device, returns the canonical index of this device
in the list of devices visible to dpctl.

Returns:
int:
The index of the device.

Raises:
ValueError:
If the device could not be found.

:Example:
.. code-block:: python

import dpctl
gpu_dev = dpctl.SyclDevice("gpu")
i = gpu_dev.get_device_id
devs = dpctl.get_devices()
assert devs[i] == gpu_dev
"""
cdef int dev_id = -1
cdef SyclDevice dev

dev = self.get_unpartitioned_parent_device()
dev_id = dev.get_overall_ordinal()
if dev_id < 0:
raise ValueError("device could not be found")
return dev_id


cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
"""
C-API function to get opaque device reference from
Expand Down
6 changes: 6 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
uint64,
)
from dpctl.tensor._device import Device
from dpctl.tensor._dldevice_conversions import (
dldevice_to_sycl_device,
sycl_device_to_dldevice,
)
from dpctl.tensor._dlpack import from_dlpack
from dpctl.tensor._indexing_functions import (
extract,
Expand Down Expand Up @@ -388,4 +392,6 @@
"take_along_axis",
"put_along_axis",
"top_k",
"dldevice_to_sycl_device",
"sycl_device_to_dldevice",
]
39 changes: 39 additions & 0 deletions dpctl/tensor/_dldevice_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .._sycl_device import SyclDevice
from ._usmarray import DLDeviceType


def dldevice_to_sycl_device(dl_dev: tuple):
if isinstance(dl_dev, tuple):
if len(dl_dev) != 2:
raise ValueError("dldevice tuple must have length 2")
else:
raise TypeError(
f"dl_dev is expected to be a 2-tuple, got " f"{type(dl_dev)}"
)
if dl_dev[0] != DLDeviceType.kDLOneAPI:
raise ValueError("dldevice type must be kDLOneAPI")
return SyclDevice(str(dl_dev[1]))


def sycl_device_to_dldevice(dev: SyclDevice):
if not isinstance(dev, SyclDevice):
raise TypeError(
"dev is expected to be a SyclDevice, got " f"{type(dev)}"
)
return (DLDeviceType.kDLOneAPI, dev.get_device_id())
2 changes: 0 additions & 2 deletions dpctl/tensor/_dlpack.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ cpdef object to_dlpack_versioned_capsule(usm_ndarray array, bint copied) except
cpdef object numpy_to_dlpack_versioned_capsule(ndarray array, bint copied) except +
cpdef object from_dlpack_capsule(object dltensor) except +

cdef int get_parent_device_ordinal_id(SyclDevice dev) except *

cdef class DLPackCreationError(Exception):
"""
A DLPackCreateError exception is raised when constructing
Expand Down
27 changes: 2 additions & 25 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,6 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev):

return default_context


cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
cdef c_dpctl.SyclDevice p_dev

pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
if pDRef is not NULL:
# if dev is a sub-device, find its parent
# and return its overall ordinal id
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
p_dev = c_dpctl.SyclDevice._create(pDRef)
return p_dev.get_overall_ordinal()

# return overall ordinal id of argument device
return dev.get_overall_ordinal()


cdef int get_array_dlpack_device_id(
usm_ndarray usm_ary
) except -1:
Expand All @@ -224,14 +202,13 @@ cdef int get_array_dlpack_device_id(
"on non-partitioned SYCL devices on platforms where "
"default_context oneAPI extension is not supported."
)
device_id = ary_sycl_device.get_overall_ordinal()
else:
if not usm_ary.sycl_context == default_context:
raise DLPackCreationError(
"to_dlpack_capsule: DLPack can only export arrays based on USM "
"allocations bound to a default platform SYCL context"
)
device_id = get_parent_device_ordinal_id(ary_sycl_device)
device_id = ary_sycl_device.get_device_id()

if device_id < 0:
raise DLPackCreationError(
Expand Down Expand Up @@ -1086,7 +1063,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
d = device.sycl_device
else:
d = device
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
dl_device = (device_OneAPI, d.get_device_id())
if dl_device is not None:
if (dl_device[0] not in [device_OneAPI, device_CPU]):
raise ValueError(
Expand Down
14 changes: 7 additions & 7 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1304,16 +1304,16 @@ cdef class usm_ndarray:
DLPackCreationError:
when the ``device_id`` could not be determined.
"""
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
if dev_id < 0:
try:
dev_id = self.sycl_device.get_device_id()
except ValueError as e:
raise c_dlpack.DLPackCreationError(
"Could not determine id of the device where array was allocated."
)
else:
return (
DLDeviceType.kDLOneAPI,
dev_id,
)
return (
DLDeviceType.kDLOneAPI,
dev_id,
)

def __eq__(self, other):
return dpctl.tensor.equal(self, other)
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/test_sycl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_cpython_api_SyclContext_Make():

def test_invalid_capsule():
cap = create_invalid_capsule()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dpctl.SyclContext(cap)


Expand Down
Loading
Loading