Skip to content

Commit

Permalink
Eliminate CUDAStream nullptr in NCCL (pytorch#13089)
Browse files Browse the repository at this point in the history
Summary:
As the title says, we should always use the current stream on device in NCCL.

This can unblock ezyang on his further work
Pull Request resolved: pytorch#13089

Reviewed By: ezyang

Differential Revision: D10847172

Pulled By: teng-li

fbshipit-source-id: 7fc7c4248b5efa1971d2af4d43f62d3379debfe4
  • Loading branch information
teng-li authored and facebook-github-bot committed Oct 25, 2018
1 parent fc1c8f8 commit b4d0dc7
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 75 deletions.
16 changes: 8 additions & 8 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ void broadcast(
const auto comms = user_comms.empty() ? _get_communicators(tensors)
: ArrayRef<ncclComm_t>(user_comms);

auto thcState = at::globalContext().lazyInitCUDA();
at::DeviceGuard device_guard;
AutoNcclGroup nccl_group_guard;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
device_guard.set_index(tensors[i].get_device());
int device = tensors[i].get_device();
device_guard.set_index(device);
const auto stream = (streams.empty() || !streams[i])
? THCState_getCurrentStream(thcState)
? at::cuda::getCurrentCUDAStream(device).stream()
: THCStream_stream(streams[i]);
AT_CHECK(
static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
Expand Down Expand Up @@ -287,14 +287,14 @@ void reduce(
auto comms_ref =
comms ? _get_communicators(inputs) : ArrayRef<ncclComm_t>(*comms);

auto thcState = at::globalContext().lazyInitCUDA();

at::DeviceGuard device_guard;
AutoNcclGroup nccl_group_guard;
for (size_t i = 0; i < len; i++) {
device_guard.set_index(inputs[i].device().index());
// Default to the current THC stream
cudaStream_t stream = THCState_getCurrentStream(thcState);
int device = inputs[i].device().index();
device_guard.set_index(device);

// Default to the current stream
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device).stream();

if (streams && (*streams)[i]) {
stream = (*streams)[i].stream();
Expand Down
Loading

0 comments on commit b4d0dc7

Please sign in to comment.