Skip to content

Commit

Permalink
[Bugfix] Fix Disco-CUDAGraph Integration
Browse files Browse the repository at this point in the history
This PR fixed a bug introduced in #15827 since which the cudagraph's
stream is discarded.
  • Loading branch information
junrushao committed Oct 4, 2023
1 parent 9f0ac49 commit e79e582
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,19 @@ struct CCLThreadLocalContext {
deviceStream_t default_stream = nullptr;
ncclComm_t comm;

void Clear() { NCCL_CALL(ncclCommDestroy(comm)); }
void Clear() {
NCCL_CALL(ncclCommDestroy(comm));
if (default_stream != nullptr) {
StreamDestroy(default_stream);
}
}

deviceStream_t GetDefaultStream() { return nullptr; }
deviceStream_t GetDefaultStream() {
const auto* func = tvm::runtime::Registry::Get("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream");
ICHECK(func != nullptr);
deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator void*());
return stream == nullptr ? default_stream : stream;
}

static CCLThreadLocalContext* Get() {
thread_local static CCLThreadLocalContext ctx;
Expand Down Expand Up @@ -171,6 +181,9 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
// Step up local context of NCCL
int device_id = device_ids[worker->worker_id];
SetDevice(device_id);
#if TVM_NCCL_RCCL_SWITCH == 0
StreamCreate(&ctx->default_stream);
#endif
Device device{TVM_DISCO_DEVICE_TYPE, device_id};
worker->default_device = device;
worker->ccl = TVM_DISCO_CCL_NAME;
Expand Down

0 comments on commit e79e582

Please sign in to comment.