-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First-class support for execution on user-provided CUDA stream & context #6255
Comments
The intended way to do it is indeed clobbering those symbols, ideally in the main executable so that contexts and streams are managed centrally rather than in a shared library. But yeah, we could also make them replaceable by function pointer, as we do with things like set_custom_do_par_for |
My main executable is Python and the target library is DALI. There will usually be some DL framework alongside with it (PyTorch, TF or MXNet) and possibly other libraries, so the conflict is likely. The fact that there is a helper to do this with PyTorch makes me quite uncomfortable. |
It is basically a missing feature and the The main feature I've had in mind for a while is to be able to change the names of runtime symbols for a given compilation. This will allow having multiple independent runtimes in the same executable. However, it does not solve the issue of managing separate contexts of execution when using the same runtime implementation. That requires rethinking |
I think that runtime parameters (such as context, stream, threadpool implementation and the like) should be somehow passed to |
The current way to do that is to make an object that represents your runtime and pass it as a user_context argument. Then you have a single override of the runtime that just casts the void * user context back to your abstract runtime interface and calls a virtual method on it. |
I'm not sure how well all this plays with JIT though. Most of our runtime-replacement shenanigans are in the context of AOT compilation. |
Does the trick with |
Symbol replacement with JIT is definitely a different beast. It's done at runtime here: https://github.com/halide/Halide/blob/master/src/JITModule.cpp#L184 I'm not sure what the right way is to get a user-provided symbol into that resolution procedure. @zvookin should know. I suppose we could plumb a |
It's possible that the map passed to Pipeline::set_jit_externs might resolve before the runtime modules. If not we could change it so that it does. |
I don't quite follow how |
You might be able to do something like my_pipeline.set_jit_externs({{"halide_cuda_acquire_context", &my_acquire_context}, ...}) to dynamically override parts of the runtime on a per-Pipeline basis in JIT mode. |
Wow, that's seems like an abuse... But if it works, then it'll do until a proper solution is in place. |
I give it a solid 30% chance of working :) |
Alas, no it doesn't. I think this only works for direct entrypoints that the pipeline calls into the runtime, not for calls that the runtime makes to itself. |
Does it mean that shadowing |
It will work in an AOT compilation. |
AOT is not an option for me - I specifically wanted to use Halide for its JIT. I'm not trying to solve issues related to scheduling/optimization/... but rather configurability and fusion - inherently run-time things. |
Here's a (horrifying) workaround for you pending a real solution. You can dynamically find symbols in the runtime and clobber them like so.
Here I just assign to the context variable that acquire_context returns by default. You could also possibly clobber the function acquire_context, but overwriting memory containing code seems even more horrifying than clobbering the memory containing a state variable. |
Note that the above is not at all thread-safe. That state variable is normally protected by a mutex. |
I see that cuda.cpp is in itself quite horrifying... global context doesn't quite cut it in multi-gpu systems. If it was at least thread-local... |
The runtimes are supposed to be things that just work for most casual users, but big users provide their own application-specific code to replace much of the runtime. These users use Halide as an AOT compiler, so the JIT story is much less well fleshed out. thread_local doesn't really work for us because the Halide pipeline itself may do work in a thread pool. |
Let me ask a different question then - how long (ballpark figure) would it take to: Od course, if I go ahead with modifying Halide, I can work on my fork, but only if there's a reasonable chance that the changes will be merged in not-so-distant future (~3 months?). |
The simplest fix is to treat acquire_context and friends like custom_malloc, custom_do_task, etc, and the other things in the JITHandlers struct that we can override. I think that would take a day or two and be uncontroversial to merge. It would mostly be boilerplatey copying of how custom_do_task works. If you did that, the state would be local to the Pipeline object rather than per-realize call, which is not ideal but is at least somewhat useful. If you're not using CPU parallelism I suppose your override could access a thread local. That approach is sort of kicking the can though. Really we need a way to pass a user_context object through a realize call in JIT mode so that you can pass in your own context object and have your runtime overrides use it meaningfully. Right now in JIT mode we hijack the user context to be pointers to the Pipeline-specific set of overrides, which makes such an approach difficult. |
@abadams Update: So far I see that there are two issues. The first seems to be relatively easy to solve and it would make things prettier at the same time. There are functions like this in JITModule.cpp: int do_task_handler(void *context, halide_task f, int idx,
uint8_t *closure) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_do_task)(context, f, idx, closure);
} else {
return (*active_handlers.custom_do_task)(context, f, idx, closure);
}
}
They could be reworked as: int do_task_handler(void *context, halide_task f, int idx,
uint8_t *closure) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_do_task)(jit_user_context->user_context, f, idx, closure);
// -------------------------------------------------^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
} else {
return (*active_handlers.custom_do_task)(context, f, idx, closure);
}
} It looks as though the actual runtimes don't look at Another thing that needs fixing - and it's, in a way, more complicated - is ErrorBuffer which does use
|
Currently I'm experimenting with adding a new Param type - |
In the draft, I've differentiated default, or "active", JIT handlers and custom JIT handlers - the former will receive JITUserContext pointer as their 1st argument whereas the latter will receive the |
@abadams |
When jitting, a single shared runtime is created. That doesn't necessarily need to be the case. For any given Pipeline we could swap in a different one, e.g. with certain functions deleted and provided by the user instead. Maybe that would be a reasonable technique. Each JIT user of Halide in a process could use a different runtime. Doesn't really solve the problem in the AOT context though. |
In the AOT context we've been kicking around the ability to put runtimes in a custom namespace, again with certain functions deleted, so different users could have their own runtimes customized in different ways with no shared state. |
This solution does not exist yet, right? But even assuming we can do it, I can hardly see how we could keep one instance of a CUDA runtime and replace CUDA-specific functions in it. Promoting things like #include <HalideCuda.h> // this is an imaginary header
#include <Halide.h>
#include <cuda.h>
using namespace Halide;
int main() {
cuInit();
Target host_cuda("host-cuda-user_context");
TargetContext ctx = host_cuda.create_context(); // get some opaque target context handle
Buffer<uint8_t> buf;
CUcontext cuctx;
cuDevicePrimaryCtxRetain(&cuctx, 0);
Halide::Cuda::SetContext(ctx, cuctx);
Halide::Cuda::SetStream(ctx, stream);
buf.copy_to_device(host_cuda, ctx);
Func f = ...;
f.compile_jit(host_cuda);
Buffer<uint8_t> f.realize(ctx, {{ input, buf }});
return 0;
} Now, if the user doesn't provide the target context, a default one will be used, so the "just works" approach still just works. |
Meanwhile, I've modified the CUDA runtime to use a thread_local CUDA context and stream (in addition to a global one) and function that allow the used to modify both the global and the thread-local context (and stream). These functions, however, need to be extracted dynamically from the runtime, which is far from perfect. Also, this doesn't always work since you can mix and match GPU acceleration and parallelization. |
So you're saying you have a short-term workaround? The core devs are still talking about how to make a context object work and what it should contain. The original proposed design (the whole runtime is just an object that you can pass in) had problems with things like dead-stripping not working. |
I have a workaround, but it's ugly, as it still requires that I get some functions from the runtime by name. Moreover, it simply doesn't work if I mix I was thinking a lot about the context objects that are passed about and I think that there are far too few of them, making any sensible customization hard or brittle or impossible, depending on a particular use case.
Could you elaborate? I have to admit I don't know what dead-stripping means in this context. By the way, I don't know what whole runtime means here. If it's something like moving all globals from, say, cuda.cpp, to an object, then it's probably going too far the other way. I think CUDA is an excelent example of a runtime where some parts can - and should - be shared, whereas others should be kept in a context object. What should be shared:
What should be kept in a context:
As in the example above, this context could be used not only in |
The PR #6301 is merely illustrating what works for what I'd call a typical GPU schedule, but I'm very much aware of its lack of generality, as I've stated before. Another (possibly cheap) but also quite ugly, option would be to override the device interface in Target. With that, I could have a runtime defined totally in user code and just pass it to the target. Now, if the device interface pointer is not null, it would be used instead of going through
|
I think I have an idea for a simple solution. I'll try to hack something up this morning. |
Here is part 1 of a solution: #6313 Part 2 of a short-term solution would be adding acquire_context and friends to the JITHandlers struct so that they can be hooked per realize call. A longer term solution would allow for overriding any runtime function when JIT-compiling, instead of just a blessed subset. |
That's great! I though about making JITUserContext explicit, but, being a third party, I thought it was too much of a breaking change. Still, the issue of "and frieds" remains - which "friends" exactly can we handle without leaking too much backend-specific semantics to the handlers? |
Anything that's extern "C" halide_ in any of the HalideRuntime*.h includes is fair game to expose. Those are the things you can clobber in AOT mode. A bunch of these are indeed cuda-specific (the ones in HalideRuntimeCUDA.h), so I'm not worried about having backend-specific things in JITHandlers. I was thinking of adding halide_cuda_acquire_context, halide_cuda_release_context, and halide_cuda_get_stream to HalideRuntimeCUDA.h, and then also adding those three functions to the JITHandlers struct. |
@abadams I've looked at your PR and I think there's one major gap there: how to use the same set of runtime overrides in |
Oh I forgot about the overloads that take a DeviceAPI. I guess I'll add variants that take a JITUserContext first arg |
Oh yuck, the existing Buffer methods take the context as the last arg. Hrm. |
The Halide::Buffer helpers are supposed to be a direct mirror of the Halide::Runtime::Buffer helpers, except that you pass a DeviceAPI instead of a halide_device_interface_t. This is not perfectly true because they also need a target arg in some cases to resolve things. Nonetheless, I felt it was most consistent with the existing API to add an optional JITUserContext * last arg. |
As I've mentioned before, there's a big "gotcha" there - |
I don't see why that's a gotcha? The default device interface should do fine, if the functions in it are willing to dispatch based on the context, just like with how we handle custom_print, custom_do_par_for, and other runtime functions referenced in JITHandlers |
The sets functions that we have JIT handlers for and the functions in halide_device_interface_t are disjoint. I didn't have a reason to think they would be hooked the same way, since they are not used the same way. device interface int (*device_malloc)(void *user_context, struct halide_buffer_t *buf);
int (*device_free)(void *user_context, struct halide_buffer_t *buf);
int (*device_sync)(void *user_context, struct halide_buffer_t *buf);
int (*device_release)(void *user_context);
int (*copy_to_host)(void *user_context, struct halide_buffer_t *buf);
int (*copy_to_device)(void *user_context, struct halide_buffer_t *buf);
int (*device_and_host_malloc)(void *user_context, struct halide_buffer_t *buf);
int (*device_and_host_free)(void *user_context, struct halide_buffer_t *buf);
int (*buffer_copy)(void *user_context, struct halide_buffer_t *src,
const struct halide_device_interface_t *dst_device_interface, struct halide_buffer_t *dst);
int (*device_crop)(void *user_context,
const struct halide_buffer_t *src,
struct halide_buffer_t *dst);
int (*device_slice)(void *user_context,
const struct halide_buffer_t *src,
int slice_dim, int slice_pos,
struct halide_buffer_t *dst);
int (*device_release_crop)(void *user_context,
struct halide_buffer_t *buf);
int (*wrap_native)(void *user_context, struct halide_buffer_t *buf, uint64_t handle);
int (*detach_native)(void *user_context, struct halide_buffer_t *buf); JIT handlers void (*custom_print)(void *, const char *){nullptr};
void *(*custom_malloc)(void *, size_t){nullptr};
void (*custom_free)(void *, void *){nullptr};
int (*custom_do_task)(void *, halide_task, int, uint8_t *){nullptr};
int (*custom_do_par_for)(void *, halide_task, int, int, uint8_t *){nullptr};
void (*custom_error)(void *, const char *){nullptr};
int32_t (*custom_trace)(void *, const halide_trace_event_t *){nullptr};
void *(*custom_get_symbol)(const char *name){nullptr};
void *(*custom_load_library)(const char *name){nullptr};
void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr}; The functions which I think we should hook for CUDA are different still, but much better suited for device_interface than jit handlers. |
Perhaps we should have a customized |
Hello,
I've looked at the documentation and the headers, but there's no mention of CUDA contexts or streams except in PyTorch helper. The functions
halide_cuda_acquire_context
andhalide_cuda_get_stream
could theoretically do the trick, but they rely on replacing symbols which is quite brittle and I'd expect this method to fail catastrophically if more than one shared object in the process defines these symbols. Is there a better way to supply these? Perhaps they could become members ofhalide_device_interface_t
orhalide_device_interface_impl_t
(the latter is probably better suited for that).Is there a way to achieve that which I've missed, or is it a missing feature?
The text was updated successfully, but these errors were encountered: