Skip to content

Commit

Permalink
[Bugfix] Fix Disco-CUDAGraph Integration
Browse files Browse the repository at this point in the history
This PR fixed a bug introduced in #15827 since which the cudagraph's
stream is discarded.
  • Loading branch information
junrushao committed Oct 4, 2023
1 parent 9f0ac49 commit 88536e1
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,19 @@ struct CCLThreadLocalContext {
deviceStream_t default_stream = nullptr;
ncclComm_t comm;

void Clear() { NCCL_CALL(ncclCommDestroy(comm)); }
void Clear() {
NCCL_CALL(ncclCommDestroy(comm));
if (default_stream != nullptr) {
StreamDestroy(default_stream);
}
}

deviceStream_t GetDefaultStream() { return nullptr; }
deviceStream_t GetDefaultStream() {
const auto* func = tvm::runtime::Registry::Get("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream");
ICHECK(func != nullptr);
deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator void*());
return stream == nullptr ? default_stream : stream;
}

static CCLThreadLocalContext* Get() {
thread_local static CCLThreadLocalContext ctx;
Expand Down

0 comments on commit 88536e1

Please sign in to comment.