Skip to content

Commit

Permalink
Move DDP queue_reduction to C++ (pytorch#12852)
Browse files Browse the repository at this point in the history
Summary:
fully working version by using continuing on goldsborough 's initial version.

waiting on the stream guard to be merged before adding more stream perf logics into the c++ version
Pull Request resolved: pytorch#12852

Differential Revision: D10468696

Pulled By: teng-li

fbshipit-source-id: 8e46d408796973817abfd9dbd6566e0ca5b7a13f
  • Loading branch information
teng-li authored and facebook-github-bot committed Oct 22, 2018
1 parent 8682999 commit 8d3e7e2
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 116 deletions.
32 changes: 30 additions & 2 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,9 @@ def test_dist_broadcast_coalesced(self):
tensors = torch.zeros(10, device=device).chunk(5)

c10d._dist_broadcast_coalesced(
process_group,
tensors,
buffer_size=10,
process_group=process_group)
buffer_size=10)

if not self.is_master:
self.assertEqual(tensors, target)
Expand Down Expand Up @@ -841,6 +841,34 @@ def test_fp16(self):
any(torch.isinf(p.grad).any() for p in ddp_model.parameters())
)

@skip_if_not_nccl
def test_queue_reduction(self):
# Set up process group.
store = c10d.FileStore(self.file.name)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# Get this process' split of devices.
devices = gpus_for_rank(self.world_size)[self.rank]
grads_batch = [(torch.ones(10, device=torch.device('cuda', d)) *
(self.rank + 1)).chunk(5)
for d in devices]

work, local_grad_sum = c10d._queue_reduction(process_group,
grads_batch,
devices)
# The first return value should be the allreduce work item.
self.assertTrue(isinstance(work, c10d.Work))
# The second return value will be the finished allreduced gradients.
self.assertTrue(isinstance(local_grad_sum, torch.Tensor))

# Wait for the allreduce to finish.
work.wait()

# The expected result of the allreduce should be the average
self.assertEqual(local_grad_sum,
torch.ones(10) * (self.world_size + 1) / 2.0)


if __name__ == '__main__':
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"

Expand Down
198 changes: 148 additions & 50 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/hash.h"

#include <unordered_map>
#include <sstream>
#include <limits>
#include <type_traits>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>

#include <THC/THC.h>
#include <THC/THCStream.h>

namespace torch { namespace cuda { namespace nccl {
#include <limits>
#include <sstream>
#include <type_traits>
#include <unordered_map>

namespace torch {
namespace cuda {
namespace nccl {

using namespace at;

Expand All @@ -27,7 +32,7 @@ struct NcclCommList {
std::unique_ptr<ncclComm_t[]> comms;
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
Expand Down Expand Up @@ -62,10 +67,13 @@ struct NcclCommList {

using device_list = std::vector<int>;
// accesses to this object have to be guarded by THC's CudaFreeMutex
static std::unordered_map<device_list, NcclCommList, torch::hash<device_list>> _communicators;
static std::unordered_map<device_list, NcclCommList, torch::hash<device_list>>
_communicators;

ArrayRef<ncclComm_t> _get_communicators(TensorList inputs) {
static auto get_device = [](const at::Tensor& t) -> int { return t.get_device(); };
static auto get_device = [](const at::Tensor& t) -> int {
return t.get_device();
};
device_list devices = fmap(inputs, get_device);
auto it = _communicators.find(devices);
if (it == _communicators.end())
Expand All @@ -78,18 +86,30 @@ ncclDataType_t _get_data_type(const Type& type) {
throw std::runtime_error("Unconvertible NCCL type");
}
switch (type.scalarType()) {
case at::kFloat : return ncclFloat;
case at::kHalf : return ncclHalf;
case at::kDouble : return ncclDouble;
case at::kLong : return ncclInt64;
case at::kInt : return ncclInt;
case at::kChar : return ncclChar;
case at::kByte : return ncclChar;
default: throw std::runtime_error("Unconvertible NCCL type");
case at::kFloat:
return ncclFloat;
case at::kHalf:
return ncclHalf;
case at::kDouble:
return ncclDouble;
case at::kLong:
return ncclInt64;
case at::kInt:
return ncclInt;
case at::kChar:
return ncclChar;
case at::kByte:
return ncclChar;
default:
throw std::runtime_error("Unconvertible NCCL type");
}
}

void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, int output_multiplier) {
void _check_inputs(
TensorList inputs,
TensorList outputs,
int input_multiplier,
int output_multiplier) {
// len(inputs) == len(outputs)
size_t len = inputs.size();

Expand All @@ -99,7 +119,8 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier,

if (len != outputs.size()) {
std::stringstream err;
err << "inputs and outputs sequences have to be of the same length, but got input of length " << len << " and output of length " << outputs.size();
err << "inputs and outputs sequences have to be of the same length, but got input of length "
<< len << " and output of length " << outputs.size();
throw std::runtime_error(err.str());
}

Expand All @@ -111,13 +132,15 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier,
auto input = inputs[i];
auto output = outputs[i];

if (!(input.type().is_cuda() && !input.type().is_sparse()
&& output.type().is_cuda() && !output.type().is_sparse())) {
throw std::runtime_error("input and output elements have to be cuda dense Tensors");
if (!(input.type().is_cuda() && !input.type().is_sparse() &&
output.type().is_cuda() && !output.type().is_sparse())) {
throw std::runtime_error(
"input and output elements have to be cuda dense Tensors");
}

if (!(type == input.type() && type == output.type())) {
throw std::runtime_error("all inputs and outputs must be of the same Tensor type");
throw std::runtime_error(
"all inputs and outputs must be of the same Tensor type");
}

if (!input.is_contiguous() || !output.is_contiguous()) {
Expand All @@ -138,11 +161,13 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier,

// all inputs must be same size
if (input.numel() != numel) {
throw std::runtime_error("all inputs must have the same number of elements");
throw std::runtime_error(
"all inputs must have the same number of elements");
}

if (output.numel() * output_multiplier != numel * input_multiplier) {
throw std::runtime_error("output must be of size input_size * size_multiplier");
throw std::runtime_error(
"output must be of size input_size * size_multiplier");
}
}
}
Expand All @@ -152,8 +177,8 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier,
bool is_available(TensorList tensors) {
#ifdef USE_NCCL
device_set devices;
for (auto & tensor : tensors) {
auto & type = tensor.type();
for (auto& tensor : tensors) {
auto& type = tensor.type();
if (!type.is_cuda() || type.is_sparse())
return false;
if (!tensor.is_contiguous())
Expand All @@ -180,50 +205,123 @@ std::uint64_t version() {
}

namespace {
// NCCL changed the numerical type used for count between NCCL1 and NCCL2.
// So we use the following struct, which gets the type of the second argument
// of T, if T is a function type, with ncclBcast, to get that type statically
// and programmatically.
// NCCL changed the numerical type used for count between NCCL1 and NCCL2.
// So we use the following struct, which gets the type of the second argument
// of T, if T is a function type, with ncclBcast, to get that type statically
// and programmatically.

template<typename T>
struct GetSecondArgType;
template <typename T>
struct GetSecondArgType;

template<typename R, typename Arg0, typename Arg1, typename ...Args>
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
typedef typename std::decay<Arg1>::type type;
};
template <typename R, typename Arg0, typename Arg1, typename... Args>
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
typedef typename std::decay<Arg1>::type type;
};

constexpr auto count_max = std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
}
constexpr auto count_max =
std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
} // namespace

size_t get_max_count() {
return count_max;
}


void broadcast(TensorList tensors, const stream_list& streams, const comm_list& user_comms) {
void broadcast(
TensorList tensors,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
_check_inputs(tensors, tensors, 1, 1);
ncclDataType_t data_type = _get_data_type(tensors[0].type());
int64_t numel = tensors[0].numel();

std::lock_guard<std::mutex> free_mutex(*(THCCachingAllocator_getCudaFreeMutex()));
const auto comms = user_comms.empty() ? _get_communicators(tensors) : ArrayRef<ncclComm_t>(user_comms);
std::lock_guard<std::mutex> free_mutex(
*(THCCachingAllocator_getCudaFreeMutex()));
const auto comms = user_comms.empty() ? _get_communicators(tensors)
: ArrayRef<ncclComm_t>(user_comms);

auto thcState = at::globalContext().lazyInitCUDA();
at::DeviceGuard device_guard;
AutoNcclGroup nccl_group_guard;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
device_guard.set_index(tensors[i].get_device());
// TODO: use current stream
const auto stream = (streams.empty() || !streams[i]) ? nullptr : THCStream_stream(streams[i]);
AT_CHECK(static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ", numel, " elements, which exceeds the "
"maximum NCCL supports (", count_max, ")");
NCCL_CHECK(ncclBcast(tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
const auto stream = (streams.empty() || !streams[i])
? THCState_getCurrentStream(thcState)
: THCStream_stream(streams[i]);
AT_CHECK(
static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ",
numel,
" elements, which exceeds the "
"maximum NCCL supports (",
count_max,
")");
NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void reduce(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
int32_t root,
int32_t op,
c10::optional<std::vector<at::cuda::CUDAStream>> streams,
c10::optional<std::vector<ncclComm_t>> comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
AT_CHECK(
root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");

_check_inputs(inputs, outputs, 1, 1);
const auto len = inputs.size();

ncclDataType_t data_type = _get_data_type(inputs[0].type());

const auto count = inputs[0].numel();
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
auto comms_ref =
comms ? _get_communicators(inputs) : ArrayRef<ncclComm_t>(*comms);

auto thcState = at::globalContext().lazyInitCUDA();

at::DeviceGuard device_guard;
AutoNcclGroup nccl_group_guard;
for (size_t i = 0; i < len; i++) {
device_guard.set_index(inputs[i].device().index());
// Default to the current THC stream
cudaStream_t stream = THCState_getCurrentStream(thcState);

if (streams && (*streams)[i]) {
stream = (*streams)[i].stream();
}
NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
(ncclRedOp_t)op,
root,
comms_ref[i],
stream));
}
#else
throw std::runtime_error("PyTorch built without NCCL support");
AT_ERROR("PyTorch built without NCCL support");
#endif
}

}}}
void reduce(
std::vector<at::Tensor>& inputs,
int32_t root,
int32_t op,
c10::optional<std::vector<at::cuda::CUDAStream>> streams,
c10::optional<std::vector<ncclComm_t>> comms) {
reduce(inputs, /*outputs=*/inputs, root, op, streams, comms);
}
} // namespace nccl
} // namespace cuda
} // namespace torch
Loading

0 comments on commit 8d3e7e2

Please sign in to comment.