Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First-class support for execution on user-provided CUDA stream & context #6255

Open
mzient opened this issue Sep 21, 2021 · 47 comments
Open

Comments

@mzient
Copy link

mzient commented Sep 21, 2021

Hello,
I've looked at the documentation and the headers, but there's no mention of CUDA contexts or streams except in PyTorch helper. The functions halide_cuda_acquire_context and halide_cuda_get_stream could theoretically do the trick, but they rely on replacing symbols which is quite brittle and I'd expect this method to fail catastrophically if more than one shared object in the process defines these symbols. Is there a better way to supply these? Perhaps they could become members of halide_device_interface_t or halide_device_interface_impl_t (the latter is probably better suited for that).

Is there a way to achieve that which I've missed, or is it a missing feature?

@abadams
Copy link
Member

abadams commented Sep 21, 2021

The intended way to do it is indeed clobbering those symbols, ideally in the main executable so that contexts and streams are managed centrally rather than in a shared library.

But yeah, we could also make them replaceable by function pointer, as we do with things like set_custom_do_par_for

@mzient
Copy link
Author

mzient commented Sep 21, 2021

ideally in the main executable so that contexts and streams are managed centrally rather than in a shared library

My main executable is Python and the target library is DALI. There will usually be some DL framework alongside with it (PyTorch, TF or MXNet) and possibly other libraries, so the conflict is likely. The fact that there is a helper to do this with PyTorch makes me quite uncomfortable.

@zvookin
Copy link
Member

zvookin commented Sep 21, 2021

It is basically a missing feature and the set_custom_* stuff doesn't really help. Building the context info into the device structure doesn't make much sense, but the implementation pointers certainly could be different for different replacements of the runtime.

The main feature I've had in mind for a while is to be able to change the names of runtime symbols for a given compilation. This will allow having multiple independent runtimes in the same executable. However, it does not solve the issue of managing separate contexts of execution when using the same runtime implementation. That requires rethinking user_context, which also needs to happen.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

I think that runtime parameters (such as context, stream, threadpool implementation and the like) should be somehow passed to realize. This would allow using one function instance (and importantly - one compilation) in multiple independent realizations.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

The current way to do that is to make an object that represents your runtime and pass it as a user_context argument. Then you have a single override of the runtime that just casts the void * user context back to your abstract runtime interface and calls a virtual method on it.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

I'm not sure how well all this plays with JIT though. Most of our runtime-replacement shenanigans are in the context of AOT compilation.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

I'm not sure how well all this plays with JIT though

Does the trick with halide_cuda_acquire_context and halide_cuda_get_stream work with JIT? Not being able to specify execution environment to a JIT-compiled function is a show stopper for me.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

Symbol replacement with JIT is definitely a different beast. It's done at runtime here: https://github.com/halide/Halide/blob/master/src/JITModule.cpp#L184

I'm not sure what the right way is to get a user-provided symbol into that resolution procedure. @zvookin should know. I suppose we could plumb a map<string, void *> through with the compile_jit call

@abadams
Copy link
Member

abadams commented Sep 22, 2021

It's possible that the map passed to Pipeline::set_jit_externs might resolve before the runtime modules. If not we could change it so that it does.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

I don't quite follow how set_jit_externs would help here, especially with GPU code. I'm not trying to replace a device function.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

You might be able to do something like my_pipeline.set_jit_externs({{"halide_cuda_acquire_context", &my_acquire_context}, ...}) to dynamically override parts of the runtime on a per-Pipeline basis in JIT mode.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

Wow, that's seems like an abuse... But if it works, then it'll do until a proper solution is in place.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

I give it a solid 30% chance of working :)

@abadams
Copy link
Member

abadams commented Sep 22, 2021

Alas, no it doesn't. I think this only works for direct entrypoints that the pipeline calls into the runtime, not for calls that the runtime makes to itself.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

Does it mean that shadowing halide_cuda_acquire_context at link time will not work, either?

@abadams
Copy link
Member

abadams commented Sep 22, 2021

It will work in an AOT compilation.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

AOT is not an option for me - I specifically wanted to use Halide for its JIT. I'm not trying to solve issues related to scheduling/optimization/... but rather configurability and fusion - inherently run-time things.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

Here's a (horrifying) workaround for you pending a real solution. You can dynamically find symbols in the runtime and clobber them like so.

    auto modules = Internal::JITSharedRuntime::get(nullptr, Target{"host-cuda"}, false);
    for (auto &m : modules) {
        auto s = m.find_symbol_by_name("_ZN6Halide7Runtime8Internal4Cuda7contextE");
        if (s.address) {
            // Found it
            *((uint64_t *)s.address) = my_cuda_context;
        }
    }

Here I just assign to the context variable that acquire_context returns by default. You could also possibly clobber the function acquire_context, but overwriting memory containing code seems even more horrifying than clobbering the memory containing a state variable.

@abadams
Copy link
Member

abadams commented Sep 22, 2021

Note that the above is not at all thread-safe. That state variable is normally protected by a mutex.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

I see that cuda.cpp is in itself quite horrifying... global context doesn't quite cut it in multi-gpu systems. If it was at least thread-local...

@abadams
Copy link
Member

abadams commented Sep 22, 2021

The runtimes are supposed to be things that just work for most casual users, but big users provide their own application-specific code to replace much of the runtime. These users use Halide as an AOT compiler, so the JIT story is much less well fleshed out.

thread_local doesn't really work for us because the Halide pipeline itself may do work in a thread pool.

@mzient
Copy link
Author

mzient commented Sep 22, 2021

Let me ask a different question then - how long (ballpark figure) would it take to:
a) add such functionality
b) review and upstream such functionality, if someone at NVIDIA (likely me) was to issue a PR?

Od course, if I go ahead with modifying Halide, I can work on my fork, but only if there's a reasonable chance that the changes will be merged in not-so-distant future (~3 months?).

@abadams
Copy link
Member

abadams commented Sep 22, 2021

The simplest fix is to treat acquire_context and friends like custom_malloc, custom_do_task, etc, and the other things in the JITHandlers struct that we can override. I think that would take a day or two and be uncontroversial to merge. It would mostly be boilerplatey copying of how custom_do_task works. If you did that, the state would be local to the Pipeline object rather than per-realize call, which is not ideal but is at least somewhat useful. If you're not using CPU parallelism I suppose your override could access a thread local.

That approach is sort of kicking the can though. Really we need a way to pass a user_context object through a realize call in JIT mode so that you can pass in your own context object and have your runtime overrides use it meaningfully. Right now in JIT mode we hijack the user context to be pointers to the Pipeline-specific set of overrides, which makes such an approach difficult.

@mzient
Copy link
Author

mzient commented Oct 1, 2021

@abadams Update:

So far I see that there are two issues. The first seems to be relatively easy to solve and it would make things prettier at the same time. There are functions like this in JITModule.cpp:

int do_task_handler(void *context, halide_task f, int idx,
                    uint8_t *closure) {
    if (context) {
        JITUserContext *jit_user_context = (JITUserContext *)context;
        return (*jit_user_context->handlers.custom_do_task)(context, f, idx, closure);
    } else {
        return (*active_handlers.custom_do_task)(context, f, idx, closure);
    }
}

They could be reworked as:

int do_task_handler(void *context, halide_task f, int idx,
                    uint8_t *closure) {
    if (context) {
        JITUserContext *jit_user_context = (JITUserContext *)context;
        return (*jit_user_context->handlers.custom_do_task)(jit_user_context->user_context, f, idx, closure);
        // -------------------------------------------------^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    } else {
        return (*active_handlers.custom_do_task)(context, f, idx, closure);
    }
}

It looks as though the actual runtimes don't look at user_context - and that's understandable, since they don't know what JITUserContext is - so making this change wouldn't break anything except some extremely hacky user code.

Another thing that needs fixing - and it's, in a way, more complicated - is ErrorBuffer which does use custom_context - this prevents the user from providing custom context without also providing a custom error handler. This could be detected, though, and handled in one of the following ways:

  1. Use stateless error handler
  2. Raise an error clearly stating the issue and suggesting a solution (provide custom error handler).

@mzient
Copy link
Author

mzient commented Oct 1, 2021

Currently I'm experimenting with adding a new Param type - UserContextParam which bypasses the check_name and cannot be used as an Expr (I can't see a reason to use it in an expression) - it will be passed to ParamMap, so the signatures of realize functions should remain unaffected.

@mzient
Copy link
Author

mzient commented Oct 5, 2021

In the draft, I've differentiated default, or "active", JIT handlers and custom JIT handlers - the former will receive JITUserContext pointer as their 1st argument whereas the latter will receive the JITUserContext::user_context - this works (read: tests pass) and allows the user to pass the same callbacks to set_custom_xxx in JIT mode and in custom_xxx in halide_device_interface_t in AOT mode. An alternative would be to change the signatures of the JIT handlers that take user_context is explicitly take JITUserContext* instead of void* - after all, it's what they've been receiving all along. This loses the allure of having the same custom runtime function for JIT and AOT, but the default JIT runtime functions would be a litlle bit less brittle.

@mzient
Copy link
Author

mzient commented Oct 5, 2021

@abadams
I'm starting to think that the whole approach of doing it via some universal interface (like halide_device_interface_t) is greatly overdoing things. The original problem is CUDA-specific and perhaps all it needs is a CUDA-specific solution. I know it's generally not how things seem to work in Halide - but I don't know if that's a design or just that there never was much drive towards more explicit control over specific runtimes in JIT mode.
A simple solution would be replace all runtime functions with weak linkage with function pointers, initialized to point to default implementations. This would make replacing the runtime as simple as overwriting these pointers with something custom. The problem with this solution is that it still doesn't help with scenarios where multiple libraries in a process use Halide.

@abadams
Copy link
Member

abadams commented Oct 7, 2021

When jitting, a single shared runtime is created. That doesn't necessarily need to be the case. For any given Pipeline we could swap in a different one, e.g. with certain functions deleted and provided by the user instead. Maybe that would be a reasonable technique. Each JIT user of Halide in a process could use a different runtime.

Doesn't really solve the problem in the AOT context though.

@abadams
Copy link
Member

abadams commented Oct 7, 2021

In the AOT context we've been kicking around the ability to put runtimes in a custom namespace, again with certain functions deleted, so different users could have their own runtimes customized in different ways with no shared state.

@mzient
Copy link
Author

mzient commented Oct 8, 2021

For any given Pipeline we could swap in a different one, e.g. with certain functions deleted and provided by the user instead.

This solution does not exist yet, right? But even assuming we can do it, I can hardly see how we could keep one instance of a CUDA runtime and replace CUDA-specific functions in it. Promoting things like halide_cuda_get_stream to a generic runtime function doesn't seem right - it makes CUDA-specific concepts creep to generic interface and it's not exactly good.
I think that maybe there should be some kind of target context, which would default to a shared one, but which could also be explicitly controlled by some target-specific functions. Something like this:

#include <HalideCuda.h>  // this is an imaginary header
#include <Halide.h>
#include <cuda.h>

using namespace Halide;

int main() {
  cuInit();
  Target host_cuda("host-cuda-user_context");
  TargetContext ctx = host_cuda.create_context();  // get some opaque target context handle
  Buffer<uint8_t> buf;
  CUcontext cuctx;
  cuDevicePrimaryCtxRetain(&cuctx, 0);
  Halide::Cuda::SetContext(ctx, cuctx);
  Halide::Cuda::SetStream(ctx, stream);
  buf.copy_to_device(host_cuda, ctx);
  Func f = ...;
  f.compile_jit(host_cuda);
  Buffer<uint8_t> f.realize(ctx, {{ input, buf }});
  return 0;
}

Now, if the user doesn't provide the target context, a default one will be used, so the "just works" approach still just works.
I know that this changes a lot in how runtimes work, adding an extra parameter to just about any run-time function, but this seems like a clean solution that:
a) allows the runtimes to be configurable without leaking target specific features to generic APIs - TargetContext is an opaque handle
b) allows users to have fine-grained control over target-specific context - we can stuff as much target-specific logic to a runtime as we want (again, because this logic does not leak to the generic interface)
c) allows multiple instances of Halide to work within one process

@mzient
Copy link
Author

mzient commented Oct 8, 2021

Meanwhile, I've modified the CUDA runtime to use a thread_local CUDA context and stream (in addition to a global one) and function that allow the used to modify both the global and the thread-local context (and stream). These functions, however, need to be extracted dynamically from the runtime, which is far from perfect. Also, this doesn't always work since you can mix and match GPU acceleration and parallelization.

@abadams
Copy link
Member

abadams commented Oct 8, 2021

So you're saying you have a short-term workaround? The core devs are still talking about how to make a context object work and what it should contain. The original proposed design (the whole runtime is just an object that you can pass in) had problems with things like dead-stripping not working.

@mzient
Copy link
Author

mzient commented Oct 8, 2021

I have a workaround, but it's ugly, as it still requires that I get some functions from the runtime by name. Moreover, it simply doesn't work if I mix .parallel and CUDA offloading - not that it's likely to happen in production, but it's not forbidden, so a solution which doesn't support that would be rather hard to merge.

I was thinking a lot about the context objects that are passed about and I think that there are far too few of them, making any sensible customization hard or brittle or impossible, depending on a particular use case.
The context objects which I can think of are:

  • runtime context (like JITUserContext)
  • target or device context - the one passed to things like halide_xxx_device_malloc and similar
  • user context - passed to user-defined functions in the pipeline (like JIT externs)
    Perhaps we should go even further and pass a separate context pointer for each custom handler instead of having a shared one - this would allow, for example, to have a proper user_context in user-defined handlers and still pass a pointer to ErrorBuffer for error handler.

The original proposed design (the whole runtime is just an object that you can pass in) had problems with things like dead-stripping not working.

Could you elaborate? I have to admit I don't know what dead-stripping means in this context.

By the way, I don't know what whole runtime means here. If it's something like moving all globals from, say, cuda.cpp, to an object, then it's probably going too far the other way. I think CUDA is an excelent example of a runtime where some parts can - and should - be shared, whereas others should be kept in a context object.

What should be shared:

  • compiled kernels
  • per-CUcontext free lists
  • perhaps some pool of per-device contexts

What should be kept in a context:

  • CUDA context
  • CUDA stream

As in the example above, this context could be used not only in realize, but also in things like host/device and device/device copies.

@mzient
Copy link
Author

mzient commented Oct 9, 2021

The PR #6301 is merely illustrating what works for what I'd call a typical GPU schedule, but I'm very much aware of its lack of generality, as I've stated before.

Another (possibly cheap) but also quite ugly, option would be to override the device interface in Target. With that, I could have a runtime defined totally in user code and just pass it to the target. Now, if the device interface pointer is not null, it would be used instead of going through get_device_interface_for_device_api. Such a custom runtime could have all sorts of non-general features, but being in user code, it would not be a problem.
Of course, this solution has many drawbacks:

  1. It requires the user to reinvent/copy the runtime. The effort of writing ground-up may be hard to justify and copying may pose licensing issues to some customers.
  2. If the device interface changes, these users will have to update their custom runtimes.
  3. In the end, it may not be trivial to implement and require similar effort in codegen to providing a proper device context. The main problem is make_device_interface_call, which selects the device interface by API type, defying attempts to tie the interface to Target.

@mzient
Copy link
Author

mzient commented Oct 12, 2021

@abadams @zvookin Do you have any opinion on which approach is better (TargetContext or configurable device_interface)? Should I even consider solutions that change the signatures of runtime functions?

@abadams
Copy link
Member

abadams commented Oct 12, 2021

I think I have an idea for a simple solution. I'll try to hack something up this morning.

@abadams
Copy link
Member

abadams commented Oct 12, 2021

Here is part 1 of a solution: #6313

Part 2 of a short-term solution would be adding acquire_context and friends to the JITHandlers struct so that they can be hooked per realize call.

A longer term solution would allow for overriding any runtime function when JIT-compiling, instead of just a blessed subset.

@mzient
Copy link
Author

mzient commented Oct 12, 2021

That's great! I though about making JITUserContext explicit, but, being a third party, I thought it was too much of a breaking change.

Still, the issue of "and frieds" remains - which "friends" exactly can we handle without leaking too much backend-specific semantics to the handlers?

@abadams
Copy link
Member

abadams commented Oct 13, 2021

Anything that's extern "C" halide_ in any of the HalideRuntime*.h includes is fair game to expose. Those are the things you can clobber in AOT mode. A bunch of these are indeed cuda-specific (the ones in HalideRuntimeCUDA.h), so I'm not worried about having backend-specific things in JITHandlers. I was thinking of adding halide_cuda_acquire_context, halide_cuda_release_context, and halide_cuda_get_stream to HalideRuntimeCUDA.h, and then also adding those three functions to the JITHandlers struct.

@mzient
Copy link
Author

mzient commented Oct 13, 2021

@abadams I've looked at your PR and I think there's one major gap there: how to use the same set of runtime overrides in Buffer::copy_to_device? Of course, we could add an overload which takes const halide_device_interface_t* instead of DeviceAPI and a user_context, but I think that would be somehow inconsistent to have one part of the API which is customized by shadowing runtime functions in halide_device_interface_t and another part of the API which implements the same customization with a different method (passing a custom JITUserContext).

@abadams
Copy link
Member

abadams commented Oct 13, 2021

Oh I forgot about the overloads that take a DeviceAPI. I guess I'll add variants that take a JITUserContext first arg

@abadams
Copy link
Member

abadams commented Oct 13, 2021

Oh yuck, the existing Buffer methods take the context as the last arg. Hrm.

@abadams
Copy link
Member

abadams commented Oct 13, 2021

The Halide::Buffer helpers are supposed to be a direct mirror of the Halide::Runtime::Buffer helpers, except that you pass a DeviceAPI instead of a halide_device_interface_t. This is not perfectly true because they also need a target arg in some cases to resolve things. Nonetheless, I felt it was most consistent with the existing API to add an optional JITUserContext * last arg.

@mzient
Copy link
Author

mzient commented Oct 14, 2021

As I've mentioned before, there's a big "gotcha" there - make_device_interface_call. This happens in codegen and directly uses DeviceAPI, so we still need to pass some kind of context there. That's where I hit the wall - I ddin't feel comfortable modifying codegen without at least discussing the approaches. That's also what led me to the conclusion, that some kind of target context is probably the cleanest and most future-proof option (but of course I'm just a guest here, so the final design is up to you).

@abadams
Copy link
Member

abadams commented Oct 14, 2021

I don't see why that's a gotcha? The default device interface should do fine, if the functions in it are willing to dispatch based on the context, just like with how we handle custom_print, custom_do_par_for, and other runtime functions referenced in JITHandlers

@mzient
Copy link
Author

mzient commented Oct 15, 2021

The sets functions that we have JIT handlers for and the functions in halide_device_interface_t are disjoint. I didn't have a reason to think they would be hooked the same way, since they are not used the same way.

device interface

    int (*device_malloc)(void *user_context, struct halide_buffer_t *buf);
    int (*device_free)(void *user_context, struct halide_buffer_t *buf);
    int (*device_sync)(void *user_context, struct halide_buffer_t *buf);
    int (*device_release)(void *user_context);
    int (*copy_to_host)(void *user_context, struct halide_buffer_t *buf);
    int (*copy_to_device)(void *user_context, struct halide_buffer_t *buf);
    int (*device_and_host_malloc)(void *user_context, struct halide_buffer_t *buf);
    int (*device_and_host_free)(void *user_context, struct halide_buffer_t *buf);
    int (*buffer_copy)(void *user_context, struct halide_buffer_t *src,
                       const struct halide_device_interface_t *dst_device_interface, struct halide_buffer_t *dst);
    int (*device_crop)(void *user_context,
                       const struct halide_buffer_t *src,
                       struct halide_buffer_t *dst);
    int (*device_slice)(void *user_context,
                        const struct halide_buffer_t *src,
                        int slice_dim, int slice_pos,
                        struct halide_buffer_t *dst);
    int (*device_release_crop)(void *user_context,
                               struct halide_buffer_t *buf);
    int (*wrap_native)(void *user_context, struct halide_buffer_t *buf, uint64_t handle);
    int (*detach_native)(void *user_context, struct halide_buffer_t *buf);

JIT handlers

    void (*custom_print)(void *, const char *){nullptr};
    void *(*custom_malloc)(void *, size_t){nullptr};
    void (*custom_free)(void *, void *){nullptr};
    int (*custom_do_task)(void *, halide_task, int, uint8_t *){nullptr};
    int (*custom_do_par_for)(void *, halide_task, int, int, uint8_t *){nullptr};
    void (*custom_error)(void *, const char *){nullptr};
    int32_t (*custom_trace)(void *, const halide_trace_event_t *){nullptr};
    void *(*custom_get_symbol)(const char *name){nullptr};
    void *(*custom_load_library)(const char *name){nullptr};
    void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};

The functions which I think we should hook for CUDA are different still, but much better suited for device_interface than jit handlers.

@mzient
Copy link
Author

mzient commented Oct 15, 2021

Perhaps we should have a customized get_device_interface_for_device_api instead? This one function would solve most of the problems at hand - but then again, we need to be able to use the same device interface in Buffer functions called outside of call_jit_code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants