Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Adopt DLPack as cross-language C ABI stable data structure for array exchange #1

Closed
tqchen opened this issue Aug 17, 2020 · 59 comments
Labels
RFC Request for comments

Comments

@tqchen
Copy link

tqchen commented Aug 17, 2020

In order for an ndarray system to interact with a variety of frameworks, a stable in-memory data structure is needed.

DLPack is one such data structure that allows exchange between major frameworks. It is developed with inputs from many deep learning system core developers. Highlights include:

  • Minimum and stable: simple header
    • The spec has stayed roughly unchanged for more than four years.
  • Designed for cross hardware: CPU, CUDA, OpenCL, Vulkan, ROCm, Hexagon
  • Already a "standard" with wide community adoption and support, ones that I am aware of:
    • Frameworks, tensorflow/jax, pytorch, mxnet
    • Libraries: dgl, spaCy etc.
    • Compilers: TVM
  • Clean C ABI compatible
    • Means you can create and access it from any language
    • It is also essential for building JIT and AOT compilers to support these data types.
  • High performance consideration
    • Data field mandatory aligns to 256 bytes(for aligned load), allow byte_offset to offset the array if necessary

The main design rationale of DLPack is the minimalism. DLPack drops the consideration of allocator, device API and focus on the minimum data structure. While still considering the need for cross hardware support(e.g. the data field is opaque for platforms that does not support normal addressing).

It also simplifies some of the design to remove legacy issues(e.g. everything assumes to be row major, strides can be used to support other case, and avoid the complexity to consider more layouts)

After building the frameworks around the related data structures for a while, and see ecosystem grows around it, I am quite convinced that DLPack should be one important candidate, if not the best one for the C ABI array data structure.

Given that array exchange is one goal of the consortium, it would be great to see if dlpack can be used as the stable C ABI structure for array exchange.

If the proposal receives positive response from the community. We would be more than happy to explore options to build a neutral, open governance(e.g. like the apache model) that continues to oversees the evolution of the spec -- for example, donate the dlpack to the data-api consortium or host it in a clean github org. 

@tqchen tqchen changed the title [RFC] Adopting DLPack as cross-language C ABI stable data structure for array exchange [RFC] Adopt DLPack as cross-language C ABI stable data structure for array exchange Aug 17, 2020
@szha szha added the RFC Request for comments label Aug 18, 2020
@shoyer
Copy link

shoyer commented Aug 18, 2020

Thanks for bringing up this suggestion! DLPack does look quite promising for array interoperability.

From the perspective of Python data APIs, one aspect of DLPack that is not clear to me is how to use it at the level of Python/CPython objects, i.e., the equivalent of __array_interface__, __cuda_array_interface__ and/or Python's buffer protocol.

Does DLPack even expose a Python object for wrapping DLPack tensors? It looks like right now JAX and PyTorch just use PyCapsule objects? That's probably fine but worth standardizing.

@szha
Copy link
Member

szha commented Aug 18, 2020

Does DLPack even expose a Python object for wrapping DLPack tensors?

not yet, though it's fairly straightforward to expose one. here's one through ctypes

https://github.com/apache/incubator-mxnet/blob/2610c10701c2b8155dbf094aaecba37ebbf67d0f/python/mxnet/dlpack.py#L63-L81

the equivalent of __array_interface__, __cuda_array_interface__ and/or Python's buffer protocol.

For dlpack, there are two main differences from array interfaces (see https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h#L132-L148):

  • coordination of writing
  • the data descriptor for complex data types

I believe the former is intentional so that it's easier to conform. The later can (and I think should) be extended in dlpack.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

Right now most of the frameworks we know already conforms the convention by PyTorch/Jax/TF/TVM (these APIs are in python), see for example,

  • Framework can export an DLPack object in PyCapsule
  • The PyCapsule can be consumed exactly once (think of move semantics in C++ and Rust)
  • If the PyCapsule is not consumed, the deleter of DLPack will be called during destruction of the PyCapsule
  • If the PyCapsule is consumed
    • the consumer will mark the PyCapsule as "used_dltensor" (This is the current convention used by most frameworks)
    • Alternatively, we can also directly change the deleter of the consumed PyCapsule to None

Example APIs

Complex Number

Thanks @szha on the comment. We could certainly fold the complex data type as part of DLDataType. However, that might be an interesting topic that can could need a bit more discussion.

The main reason is because there are quite a few ways complex number can be stored(e.g. array of struct vs struct of array) for performance reasons, and different frameworks/HW might choose different approach. A more thorough discussion might be necessary

@aregm
Copy link

aregm commented Aug 18, 2020

@tqchen for 1-dim arrays what is the difference between DLPack and Arrow format?

@byronyi
Copy link

byronyi commented Aug 18, 2020

DLPack drops the consideration of allocator, device API and focus on the minimum data structure.

I would suggest to integrate DLPack with the stream executor API, including async malloc/dealloc, fine-grained read/write barriers, etc., which is a de-facto standard in high performance training frameworks.

Without proper compute/transfer stream synchronization between frameworks, pretending accessing the array in device memory space is the same as accessing host memory causes either overhead of global barriers or memory inconsistency for DLPack arrays.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

Thanks @byronyi for the comment about async device stream support, this is something that we have thought very careful about.

This is a design tradeoff in terms of how many parts people want to standarize, vs how many part are left over to the frameworks themselves.

Most of the deep learning framework has their own internal mechanism for managing aync device computations: for example MXNet has the dependency scheduler, TF-RT has its own scheduler(that relies on its internal future system).

While it is totally possible to introduce a broader scope API standardization, by incorporating the stream/executor scheduling. The result is the cost of more standardization, and harder adoption from the frameworks -- what if framework A comes up with another graph scheduler that is faster than the vanilla stream executor? (This is totally possible).

So the rationale is given that the allocator / async scheduler part is a bulk piece that is still evolving, we take a more conservative approach by only standardizing the part we can agree on -- namely the data structures.

This does not prevent frameworks to still agree on additional conventions during exchange, for example, if the pytorch and TVM uses the same CUDA stream during exchange, there is no need for barriers in synchronization. In many cases, agreeing to the default convention is good enough as a compromise -- for example, usually sync to default CUDA stream is not a bad choice.

Now, it is certainly possible to introduce additional layers of standardization of allocator, or scheduler on top of DLPack -- since scheduling and data structure are orthogonal. But based on my experience, this part is still in flux and it is relatively harder to get frameworks' agreement.

@szha
Copy link
Member

szha commented Aug 18, 2020

I'd agree with the assessment and we can regard scheduling coordination out of scope for now.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

@aregm wrt to Arrow and DLPack. I believe they are designed with different design goals in mind.

Base on my understanding, Arrow is a good format for dataframe exchange. The key rationale is to represent the data in a compact in-memory format that is also friendly to common dataframe related processing. From that perspective, the meta-data is defined with considerations including things like support for non-POD data types, variable length encoding. etc.

DLPack focuses more on the computation, and support for more hardware variations(due to the background in deep learning system). As a result there are several key design choices that may not be present in arrow's array. Note that these are all subtle but important design decisions (since the representation of POD-type Array can be as simple as the data pointer plus length). Most of the rationales are documented in the DLPack header file as well, I list some of the choices here:

  • Besides the data pointer file, there is a byte_offset to represent offset to the data pointer. This is to accomodate array slicing when the device data pointer is opaque(does not support host side addressing) , in the case of common accelerators, vulkan and opencl.
  • Instead of having a plain type code that enumerates over the types (e.g. int8, int32, int64, float32), the data type field is parameteric(support bits, type code and lanes), which allows us to represent vector types like int4x2, this is important to represent basic vector types, especially those in sub-byte category.
    • Right now supported base type include float, 'int', 'uint', 'bfloat'(bfloat16 for deep learning accelration)
  • A context field to represent the device context (include CPU, CUDA, AMDGPU, vulkan, opencl).

@rgommers
Copy link
Member

Given that array exchange is one goal of the consortium, it would be great to see if dlpack can be used as the stable C ABI structure for array exchange.

Agreed, this is an interesting topic and fits well with the goals of this consortium. We're starting with Python API standards docs, and I think this would be separate, but makes a lot of sense to treat it in a very similar way.

One of the things DLPack doesn't yet seem to have is docs (except for the README and info in the header) - the content of the conversation in this issue tells me more about purpose, scope, use cases and semantics than what I can find in the DLPack repo.

If the proposal receives positive response from the community. We would be more than happy to explore options to build a neutral, open governance(e.g. like the apache model) that continues to oversees the evolution of the spec -- for example, donate the dlpack to the data-api consortium or host it in a clean github org.

Thanks for mentioning that. It looks to me like the repo with the reference implementation for DLPack is in good hands today, so I wouldn't be in a hurry to move it. If we get consensus on DLPack being standardized, I'd be more inclined to do the docs (including purpose etc. I mentioned above) here, and reference the current repo for implementation.

....
So the rationale is given that the allocator / async scheduler part is a bulk piece that is still evolving, we take a more conservative approach by only standardizing the part we can agree on -- namely the data structures.

This makes perfect sense to me, and is how we approach the Python API standardization as well.

Example APIs

I have to say the Python API looks a little awkward to me. Referencing dlpack as a name assumes a level of knowledge from the user that really would be better hidden. Compare with the buffer protocol in Python, which "just works" but is invisible to users - they just call a constructor function like numpy.asarray.

The "consume exactly once" is something that doesn't commonly exist in Python usage right now. I'm thinking of:

In [1]: import numpy as np                                                     

In [2]: import torch                                                           

In [3]: x = np.arange(3)                                                       

In [4]: t = torch.tensor(x)  # copies data                                     

In [5]: t2 = torch.as_tensor(x)  # shares memory                               

In [6]: x[0] = 9                                                               

In [7]: x                                                                      
Out[7]: array([9, 1, 2])

In [8]: t                                                                      
Out[8]: tensor([0, 1, 2])

In [9]: t2                                                                     
Out[9]: tensor([9, 1, 2])

Now we have a third type of construction, which doesn't copy but also doesn't share - instead it consumes. So what comes to mind at the Python level is something like a __dlpack__ method plus a constructor name similar to as_tensor for this behaviour.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

Thanks @rgommers, re docs: agree, the orginal purpose of dlpack is to specify the C data structure, where most of the rationales are documented in the C header file. On the other hand, it would be useful to write down the python API calling conventions, and provide more docs on the area.

Clarification wrt "consume exactly one": It does not mean that we are moving the memory from numpy to torch. Instead, the convention means that the PyCapsule can only be consumed exactly once. The exporter(that calls to_dlpack) still retains the memory.

To rephrase your example using the to_dlpack/from_dlpack in the PyTorch API convention

import numpy as np                                                     
import torch                                                           

x = np.arange(3)                                                       
capsule = np.to_dlpack(x)         
# consumes the capsule
t2 = torch.from_dlpack(capsule)

x[0] = 9
print(t2)
>> tensor([9, 1, 2])

# The following code throws because capsule is already consumed.
t3 = torch.from_dlpack(capsule)

The way things works is that when the consumer choose to de-allocate later, it will call into the deleter in the DLManagedTensor. A common implementation of a deleter will then decrease the refcount to the array object.

For example, in order to implement np.to_dlpack, we will call PyIncRef on the numpy object, and put the object pointer into the manager_ctx field. Then the deleter will call into PyIncRef.

The memory will be released only after both x and t2 goes out of scope. Notably, one can choose to not consume the capsule at all. In that case, the PyCapsule will call the deleter instead, and there won't be any memory leak.

So to sum up, the above mechanism should be aligned with the example you provide. For example, we could just redirect __dlpack__ to to_dlpack. And call from_dlpack from the as_tensor function.

@rgommers
Copy link
Member

rgommers commented Aug 18, 2020

Thanks @tqchen, makes sense. Is there a reason then for the consume-once? Maybe related to how some device support and memory management functions?

It's a little confusing, for example this works fine:

import jax
import jax.dlpack

import torch
import torch.utils.dlpack

j = jax.numpy.arange(3)
capsule = jax.dlpack.to_dlpack(j)
t = torch.utils.dlpack.from_dlpack(capsule)

But run the exact same code interactively, and you get a RuntimeError (presumably because the interpreter makes a call to __repr__ or something similar):

In [2]: %paste
import jax
import jax.dlpack

import torch
import torch.utils.dlpack
## -- End pasted text --

In [3]: j = jax.numpy.arange(3)

In [4]: capsule = jax.dlpack.to_dlpack(j)

In [5]: t = torch.utils.dlpack.from_dlpack(j)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-60aa16dd583b> in <module>
----> 1 t = torch.utils.dlpack.from_dlpack(j)

RuntimeError: from_dlpack received an invalid capsule. Note that DLTensor capsules can be consumed only once, so you might have already constructed a tensor from it once.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

The consume once requirement comes from how the memory management is done in the DLPack -- we will need a language agnostic way to signal memory-recycling.

In particular, the DLManagedTensor contains a deleter that allows the consumer to signal that the tensor is no longer needed. Because the way the signature is designed, we need to make sure that there is a sole consumer of the DLManagedTensor so it is only called once when the consumer no longer needs the memory(otherwise it will cause a double free).

Of course, we can also change the signature to include refcounting(e.g. call IncRef when there is a copy) in DLManagedTensor, however, that means additional requirement that not every exporter might support.

Your particular repr code contains a typo t = torch.utils.dlpack.from_dlpack(j) => t = torch.utils.dlpack.from_dlpack(capsule)

@rgommers
Copy link
Member

Your particular repr code contains a typo

Oops, sorry for the noise - doing too many things at once.

Of course, we can also change the signature to include refcounting .... however, that means additional requirement that not every exporter might support.

Yes, I'm not trying to suggest changes, just trying to wrap my head around how things work and the Python API. There's view-vs-copy semantics there as well, e.g. if I construct a torch.Tensor from a numpy.ndarray, they share memory and mutating the torch.Tensor affects both (in your example). Doing the same with PyTorch + JAX one can still mutate the torch.Tensor, but that doesn't affect the (immutable) JAX array.

@tqchen
Copy link
Author

tqchen commented Aug 18, 2020

In the specific case of DLPack, the data content should be able to mutate (as in the numpy example) from consumer's PoV. I do not know what is happening in the JAX case, perhaps what is happening is that they generate a copy (to preserve immutablity) instead.

@szha
Copy link
Member

szha commented Aug 19, 2020

they share memory and mutating the torch.Tensor affects both (in your example)

This would require coordination in asynchronous setting, and I'm not sure if we'd want to make the explicit requirement that this data exchange solves the coordination on writing to the shared space. Also, regarding view, I think requiring anything beyond a read-only view may be troublesome as it takes extra care to deal with effect in a compiler. It might be better to leave that decision to each framework.

@tqchen
Copy link
Author

tqchen commented Aug 20, 2020

The read-only view is fine. My take is that a generalization of read-only(move of ownership) also makes sense. In terms of async write, if both uses the same stream, the behavior will still be correct. But I agree that it is something that can be defined as per framework behavior.

@rgommers
Copy link
Member

@kkraus14 gave the feedback that for RAPIDS the Python level usage of DLPack has been troublesome to support, due to the semantics of "delete on consumption". And that regular Python refcounting behavior (e.g. like __cuda_array_interface__) is easier to support. @kkraus14 if you have specific issues you can link to here, that would be helpful.

@honnibal was positive about the spaCy/Thinc interop with PyTorch via DLPack on Twitter. @honnibal do you have any more thoughts on this? Anything you would add/change?

@kkraus14
Copy link

I think it's more that there isn't an official Python container / spec anywhere, but everyone has followed suite of using a PyCapsule object and changing its name on use: https://github.com/rapidsai/cudf/blob/branch-0.16/python/cudf/cudf/_lib/dlpack.pyx#L32-L34

Then the deletion behavior is controlled based on the name: https://github.com/rapidsai/cudf/blob/branch-0.16/python/cudf/cudf/_lib/dlpack.pyx#L84-L93

On the other hand for __cuda_array_interface__ everything is just based on Python refcounting and garbage collection. That being said, this does leave issues for when users want to hand the lifetime management down to a C/C++ layer.

@tqchen
Copy link
Author

tqchen commented Aug 24, 2020

To summarize, the support for deletion is fine via PyCapsule, except that due to the dependency on the "use_dltensor" is a bit twisted. The deletion code will need to check that field as per the Cython code linked by @kkraus14, however functionality wise it works fine.

The way we use PyCapsule spec itself can also be changed(however that also requires potential PRs to the frameworks). For example, another possible cleaner way to is to simply consume the capsule and set the deleter to None.

rgommers added a commit to data-apis/array-api that referenced this issue Sep 8, 2020
This is incomplete, but there seems to be agreement that DLPack
is the best candidate for an interchange mechanism. So let's
document that, and improve the content as we go.

See data-apis/consortium-feedback#1
for the RFC about standardizing on DLPack. Design discussion
should be picked up there.
rgommers added a commit to data-apis/array-api that referenced this issue Sep 9, 2020
This is incomplete, but there seems to be agreement that DLPack
is the best candidate for an interchange mechanism. So let's
document that, and improve the content as we go.

See data-apis/consortium-feedback#1
for the RFC about standardizing on DLPack. Design discussion
should be picked up there.
@oleksandr-pavlyk
Copy link

It would be very useful to see an OpenCL implementation of DLPack interoperability, specifically the use of DLContext.device_id.

Suppose of SYCL application would like to share data allocated via SYCL Unified Shared Memory. The USM shared memory is bound to a SYCL context, which a receiver needs to make sense of the pointer. The only way for the exported to pass it along in the DLTensor is to understand that DLTensor.data be a pointer to a struct, that holds the USM pointer and the associated SYCL context.

Is this going against the grain of intended DLPack usage?

@tqchen
Copy link
Author

tqchen commented Sep 14, 2020

@oleksandr-pavlyk We will then need to define a a common way to refer to a device context. For example, in the case of CUDA, the devices can be simply referred to by numbers. If there is additional convention that is agreed upon between the applications(e.g. what does SYCL context 0 mean) then such exchange is possible like in the case of CUDA.

My understanding is some level of standardization is necessary, if each of the application still like to hold their own SYCL context, then it is harder for such exchange like in the case of CUDA, as is not very realistic for application to understand the device context from the another application.

@oleksandr-pavlyk
Copy link

@tqchen The sycl context is not an int, (see https://developer.codeplay.com/products/computecpp/ce/api-reference/classcl_1_1sycl_1_1context). It may encapsulate a sequence of devices on a common platform (using the same driver).
In SYCL data transfer between devices in the same context can be optimized by SYCL runtime to be done directly avoiding the host.

Here is a table comparing CUDA-world entities to SYCL-world ones: https://developer.codeplay.com/products/computecpp/ce/guides/sycl-for-cuda-developers/migration

In the case of OpenCL, the DLTensor.data is documented to point to cl_mem object which encapsulates the OpenCL context.

One way of using DLPack to share data referenced by USM pointers is to for the receiver and the exporter to agree that DLTensor.data will point to a struct with two void* members, one being the USM pointer, the other being reference to cl::sycl::context.

@tqchen
Copy link
Author

tqchen commented Sep 21, 2020

Thanks @oleksandr-pavlyk . I understand your proposal(and I made that remark in the last comment) and how SYCL works.

However, as my last comment. Passing cl::context around would require the consumer to make use of the sycl context being passed from another application.

From the application developer's PoV, such additional flexibility from the data structure side can increase the overhead of the application development (I am speaking with my past experiences developing deep learning systems). Since most of the applications would like to manage their own context, and may not be ready to directly use a context passed externally (e.g. due to the need of synchronization with other internal data under internal context etc).

So in this case a programming model like CUDA is still desirable. If SYCL or applications can agree on a set of context(e.g. put them in a table) before hand, and use integer to refer to these contexts. Of course there is not standardization around this area yet.

@honnibal
Copy link

@honnibal was positive about the spaCy/Thinc interop with PyTorch via DLPack on Twitter. @honnibal do you have any more thoughts on this? Anything you would add/change?

To flesh out a little what we're doing:

We use DLPack to exchange arrays between CuPy and PyTorch, which is allowing us to backprop through layers implemented in different frameworks. We're also using DLPack inside a custom allocator for CuPy. Instead of having CuPy ask for memory directly from the device, the allocator gets memory from PyTorch, and passes it to CuPy via DLPack. This prevents memory contention between the libraries. I haven't tested the MXNet integration very heavily, but we expect the MXNet interoperation to work the same. We've been eagerly awaiting TensorFlow support for DLPack. Heck, we'd settle for even a way to get a buffer with a device copy. Currently we can't communicate with TensorFlow without copying data via the CPU, which I find quite unbelievable.

So far things are working fine. However, I understand that the DLPack standard may introduce complexities that I'm not seeing, as I'm relying on other people's implementations. We would have no problem adopting a different standard instead.

@rgommers
Copy link
Member

Having some form of standardization probably would be useful.

I'm inclined to add a note of caution to the API standard doc now. Agreed it would be useful and is probably going to become more important over time.

@tqchen
Copy link
Author

tqchen commented Dec 16, 2020

I agree, just call it out to clarify the status

@leofang
Copy link

leofang commented Dec 18, 2020

Sorry to bring up a question if this was already discussed somewhere 😅 I am a newcomer here trying to catch up with the massive discussion:

Is this device_id guaranteed to be consistent between libraries for all device types, corresponding to the way the OS driver (e.g. CUDA) labels them?

Why don't we look up which device it is through the cudaPointerAttributes/hipPointerAttribute_t struct associated with the device pointer? This would be guaranteed to work on NVIDIA/AMD GPUs, at the driver level, so in theory DLPack doesn't even need to contain this information, just the pointer address and the array metadata. At least this is what CuPy does when encountering unowned memory (allocated from other libraries).

I imagine OpenCL/SYCL might have similar look-up capability, but I am not familiar with them enough and need to do my homework.

@tqchen
Copy link
Author

tqchen commented Dec 21, 2020

@leofang The specific property really depends on how the driver is implemented. While it can be true for unified memory model(CUDA, rocm case). Such API is not guaranteed for opaque memory address(in the case of opencl, vulkan, metal).

@leofang
Copy link

leofang commented Dec 30, 2020

Ah I missed it, sorry @tqchen!

Such API is not guaranteed for opaque memory address(in the case of opencl, vulkan, metal).

Thanks, it's good to confirm. I suppose OpenCL is the most important player for the purpose of Array API.

@rgommers
Copy link
Member

I opened data-apis/array-api#106 to add relevant content from this discussion to the API standard document.

@rgommers
Copy link
Member

There's still the issue of where to put DLPack docs, right now it's mostly in the C header. High-level docs like purpose, scope, semantics and Python API are missing in the dmlc/dlpack repo - as discussed higher up at #1 (comment).

Links to implementations and helpful content like how to put together a ctypes or cffi interface, are mostly contained in this discussion. We could put them in a separate Sphinx-generated site and host it from this org. Using the same theme as https://data-apis.github.io/array-api/latest/, and making it a similar API.

Or we could just add more docs to https://github.com/dmlc/dlpack, either in its README or with html docs. It's mostly up to your preference I think @tqchen, what do you think? I'm happy to help either way.

@tqchen
Copy link
Author

tqchen commented Dec 31, 2020

We are open to both options. Given it is simple enough we agree that we could work to improve data-apis/array-api#106 and cross reference.

@rgommers
Copy link
Member

Given it is simple enough we agree that we could work to improve data-apis/array-api#106 and cross reference.

That sounds good to me.

@oleksandr-pavlyk
Copy link

I think it is important that DLPack supports sharing of data referenced by SYCL 2020's universal shared memory (USM) pointers. SYCL implementation may be based on OpenCL, but need not be. For example oneAPI's DPCPP defaults to using Level-zero. This calls for extending device types enum, for example with kDLDPCPP.

The USM pointer is bound to a SYCL context. A device kernel that accesses the pointer must be submitted to a queue associated with this same context, or an asynchronous error is thrown.

DLPack exporter and receiver must either explicitly assume to use a common SYCL context (i.e. mismatch is a documented user error), or the DLPack must provide for the receiver to obtain the SYCL context the USM pointer is bound to, if only to make a copy, but ideally to submit kernels to a queue associated with that context.

Attempting to accomplish the latter, the __sycl_usm_array_interface__ was proposed to share USM referenced memory. It provides a SYCL object syclobj which can be either a SYCL context or a SYCL queue Python objects defined by dpctl with the guarantee that the data referenced by the USM pointer can be accessed using this SYCL object.

The question of SYCL context is related to device_id. Having a USM pointer and the associated SYCL context, it is possible to retrieve the device where it was allocated using sycl::get_pointer_type(void*, sycl::context &). The closest SYCL comes to supporting a device_id is that triple {enum sycl::backend, enum sycl::info::device_type, int relative_id} can be used to select a parent device, but this can not identify SYCL::device obtained using sycl::device::create_sub_devices method.

[Quote from SYCL 2020 provisional standard, ch. 4.6.4]
A SYCL device can be partitioned into multiple SYCL devices, by calling the create_sub_devices() member
function template. The resulting SYCL devices are considered sub devices, and it is valid to partition these sub devices further.

I would therefore urge DLPack to consider adding a way to provide additional meta-data associated with the data. This could be used to provide device_id, or reference to SYCL context/queue, or some other information specific for each device kind.

@tqchen
Copy link
Author

tqchen commented Jan 5, 2021

Thanks @oleksandr-pavlyk. Happy discuss further.

However, as my last comment. Passing cl::context around would require the consumer to make use of the sycl context being passed from another application. From the application developer's PoV, such additional flexibility from the data structure side can increase the overhead of the application development (I am speaking with my past experiences developing deep learning systems). Since most of the applications would like to manage their own context, and may not be ready to directly use a context passed externally (e.g. due to the need of synchronization with other internal data under internal context etc).

As explained in my earlier reply(in the above). The main problem here is SYCL and OpenCL's lack of standardization in terms of the "default context". From a developer's PoV I personally think this is a hidden gem of CUDA that can be ignored.

I fully understand your proposal that from a purely feasibility perspective:

  • F0: Each application can create their own context, which is more flexible, everyone can enjoy the local context without worrying about conflicting with each other.
  • F1: If the context A is being passed to another app, ideally another app submit kernels on context A, so no copy is involved.

The reality though, the lack of standardization of default means frameworks will usually create their own internal context and manage them as part of their scheduler. So it is quite hard technically to create a solution that uses F1, since most of the async scheduling are already happening internally, using framework maintained context. Additionally, because there wasn't a default to agree upon, copy is just inevitable. Note that this is not a feasibility argument, just how frameworks and schedulers are being built.

The standardization default in CUDA, although perhaps a bit restricted, brings simplicity, and things people can agree on. Usually we can rest assured that torch.cuda(0) and mxnet.gpu(0) refers to the same thing, as long as they also agrees on the stream. Such simplicity and commonality is really powerful for both developers and users, which also makes it easy to do standarization.

In sum, IMHO think this is a lesson that SYCL or DPCPP should learn from CUDA. Tiny details like this set CUDA and SYCL apart and and matters a lot actually when we start to about exchange between frameworks(as common ground is needed).

The solution to this problem is also not hard. Create a standard context table API that can allow users to easily refer to as common_context.dpcpp(0), without worrying about context creation, and encourage frameworks to use them. Such standardization can happen in SYCL, DPCPP or even data-api, but I guess would be useful as future topics. In CUDA the standardization happens in driver which makes agreements really easy.

@oleksandr-pavlyk since you are involved in SYCL/DPCPP, perhaps it is also a feedback that you could bring back to that community as well.

@oleksandr-pavlyk
Copy link

@tqchen Thanks for the feedback. I am in agreement that existence of default context per device brings simplicity, and drives adoption. I have communicated this feedback to the DPCPP community as well. This definitely should be the suggested workflow for Python ecosystem.

However, I feel like there should be a way to support data sharing among applications that do not use the default context, albeit not intended for the mainstream usage. This might be relevant in the future when sharing data allocated in one package on a sub-device with another package that intends to use the whole device. Since default context is per device, a device and its subdevices can not share the same default context.

To enable data transfer, packages need to negotiate contexts, either by telling the exporter which context data must be bound to, or by allowing the exporter to share the context with consumers, ideally in the metadata accompanying the data being shared.

@tqchen
Copy link
Author

tqchen commented Jan 6, 2021

@oleksandr-pavlyk Thanks. Right now there are discussions about possible stream support in dmlc/dlpack#57

The current S0 proposal allows the producer and consumer to exchange the stream(context) via an optional stream parameter, which I believe we could also specify later for SYCL and dpcpp.

Speficially, the meta-data requested by the consumer can be provided through the stream parameter, and producer can perform the synchronization necessary. The pros/cons of different API styles are also discussed in that thread. The API should help for the case where specific stream/context need to be exchanged, while not necessarily embed that info into the data structure itself.

@leofang
Copy link

leofang commented Jan 27, 2021

For those who are interested: Complex number support for DLPack is being added in dmlc/dlpack#58 to get ready for the Array API (see data-apis/array-api#102 / data-apis/array-api#105).

@tqchen
Copy link
Author

tqchen commented May 7, 2021

Thanks everyone for great discussions. This RFC can be closed now that DLPack exchange is officially in the Array API

@tqchen tqchen closed this as completed May 7, 2021
@rgommers
Copy link
Member

rgommers commented May 7, 2021

Thanks @tqchen!

@jroesch
Copy link

jroesch commented May 7, 2021

I know I didn't participate at all in the process but awesome to see this happen, thanks for all the hard work and discussion folks!

@leofang
Copy link

leofang commented May 7, 2021

Thanks @tqchen and everyone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for comments
Projects
None yet
Development

No branches or pull requests