-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Adopt DLPack as cross-language C ABI stable data structure for array exchange #1
Comments
Thanks for bringing up this suggestion! DLPack does look quite promising for array interoperability. From the perspective of Python data APIs, one aspect of DLPack that is not clear to me is how to use it at the level of Python/CPython objects, i.e., the equivalent of Does DLPack even expose a Python object for wrapping DLPack tensors? It looks like right now JAX and PyTorch just use |
not yet, though it's fairly straightforward to expose one. here's one through ctypes
For dlpack, there are two main differences from array interfaces (see https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h#L132-L148):
I believe the former is intentional so that it's easier to conform. The later can (and I think should) be extended in dlpack. |
Right now most of the frameworks we know already conforms the convention by PyTorch/Jax/TF/TVM (these APIs are in python), see for example,
Example APIs
Complex NumberThanks @szha on the comment. We could certainly fold the complex data type as part of DLDataType. However, that might be an interesting topic that can could need a bit more discussion. The main reason is because there are quite a few ways complex number can be stored(e.g. array of struct vs struct of array) for performance reasons, and different frameworks/HW might choose different approach. A more thorough discussion might be necessary |
@tqchen for 1-dim arrays what is the difference between DLPack and Arrow format? |
I would suggest to integrate DLPack with the stream executor API, including async malloc/dealloc, fine-grained read/write barriers, etc., which is a de-facto standard in high performance training frameworks. Without proper compute/transfer stream synchronization between frameworks, pretending accessing the array in device memory space is the same as accessing host memory causes either overhead of global barriers or memory inconsistency for DLPack arrays. |
Thanks @byronyi for the comment about async device stream support, this is something that we have thought very careful about. This is a design tradeoff in terms of how many parts people want to standarize, vs how many part are left over to the frameworks themselves. Most of the deep learning framework has their own internal mechanism for managing aync device computations: for example MXNet has the dependency scheduler, TF-RT has its own scheduler(that relies on its internal future system). While it is totally possible to introduce a broader scope API standardization, by incorporating the stream/executor scheduling. The result is the cost of more standardization, and harder adoption from the frameworks -- what if framework A comes up with another graph scheduler that is faster than the vanilla stream executor? (This is totally possible). So the rationale is given that the allocator / async scheduler part is a bulk piece that is still evolving, we take a more conservative approach by only standardizing the part we can agree on -- namely the data structures. This does not prevent frameworks to still agree on additional conventions during exchange, for example, if the pytorch and TVM uses the same CUDA stream during exchange, there is no need for barriers in synchronization. In many cases, agreeing to the default convention is good enough as a compromise -- for example, usually sync to default CUDA stream is not a bad choice. Now, it is certainly possible to introduce additional layers of standardization of allocator, or scheduler on top of DLPack -- since scheduling and data structure are orthogonal. But based on my experience, this part is still in flux and it is relatively harder to get frameworks' agreement. |
I'd agree with the assessment and we can regard scheduling coordination out of scope for now. |
@aregm wrt to Arrow and DLPack. I believe they are designed with different design goals in mind. Base on my understanding, Arrow is a good format for dataframe exchange. The key rationale is to represent the data in a compact in-memory format that is also friendly to common dataframe related processing. From that perspective, the meta-data is defined with considerations including things like support for non-POD data types, variable length encoding. etc. DLPack focuses more on the computation, and support for more hardware variations(due to the background in deep learning system). As a result there are several key design choices that may not be present in arrow's array. Note that these are all subtle but important design decisions (since the representation of POD-type Array can be as simple as the data pointer plus length). Most of the rationales are documented in the DLPack header file as well, I list some of the choices here:
|
Agreed, this is an interesting topic and fits well with the goals of this consortium. We're starting with Python API standards docs, and I think this would be separate, but makes a lot of sense to treat it in a very similar way. One of the things DLPack doesn't yet seem to have is docs (except for the README and info in the header) - the content of the conversation in this issue tells me more about purpose, scope, use cases and semantics than what I can find in the DLPack repo.
Thanks for mentioning that. It looks to me like the repo with the reference implementation for DLPack is in good hands today, so I wouldn't be in a hurry to move it. If we get consensus on DLPack being standardized, I'd be more inclined to do the docs (including purpose etc. I mentioned above) here, and reference the current repo for implementation.
This makes perfect sense to me, and is how we approach the Python API standardization as well.
I have to say the Python API looks a little awkward to me. Referencing The "consume exactly once" is something that doesn't commonly exist in Python usage right now. I'm thinking of:
Now we have a third type of construction, which doesn't copy but also doesn't share - instead it consumes. So what comes to mind at the Python level is something like a |
Thanks @rgommers, re docs: agree, the orginal purpose of dlpack is to specify the C data structure, where most of the rationales are documented in the C header file. On the other hand, it would be useful to write down the python API calling conventions, and provide more docs on the area. Clarification wrt "consume exactly one": It does not mean that we are moving the memory from numpy to torch. Instead, the convention means that the PyCapsule can only be consumed exactly once. The exporter(that calls to_dlpack) still retains the memory. To rephrase your example using the to_dlpack/from_dlpack in the PyTorch API convention import numpy as np
import torch
x = np.arange(3)
capsule = np.to_dlpack(x)
# consumes the capsule
t2 = torch.from_dlpack(capsule)
x[0] = 9
print(t2)
>> tensor([9, 1, 2])
# The following code throws because capsule is already consumed.
t3 = torch.from_dlpack(capsule) The way things works is that when the consumer choose to de-allocate later, it will call into the deleter in the DLManagedTensor. A common implementation of a deleter will then decrease the refcount to the array object. For example, in order to implement The memory will be released only after both So to sum up, the above mechanism should be aligned with the example you provide. For example, we could just redirect |
Thanks @tqchen, makes sense. Is there a reason then for the consume-once? Maybe related to how some device support and memory management functions? It's a little confusing, for example this works fine:
But run the exact same code interactively, and you get a
|
The consume once requirement comes from how the memory management is done in the DLPack -- we will need a language agnostic way to signal memory-recycling. In particular, the DLManagedTensor contains a deleter that allows the consumer to signal that the tensor is no longer needed. Because the way the signature is designed, we need to make sure that there is a sole consumer of the DLManagedTensor so it is only called once when the consumer no longer needs the memory(otherwise it will cause a double free). Of course, we can also change the signature to include refcounting(e.g. call IncRef when there is a copy) in DLManagedTensor, however, that means additional requirement that not every exporter might support. Your particular repr code contains a typo |
Oops, sorry for the noise - doing too many things at once.
Yes, I'm not trying to suggest changes, just trying to wrap my head around how things work and the Python API. There's view-vs-copy semantics there as well, e.g. if I construct a |
In the specific case of DLPack, the data content should be able to mutate (as in the numpy example) from consumer's PoV. I do not know what is happening in the JAX case, perhaps what is happening is that they generate a copy (to preserve immutablity) instead. |
This would require coordination in asynchronous setting, and I'm not sure if we'd want to make the explicit requirement that this data exchange solves the coordination on writing to the shared space. Also, regarding view, I think requiring anything beyond a read-only view may be troublesome as it takes extra care to deal with effect in a compiler. It might be better to leave that decision to each framework. |
The read-only view is fine. My take is that a generalization of read-only(move of ownership) also makes sense. In terms of async write, if both uses the same stream, the behavior will still be correct. But I agree that it is something that can be defined as per framework behavior. |
@kkraus14 gave the feedback that for RAPIDS the Python level usage of DLPack has been troublesome to support, due to the semantics of "delete on consumption". And that regular Python refcounting behavior (e.g. like @honnibal was positive about the spaCy/Thinc interop with PyTorch via DLPack on Twitter. @honnibal do you have any more thoughts on this? Anything you would add/change? |
I think it's more that there isn't an official Python container / spec anywhere, but everyone has followed suite of using a PyCapsule object and changing its name on use: https://github.com/rapidsai/cudf/blob/branch-0.16/python/cudf/cudf/_lib/dlpack.pyx#L32-L34 Then the deletion behavior is controlled based on the name: https://github.com/rapidsai/cudf/blob/branch-0.16/python/cudf/cudf/_lib/dlpack.pyx#L84-L93 On the other hand for |
To summarize, the support for deletion is fine via PyCapsule, except that due to the dependency on the "use_dltensor" is a bit twisted. The deletion code will need to check that field as per the Cython code linked by @kkraus14, however functionality wise it works fine. The way we use PyCapsule spec itself can also be changed(however that also requires potential PRs to the frameworks). For example, another possible cleaner way to is to simply consume the capsule and set the deleter to |
This is incomplete, but there seems to be agreement that DLPack is the best candidate for an interchange mechanism. So let's document that, and improve the content as we go. See data-apis/consortium-feedback#1 for the RFC about standardizing on DLPack. Design discussion should be picked up there.
This is incomplete, but there seems to be agreement that DLPack is the best candidate for an interchange mechanism. So let's document that, and improve the content as we go. See data-apis/consortium-feedback#1 for the RFC about standardizing on DLPack. Design discussion should be picked up there.
It would be very useful to see an Suppose of SYCL application would like to share data allocated via SYCL Unified Shared Memory. The USM shared memory is bound to a SYCL context, which a receiver needs to make sense of the pointer. The only way for the exported to pass it along in the Is this going against the grain of intended DLPack usage? |
@oleksandr-pavlyk We will then need to define a a common way to refer to a device context. For example, in the case of CUDA, the devices can be simply referred to by numbers. If there is additional convention that is agreed upon between the applications(e.g. what does SYCL context 0 mean) then such exchange is possible like in the case of CUDA. My understanding is some level of standardization is necessary, if each of the application still like to hold their own SYCL context, then it is harder for such exchange like in the case of CUDA, as is not very realistic for application to understand the device context from the another application. |
@tqchen The sycl context is not an Here is a table comparing CUDA-world entities to SYCL-world ones: https://developer.codeplay.com/products/computecpp/ce/guides/sycl-for-cuda-developers/migration In the case of OpenCL, the One way of using DLPack to share data referenced by USM pointers is to for the receiver and the exporter to agree that |
Thanks @oleksandr-pavlyk . I understand your proposal(and I made that remark in the last comment) and how SYCL works. However, as my last comment. Passing cl::context around would require the consumer to make use of the sycl context being passed from another application. From the application developer's PoV, such additional flexibility from the data structure side can increase the overhead of the application development (I am speaking with my past experiences developing deep learning systems). Since most of the applications would like to manage their own context, and may not be ready to directly use a context passed externally (e.g. due to the need of synchronization with other internal data under internal context etc). So in this case a programming model like CUDA is still desirable. If SYCL or applications can agree on a set of context(e.g. put them in a table) before hand, and use integer to refer to these contexts. Of course there is not standardization around this area yet. |
To flesh out a little what we're doing: We use DLPack to exchange arrays between CuPy and PyTorch, which is allowing us to backprop through layers implemented in different frameworks. We're also using DLPack inside a custom allocator for CuPy. Instead of having CuPy ask for memory directly from the device, the allocator gets memory from PyTorch, and passes it to CuPy via DLPack. This prevents memory contention between the libraries. I haven't tested the MXNet integration very heavily, but we expect the MXNet interoperation to work the same. We've been eagerly awaiting TensorFlow support for DLPack. Heck, we'd settle for even a way to get a buffer with a device copy. Currently we can't communicate with TensorFlow without copying data via the CPU, which I find quite unbelievable. So far things are working fine. However, I understand that the DLPack standard may introduce complexities that I'm not seeing, as I'm relying on other people's implementations. We would have no problem adopting a different standard instead. |
I'm inclined to add a note of caution to the API standard doc now. Agreed it would be useful and is probably going to become more important over time. |
I agree, just call it out to clarify the status |
Sorry to bring up a question if this was already discussed somewhere 😅 I am a newcomer here trying to catch up with the massive discussion:
Why don't we look up which device it is through the I imagine OpenCL/SYCL might have similar look-up capability, but I am not familiar with them enough and need to do my homework. |
@leofang The specific property really depends on how the driver is implemented. While it can be true for unified memory model(CUDA, rocm case). Such API is not guaranteed for opaque memory address(in the case of opencl, vulkan, metal). |
Ah I missed it, sorry @tqchen!
Thanks, it's good to confirm. I suppose OpenCL is the most important player for the purpose of Array API. |
I opened data-apis/array-api#106 to add relevant content from this discussion to the API standard document. |
There's still the issue of where to put DLPack docs, right now it's mostly in the C header. High-level docs like purpose, scope, semantics and Python API are missing in the Links to implementations and helpful content like how to put together a ctypes or cffi interface, are mostly contained in this discussion. We could put them in a separate Sphinx-generated site and host it from this org. Using the same theme as https://data-apis.github.io/array-api/latest/, and making it a similar API. Or we could just add more docs to https://github.com/dmlc/dlpack, either in its README or with html docs. It's mostly up to your preference I think @tqchen, what do you think? I'm happy to help either way. |
We are open to both options. Given it is simple enough we agree that we could work to improve data-apis/array-api#106 and cross reference. |
That sounds good to me. |
I think it is important that DLPack supports sharing of data referenced by SYCL 2020's universal shared memory (USM) pointers. SYCL implementation may be based on OpenCL, but need not be. For example oneAPI's DPCPP defaults to using Level-zero. This calls for extending device types enum, for example with The USM pointer is bound to a SYCL context. A device kernel that accesses the pointer must be submitted to a queue associated with this same context, or an asynchronous error is thrown. DLPack exporter and receiver must either explicitly assume to use a common SYCL context (i.e. mismatch is a documented user error), or the DLPack must provide for the receiver to obtain the SYCL context the USM pointer is bound to, if only to make a copy, but ideally to submit kernels to a queue associated with that context. Attempting to accomplish the latter, the The question of SYCL context is related to
I would therefore urge DLPack to consider adding a way to provide additional meta-data associated with the data. This could be used to provide |
Thanks @oleksandr-pavlyk. Happy discuss further.
As explained in my earlier reply(in the above). The main problem here is SYCL and OpenCL's lack of standardization in terms of the "default context". From a developer's PoV I personally think this is a hidden gem of CUDA that can be ignored. I fully understand your proposal that from a purely feasibility perspective:
The reality though, the lack of standardization of default means frameworks will usually create their own internal context and manage them as part of their scheduler. So it is quite hard technically to create a solution that uses F1, since most of the async scheduling are already happening internally, using framework maintained context. Additionally, because there wasn't a default to agree upon, copy is just inevitable. Note that this is not a feasibility argument, just how frameworks and schedulers are being built. The standardization default in CUDA, although perhaps a bit restricted, brings simplicity, and things people can agree on. Usually we can rest assured that In sum, IMHO think this is a lesson that SYCL or DPCPP should learn from CUDA. Tiny details like this set CUDA and SYCL apart and and matters a lot actually when we start to about exchange between frameworks(as common ground is needed). The solution to this problem is also not hard. Create a standard context table API that can allow users to easily refer to as @oleksandr-pavlyk since you are involved in SYCL/DPCPP, perhaps it is also a feedback that you could bring back to that community as well. |
@tqchen Thanks for the feedback. I am in agreement that existence of default context per device brings simplicity, and drives adoption. I have communicated this feedback to the DPCPP community as well. This definitely should be the suggested workflow for Python ecosystem. However, I feel like there should be a way to support data sharing among applications that do not use the default context, albeit not intended for the mainstream usage. This might be relevant in the future when sharing data allocated in one package on a sub-device with another package that intends to use the whole device. Since default context is per device, a device and its subdevices can not share the same default context. To enable data transfer, packages need to negotiate contexts, either by telling the exporter which context data must be bound to, or by allowing the exporter to share the context with consumers, ideally in the metadata accompanying the data being shared. |
@oleksandr-pavlyk Thanks. Right now there are discussions about possible stream support in dmlc/dlpack#57 The current S0 proposal allows the producer and consumer to exchange the stream(context) via an optional stream parameter, which I believe we could also specify later for SYCL and dpcpp. Speficially, the meta-data requested by the consumer can be provided through the stream parameter, and producer can perform the synchronization necessary. The pros/cons of different API styles are also discussed in that thread. The API should help for the case where specific stream/context need to be exchanged, while not necessarily embed that info into the data structure itself. |
For those who are interested: Complex number support for DLPack is being added in dmlc/dlpack#58 to get ready for the Array API (see data-apis/array-api#102 / data-apis/array-api#105). |
Thanks everyone for great discussions. This RFC can be closed now that DLPack exchange is officially in the Array API |
Thanks @tqchen! |
I know I didn't participate at all in the process but awesome to see this happen, thanks for all the hard work and discussion folks! |
Thanks @tqchen and everyone! |
In order for an ndarray system to interact with a variety of frameworks, a stable in-memory data structure is needed.
DLPack is one such data structure that allows exchange between major frameworks. It is developed with inputs from many deep learning system core developers. Highlights include:
The main design rationale of DLPack is the minimalism. DLPack drops the consideration of allocator, device API and focus on the minimum data structure. While still considering the need for cross hardware support(e.g. the data field is opaque for platforms that does not support normal addressing).
It also simplifies some of the design to remove legacy issues(e.g. everything assumes to be row major, strides can be used to support other case, and avoid the complexity to consider more layouts)
After building the frameworks around the related data structures for a while, and see ecosystem grows around it, I am quite convinced that DLPack should be one important candidate, if not the best one for the C ABI array data structure.
Given that array exchange is one goal of the consortium, it would be great to see if dlpack can be used as the stable C ABI structure for array exchange.
If the proposal receives positive response from the community. We would be more than happy to explore options to build a neutral, open governance(e.g. like the apache model) that continues to oversees the evolution of the spec -- for example, donate the dlpack to the data-api consortium or host it in a clean github org.
The text was updated successfully, but these errors were encountered: