diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index f9d3ca400abe..bfbae4c790ec 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -140,9 +140,19 @@ struct CCLThreadLocalContext { deviceStream_t default_stream = nullptr; ncclComm_t comm; - void Clear() { NCCL_CALL(ncclCommDestroy(comm)); } + void Clear() { + NCCL_CALL(ncclCommDestroy(comm)); + if (default_stream != nullptr) { + StreamDestroy(default_stream); + } + } - deviceStream_t GetDefaultStream() { return nullptr; } + deviceStream_t GetDefaultStream() { + const auto* func = tvm::runtime::Registry::Get("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream"); + ICHECK(func != nullptr); + deviceStream_t stream = static_cast((*func)().operator void*()); + return stream == nullptr ? default_stream : stream; + } static CCLThreadLocalContext* Get() { thread_local static CCLThreadLocalContext ctx; @@ -171,6 +181,9 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) { // Step up local context of NCCL int device_id = device_ids[worker->worker_id]; SetDevice(device_id); +#if TVM_NCCL_RCCL_SWITCH == 0 + StreamCreate(&ctx->default_stream); +#endif Device device{TVM_DISCO_DEVICE_TYPE, device_id}; worker->default_device = device; worker->ccl = TVM_DISCO_CCL_NAME;