From 8d3e7e2fcb860d8c9b25cb6622f1a0aded7f2218 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Mon, 22 Oct 2018 16:00:18 -0700 Subject: [PATCH] Move DDP queue_reduction to C++ (#12852) Summary: fully working version by using continuing on goldsborough 's initial version. waiting on the stream guard to be merged before adding more stream perf logics into the c++ version Pull Request resolved: https://github.com/pytorch/pytorch/pull/12852 Differential Revision: D10468696 Pulled By: teng-li fbshipit-source-id: 8e46d408796973817abfd9dbd6566e0ca5b7a13f --- test/test_c10d.py | 32 ++++- torch/csrc/cuda/nccl.cpp | 198 ++++++++++++++++++++------- torch/csrc/cuda/nccl.h | 47 +++++-- torch/csrc/cuda/python_nccl.cpp | 25 +--- torch/csrc/distributed/c10d/ddp.cpp | 37 ++++- torch/csrc/distributed/c10d/ddp.h | 17 ++- torch/csrc/distributed/c10d/init.cpp | 19 ++- torch/nn/parallel/distributed.py | 28 +--- 8 files changed, 287 insertions(+), 116 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index 64cf4611cd153..6ce7a32f5826c 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -740,9 +740,9 @@ def test_dist_broadcast_coalesced(self): tensors = torch.zeros(10, device=device).chunk(5) c10d._dist_broadcast_coalesced( + process_group, tensors, - buffer_size=10, - process_group=process_group) + buffer_size=10) if not self.is_master: self.assertEqual(tensors, target) @@ -841,6 +841,34 @@ def test_fp16(self): any(torch.isinf(p.grad).any() for p in ddp_model.parameters()) ) + @skip_if_not_nccl + def test_queue_reduction(self): + # Set up process group. + store = c10d.FileStore(self.file.name) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # Get this process' split of devices. + devices = gpus_for_rank(self.world_size)[self.rank] + grads_batch = [(torch.ones(10, device=torch.device('cuda', d)) * + (self.rank + 1)).chunk(5) + for d in devices] + + work, local_grad_sum = c10d._queue_reduction(process_group, + grads_batch, + devices) + # The first return value should be the allreduce work item. + self.assertTrue(isinstance(work, c10d.Work)) + # The second return value will be the finished allreduced gradients. + self.assertTrue(isinstance(local_grad_sum, torch.Tensor)) + + # Wait for the allreduce to finish. + work.wait() + + # The expected result of the allreduce should be the average + self.assertEqual(local_grad_sum, + torch.ones(10) * (self.world_size + 1) / 2.0) + + if __name__ == '__main__': assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process" diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index e769b85f8e2d9..09a004bc2e47e 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -3,15 +3,20 @@ #include "torch/csrc/utils/functional.h" #include "torch/csrc/utils/hash.h" -#include -#include -#include -#include #include +#include + #include #include -namespace torch { namespace cuda { namespace nccl { +#include +#include +#include +#include + +namespace torch { +namespace cuda { +namespace nccl { using namespace at; @@ -27,7 +32,7 @@ struct NcclCommList { std::unique_ptr comms; int ndevices; NcclCommList(const std::vector& devices) - : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { + : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { NCCL_CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data())); } NcclCommList(NcclCommList&& foo) = default; @@ -62,10 +67,13 @@ struct NcclCommList { using device_list = std::vector; // accesses to this object have to be guarded by THC's CudaFreeMutex -static std::unordered_map> _communicators; +static std::unordered_map> + _communicators; ArrayRef _get_communicators(TensorList inputs) { - static auto get_device = [](const at::Tensor& t) -> int { return t.get_device(); }; + static auto get_device = [](const at::Tensor& t) -> int { + return t.get_device(); + }; device_list devices = fmap(inputs, get_device); auto it = _communicators.find(devices); if (it == _communicators.end()) @@ -78,18 +86,30 @@ ncclDataType_t _get_data_type(const Type& type) { throw std::runtime_error("Unconvertible NCCL type"); } switch (type.scalarType()) { - case at::kFloat : return ncclFloat; - case at::kHalf : return ncclHalf; - case at::kDouble : return ncclDouble; - case at::kLong : return ncclInt64; - case at::kInt : return ncclInt; - case at::kChar : return ncclChar; - case at::kByte : return ncclChar; - default: throw std::runtime_error("Unconvertible NCCL type"); + case at::kFloat: + return ncclFloat; + case at::kHalf: + return ncclHalf; + case at::kDouble: + return ncclDouble; + case at::kLong: + return ncclInt64; + case at::kInt: + return ncclInt; + case at::kChar: + return ncclChar; + case at::kByte: + return ncclChar; + default: + throw std::runtime_error("Unconvertible NCCL type"); } } -void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, int output_multiplier) { +void _check_inputs( + TensorList inputs, + TensorList outputs, + int input_multiplier, + int output_multiplier) { // len(inputs) == len(outputs) size_t len = inputs.size(); @@ -99,7 +119,8 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, if (len != outputs.size()) { std::stringstream err; - err << "inputs and outputs sequences have to be of the same length, but got input of length " << len << " and output of length " << outputs.size(); + err << "inputs and outputs sequences have to be of the same length, but got input of length " + << len << " and output of length " << outputs.size(); throw std::runtime_error(err.str()); } @@ -111,13 +132,15 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, auto input = inputs[i]; auto output = outputs[i]; - if (!(input.type().is_cuda() && !input.type().is_sparse() - && output.type().is_cuda() && !output.type().is_sparse())) { - throw std::runtime_error("input and output elements have to be cuda dense Tensors"); + if (!(input.type().is_cuda() && !input.type().is_sparse() && + output.type().is_cuda() && !output.type().is_sparse())) { + throw std::runtime_error( + "input and output elements have to be cuda dense Tensors"); } if (!(type == input.type() && type == output.type())) { - throw std::runtime_error("all inputs and outputs must be of the same Tensor type"); + throw std::runtime_error( + "all inputs and outputs must be of the same Tensor type"); } if (!input.is_contiguous() || !output.is_contiguous()) { @@ -138,11 +161,13 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, // all inputs must be same size if (input.numel() != numel) { - throw std::runtime_error("all inputs must have the same number of elements"); + throw std::runtime_error( + "all inputs must have the same number of elements"); } if (output.numel() * output_multiplier != numel * input_multiplier) { - throw std::runtime_error("output must be of size input_size * size_multiplier"); + throw std::runtime_error( + "output must be of size input_size * size_multiplier"); } } } @@ -152,8 +177,8 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, bool is_available(TensorList tensors) { #ifdef USE_NCCL device_set devices; - for (auto & tensor : tensors) { - auto & type = tensor.type(); + for (auto& tensor : tensors) { + auto& type = tensor.type(); if (!type.is_cuda() || type.is_sparse()) return false; if (!tensor.is_contiguous()) @@ -180,50 +205,123 @@ std::uint64_t version() { } namespace { - // NCCL changed the numerical type used for count between NCCL1 and NCCL2. - // So we use the following struct, which gets the type of the second argument - // of T, if T is a function type, with ncclBcast, to get that type statically - // and programmatically. +// NCCL changed the numerical type used for count between NCCL1 and NCCL2. +// So we use the following struct, which gets the type of the second argument +// of T, if T is a function type, with ncclBcast, to get that type statically +// and programmatically. - template - struct GetSecondArgType; +template +struct GetSecondArgType; - template - struct GetSecondArgType { - typedef typename std::decay::type type; - }; +template +struct GetSecondArgType { + typedef typename std::decay::type type; +}; - constexpr auto count_max = std::numeric_limits::type>::max(); -} +constexpr auto count_max = + std::numeric_limits::type>::max(); +} // namespace size_t get_max_count() { return count_max; } - -void broadcast(TensorList tensors, const stream_list& streams, const comm_list& user_comms) { +void broadcast( + TensorList tensors, + const stream_list& streams, + const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; _check_inputs(tensors, tensors, 1, 1); ncclDataType_t data_type = _get_data_type(tensors[0].type()); int64_t numel = tensors[0].numel(); - std::lock_guard free_mutex(*(THCCachingAllocator_getCudaFreeMutex())); - const auto comms = user_comms.empty() ? _get_communicators(tensors) : ArrayRef(user_comms); + std::lock_guard free_mutex( + *(THCCachingAllocator_getCudaFreeMutex())); + const auto comms = user_comms.empty() ? _get_communicators(tensors) + : ArrayRef(user_comms); + + auto thcState = at::globalContext().lazyInitCUDA(); at::DeviceGuard device_guard; AutoNcclGroup nccl_group_guard; for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) { device_guard.set_index(tensors[i].get_device()); - // TODO: use current stream - const auto stream = (streams.empty() || !streams[i]) ? nullptr : THCStream_stream(streams[i]); - AT_CHECK(static_cast(numel) <= static_cast(count_max), - "Broadcast tensor has ", numel, " elements, which exceeds the " - "maximum NCCL supports (", count_max, ")"); - NCCL_CHECK(ncclBcast(tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream)); + const auto stream = (streams.empty() || !streams[i]) + ? THCState_getCurrentStream(thcState) + : THCStream_stream(streams[i]); + AT_CHECK( + static_cast(numel) <= static_cast(count_max), + "Broadcast tensor has ", + numel, + " elements, which exceeds the " + "maximum NCCL supports (", + count_max, + ")"); + NCCL_CHECK(ncclBcast( + tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream)); + } +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + +void reduce( + const std::vector& inputs, + std::vector& outputs, + int32_t root, + int32_t op, + c10::optional> streams, + c10::optional> comms) { +#ifdef USE_NCCL + using namespace torch::cuda::nccl::detail; + AT_CHECK( + root >= 0 && static_cast(root) < inputs.size(), "invalid root"); + + _check_inputs(inputs, outputs, 1, 1); + const auto len = inputs.size(); + + ncclDataType_t data_type = _get_data_type(inputs[0].type()); + + const auto count = inputs[0].numel(); + std::lock_guard lock(*(THCCachingAllocator_getCudaFreeMutex())); + auto comms_ref = + comms ? _get_communicators(inputs) : ArrayRef(*comms); + + auto thcState = at::globalContext().lazyInitCUDA(); + + at::DeviceGuard device_guard; + AutoNcclGroup nccl_group_guard; + for (size_t i = 0; i < len; i++) { + device_guard.set_index(inputs[i].device().index()); + // Default to the current THC stream + cudaStream_t stream = THCState_getCurrentStream(thcState); + + if (streams && (*streams)[i]) { + stream = (*streams)[i].stream(); + } + NCCL_CHECK(ncclReduce( + inputs[i].data_ptr(), + outputs[i].data_ptr(), + count, + data_type, + (ncclRedOp_t)op, + root, + comms_ref[i], + stream)); } #else - throw std::runtime_error("PyTorch built without NCCL support"); + AT_ERROR("PyTorch built without NCCL support"); #endif } -}}} +void reduce( + std::vector& inputs, + int32_t root, + int32_t op, + c10::optional> streams, + c10::optional> comms) { + reduce(inputs, /*outputs=*/inputs, root, op, streams, comms); +} +} // namespace nccl +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 349d8bcfdf507..afd4ffc3ac5b8 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -1,10 +1,18 @@ #pragma once -#include #include +#include #include +#include + +#include + +#include +#include -namespace torch { namespace cuda { namespace nccl { +namespace torch { +namespace cuda { +namespace nccl { // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers. // Don't use them outside of these files. @@ -32,8 +40,11 @@ struct AutoNcclGroup { }; at::ArrayRef _get_communicators(at::TensorList inputs); -void _check_inputs(at::TensorList inputs, at::TensorList outputs, - int input_multiplier, int output_multiplier); +void _check_inputs( + at::TensorList inputs, + at::TensorList outputs, + int input_multiplier, + int output_multiplier); ncclDataType_t _get_data_type(const at::Type& type); } // namespace detail @@ -42,11 +53,31 @@ using comm_list = std::vector; using stream_list = std::vector; std::uint64_t version(); + bool is_available(at::TensorList tensors); -void broadcast(at::TensorList tensors, - const stream_list& streams = {}, - const comm_list& user_comms = {}); + +void broadcast( + at::TensorList tensors, + const stream_list& streams = {}, + const comm_list& user_comms = {}); size_t get_max_count(); -}}} +void reduce( + const std::vector& inputs, + std::vector& outputs, + int32_t root = 0, + int32_t op = ncclSum, + at::optional> streams = c10::nullopt, + at::optional> user_comms = c10::nullopt); + +void reduce( + std::vector& inputs, + int32_t root = 0, + int32_t op = ncclSum, + c10::optional> streams = c10::nullopt, + c10::optional> user_comms = c10::nullopt); + +} // namespace nccl +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 79a859f20d107..44232aa046867 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -7,8 +7,10 @@ #include "torch/csrc/cuda/THCP.h" #include "torch/csrc/cuda/nccl.h" #include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/functional.h" #include + #include #include @@ -129,29 +131,12 @@ PyObject * THCPModule_nccl_reduce(PyObject *self, PyObject *args) { std::vector inputs = extract_tensors(_inputs); std::vector outputs = extract_tensors(_outputs); - std::vector streams = unpack_streams(_streams, inputs.size()); + std::vector thc_streams = unpack_streams(_streams, inputs.size()); + std::vector streams = fmap(thc_streams); auto user_comms = unpack_comms(_comms, inputs.size()); - THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root"); - with_no_gil([&]{ - _check_inputs(inputs, outputs, 1, 1); - size_t len = inputs.size(); - - ncclDataType_t data_type = _get_data_type(inputs[0].type()); - - int64_t count = inputs[0].numel(); - std::lock_guard lock(*(THCCachingAllocator_getCudaFreeMutex())); - auto comms = user_comms.empty() ? _get_communicators(inputs) : ArrayRef(user_comms); - at::DeviceGuard device_guard; - AutoNcclGroup nccl_group_guard; - for (size_t i = 0; i < len; i++) { - int device = inputs[i].get_device(); - device_guard.set_index(device); - auto stream = (streams[i] == nullptr) ? nullptr : THCStream_stream(streams[i]); - NCCL_CHECK(ncclReduce(inputs[i].data_ptr(), outputs[i].data_ptr(), - count, data_type, (ncclRedOp_t) op, root, comms[i], stream)); - } + torch::cuda::nccl::reduce(inputs, outputs, root, op, streams, user_comms); }); Py_RETURN_NONE; diff --git a/torch/csrc/distributed/c10d/ddp.cpp b/torch/csrc/distributed/c10d/ddp.cpp index 34a63ab613969..d32098d8873ce 100644 --- a/torch/csrc/distributed/c10d/ddp.cpp +++ b/torch/csrc/distributed/c10d/ddp.cpp @@ -3,12 +3,16 @@ #include #include +#include + #include #include #include #include +#include +#include #include namespace c10d { @@ -30,9 +34,9 @@ void copyBroadcastTensorsToReplicas( } // namespace void distBroadcastCoalesced( + ProcessGroup& processGroup, std::vector& tensors, - int64_t bufferSize, - ProcessGroup& processGroup) { + int64_t bufferSize) { auto tensorGroups = torch::utils::take_tensors(tensors, bufferSize); // We store single-element vectors in `flatTensors` because // `ProcessGroup::broadcast` takes a reference to a vector, which must be @@ -91,7 +95,7 @@ void syncParams( if (broadcastBuffers && !bufferData[0].empty()) { // Do an inter-node sync first. - distBroadcastCoalesced(bufferData[0], broadcastBucketSize, processGroup); + distBroadcastCoalesced(processGroup, bufferData[0], broadcastBucketSize); // Then an intra-node sync if we have more than one device. if (devices.size() > 1) { auto result = torch::cuda::broadcast_coalesced( @@ -101,4 +105,31 @@ void syncParams( } } +std::tuple, at::Tensor> queueReduction( + ProcessGroup& processGroup, + std::vector>& gradsBatch, + const std::vector& devices) { + AT_ASSERT(!gradsBatch.empty()); + AT_ASSERT(!devices.empty()); + + // TODO: create a copy stream to do the async memory copy and + // Intra node reduction with profiler perf work + std::vector gradsBatchCoalesced; + for (size_t devIdx = 0; devIdx < devices.size(); ++devIdx) { + at::DeviceGuard guard(devices[devIdx]); + gradsBatchCoalesced.push_back( + torch::utils::flatten_dense_tensors(gradsBatch[devIdx])); + } + + if (devices.size() > 1) { + torch::cuda::nccl::reduce(gradsBatchCoalesced, 0); + } + + gradsBatchCoalesced[0] /= processGroup.getSize(); + + std::vector allreduceInput = {gradsBatchCoalesced[0]}; + auto reductionWork = processGroup.allreduce(allreduceInput); + + return std::make_tuple(reductionWork, gradsBatchCoalesced[0]); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ddp.h b/torch/csrc/distributed/c10d/ddp.h index f3aa61d6017bc..af5969e7ca681 100644 --- a/torch/csrc/distributed/c10d/ddp.h +++ b/torch/csrc/distributed/c10d/ddp.h @@ -1,20 +1,20 @@ #pragma once +#include + #include +#include #include #include +#include #include -namespace c10d { -class ProcessGroup; -} // namespace c10d - namespace c10d { void distBroadcastCoalesced( + ProcessGroup& processGroup, std::vector& tensors, - int64_t bufferSize, - ProcessGroup& processGroup); + int64_t bufferSize); void syncParams( ProcessGroup& processGroup, @@ -23,4 +23,9 @@ void syncParams( const std::vector& devices, int64_t broadcastBucketSize, bool broadcastBuffers); + +std::tuple, at::Tensor> queueReduction( + ProcessGroup& processGroup, + std::vector>& gradsBatch, + const std::vector& devices); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 72bcac27cf875..f1880f789dc89 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,4 +1,4 @@ -#include "torch/csrc/python_headers.h" +#include #include #include @@ -33,8 +33,7 @@ template using shared_ptr_class_ = py::class_>; PyObject* c10d_init(PyObject* _unused) { - auto c10d_module = - THPObjectPtr(PyImport_ImportModule("torch.distributed")); + auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed")); if (!c10d_module) { throw python_error(); } @@ -318,7 +317,8 @@ PyObject* c10d_init(PyObject* _unused) { } else if (!interface.empty()) { attr.iface = interface; } else { - // Neither argument is specified; Gloo itself will use the hostname + // Neither argument is specified; Gloo itself will use the + // hostname // Nothing specified, default to something useful } return ::gloo::transport::tcp::CreateDevice(attr); @@ -381,10 +381,11 @@ PyObject* c10d_init(PyObject* _unused) { module.def( "_dist_broadcast_coalesced", &::c10d::distBroadcastCoalesced, + py::arg("process_group"), py::arg("tensors"), py::arg("buffer_size"), - py::arg("process_group"), py::call_guard()); + module.def( "_sync_params", &::c10d::syncParams, @@ -395,6 +396,14 @@ PyObject* c10d_init(PyObject* _unused) { py::arg("broadcast_bucket_size"), py::arg("broadcast_buffers"), py::call_guard()); + + module.def( + "_queue_reduction", + &::c10d::queueReduction, + py::arg("process_group"), + py::arg("grads_batch"), + py::arg("devices"), + py::call_guard()); #endif Py_RETURN_TRUE; diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 2358c53e0572c..41ef5d03bf447 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -261,7 +261,7 @@ def train(self, mode=True): module.train(mode) def _dist_broadcast_coalesced(self, tensors, buffer_size): - dist._dist_broadcast_coalesced(tensors, buffer_size, self.process_group) + dist._dist_broadcast_coalesced(self.process_group, tensors, buffer_size) def _sync_params(self): if len(self.device_ids) > 1: @@ -355,27 +355,11 @@ def distributed_data_parallel_hook(*unused): return distributed_data_parallel_hook def _queue_reduction(self, bucket_idx): - grads_batch = self.buckets[bucket_idx] - grads_batch_coalesced = [] - - # coalesce the bucket - for dev_id, dev_grads_batch in zip(self.device_ids, grads_batch): - with torch.cuda.device(dev_id): - dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch) - grads_batch_coalesced.append(dev_grads_batch_coalesced) - - # reduce to the first GPU in self.device_ids - if len(self.device_ids) > 1: - nccl.reduce(grads_batch_coalesced, root=0, streams=self.default_streams) - - # divide by the number of processes here to reduce chances of overflow - grads_batch_coalesced[0] /= self.process_group.size() - - # now work on the first gpu - reduction_work = self.process_group.allreduce([grads_batch_coalesced[0]], - self.allreduce_opts) - self.reduction_works[bucket_idx] = reduction_work - self.buckets_coalesced[bucket_idx] = grads_batch_coalesced[0] + result = dist._queue_reduction(self.process_group, + self.buckets[bucket_idx], + self.device_ids) + self.reduction_works[bucket_idx] = result[0] + self.buckets_coalesced[bucket_idx] = result[1] def _sync_reduction_works(self): # Now only work on the first GPU of self.device_ids, uncoalesce