diff --git a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc index 31f9b26e732d1..0911a4a3e3e18 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc @@ -255,12 +255,6 @@ std::shared_ptr ProcessGroupHeter::Broadcast( std::shared_ptr ProcessGroupHeter::Send( std::vector& in_tensors, int peer) { -#if defined(PADDLE_WITH_NCCL) - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); -#endif - PADDLE_ENFORCE_EQ( in_tensors.size(), 1, platform::errors::PreconditionNotMet( @@ -299,12 +293,6 @@ std::shared_ptr ProcessGroupHeter::Send( std::shared_ptr ProcessGroupHeter::Recv( std::vector& out_tensors, int peer) { -#if defined(PADDLE_WITH_NCCL) - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); -#endif - PADDLE_ENFORCE_EQ( out_tensors.size(), 1, platform::errors::PreconditionNotMet( @@ -343,7 +331,7 @@ std::shared_ptr ProcessGroupHeter::Recv( end = std::chrono::high_resolution_clock::now(); diff = end - start; VLOG(2) << "Time to copy tensor of dims(" << cpu_tensor.dims() - << ") from gpu to cpu for recv " << std::setw(9) + << ") from cpu to gpu for recv " << std::setw(9) << " is: " << diff.count() << " s" << std::endl; return CreateTask(rank_, CommType::RECV, out_tensors); } diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc index 3e8fa631507ab..494665544f0d3 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -44,15 +44,21 @@ class RecvOpV2 : public framework::OperatorWithKernel { "The size of the output shape must be greater than 0 " "but the value given is %d.", out_shape.size())); - for (size_t i = 0; i < out_shape.size(); ++i) { - PADDLE_ENFORCE_GE(out_shape[i], 1, - platform::errors::InvalidArgument( - "The shape attribute for recv_v2 must be set " - "explicitly, but the %dth element is %d which " - "is less than 1.", - i, out_shape[i])); + bool dynamic_shape = ctx->Attrs().Get("dynamic_shape"); + if (!dynamic_shape) { + // No need to check out shape if with dynamic_shape, + // since the shape will be recv from send_v2 + for (size_t i = 0; i < out_shape.size(); ++i) { + PADDLE_ENFORCE_GE(out_shape[i], 1, + platform::errors::InvalidArgument( + "The shape attribute for recv_v2 must be set " + "explicitly, but the %dth element is %d which " + "is less than 1. Or dynamic_shape should be " + "set to True for both send_v2 and recv_v2.", + i, out_shape[i])); + } + ctx->SetOutputDim("Out", phi::make_ddim(out_shape)); } - ctx->SetOutputDim("Out", phi::make_ddim(out_shape)); } } @@ -87,6 +93,10 @@ class RecvOpV2Maker : public framework::OpProtoAndCheckerMaker { "use_calc_stream", "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); + AddAttr( + "dynamic_shape", + "(bool default false) the send/recv will be done with dynamic shape.") + .SetDefault(false); AddComment(R"DOC( Recv Operator diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 7a2a802382f6c..f7a2e198db938 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -25,6 +25,85 @@ limitations under the License. */ namespace paddle { namespace operators { +#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ + NCCL_VERSION_CODE >= 2703 +framework::DDim recv_shape_info(const platform::Place &place, + const gpuStream_t &stream, + platform::NCCLComm *comm, const int &peer, + distributed::ProcessGroup *group) { + if (!group) { + PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr), true, + platform::errors::InvalidArgument( + "NCCLComm and Stream should be provided if use NCCL " + "to send the shape info.")); + } + + paddle::experimental::DataType shape_dytpe = + paddle::experimental::DataType::INT32; + ncclDataType_t nccl_dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(shape_dytpe)); + + // step1: recv the shape size + framework::Tensor gpu_shape_size_tensor(shape_dytpe); + if (!group) { + gpu_shape_size_tensor.Resize({1}); + gpu_shape_size_tensor.mutable_data(place, shape_dytpe); + auto *gpu_data = gpu_shape_size_tensor.data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + gpu_data, 1, nccl_dtype, peer, comm->comm(), stream)); + } + + // copy the shape size tensor to cpu + framework::Tensor *cpu_shape_size_tensor = new framework::Tensor(shape_dytpe); + cpu_shape_size_tensor->Resize({1}); + cpu_shape_size_tensor->mutable_data(platform::CPUPlace(), shape_dytpe); + if (group) { + std::vector shape_size_tensor; + shape_size_tensor.emplace_back(*cpu_shape_size_tensor); + auto shape_size_task = group->Recv(shape_size_tensor, peer); + } else { + framework::TensorCopySync(gpu_shape_size_tensor, platform::CPUPlace(), + cpu_shape_size_tensor); + } + auto *cpu_data = cpu_shape_size_tensor->data(); + int shape_size = cpu_data[0]; + VLOG(3) << "recv the shape size: " << shape_size << " from peer"; + + // step2: recv the shape + framework::Tensor gpu_shape_tensor(shape_dytpe); + if (!group) { + gpu_shape_tensor.Resize({shape_size}); + gpu_shape_tensor.mutable_data(place, shape_dytpe); + auto *gpu_shape_data = gpu_shape_tensor.data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream)); + } + + // copy the shape tensor to cpu + framework::Tensor *cpu_shape_tensor = new framework::Tensor(shape_dytpe); + cpu_shape_tensor->Resize({shape_size}); + cpu_shape_tensor->mutable_data(platform::CPUPlace(), shape_dytpe); + if (group) { + std::vector shape_tensor; + shape_tensor.emplace_back(*cpu_shape_tensor); + auto shape_task = group->Recv(shape_tensor, peer); + } else { + framework::TensorCopySync(gpu_shape_tensor, platform::CPUPlace(), + cpu_shape_tensor); + } + auto *cpu_shape_data = cpu_shape_tensor->data(); + std::vector all_shape; + for (int i = 0; i < shape_size; ++i) { + all_shape.emplace_back(cpu_shape_data[i]); + } + framework::DDim new_dim; + new_dim = new_dim.reshape(all_shape); + VLOG(3) << "recv the shape: (" << new_dim << ") from peer"; + + return new_dim; +} +#endif + template class RecvOpV2CUDAKernel : public framework::OpKernel { public: @@ -32,6 +111,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { #if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ NCCL_VERSION_CODE >= 2703 int rid = ctx.Attr("ring_id"); + bool dynamic_shape = ctx.Attr("dynamic_shape"); PADDLE_ENFORCE_GE( rid, 0, platform::errors::InvalidArgument( @@ -53,7 +133,18 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { auto out_shape = ctx.Attr>("out_shape"); auto out = ctx.Output("Out"); auto out_dims = out->dims(); - out->mutable_data(out_dims, place); + + if (dynamic_shape) { + VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch"; + framework::DDim new_dim = + recv_shape_info(ctx.GetPlace(), + /* gpuStream_t */ nullptr, + /* NCCLComm* */ nullptr, peer, pg); + out->Resize(new_dim); + out->mutable_data(new_dim, place); + } else { + out->mutable_data(out_dims, place); + } out_tensor.emplace_back(*out); auto task = pg->Recv(out_tensor, peer); @@ -79,6 +170,10 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { auto *out_var = ctx.OutputVar("Out"); if (out_var->IsType()) { + PADDLE_ENFORCE_EQ( + dynamic_shape, false, + platform::errors::InvalidArgument("Dynamic shape for send/recv not " + "support LoDTensorArray for now.")); auto out_array = out_var->GetMutable(); for (size_t idx = 0; idx < out_array->size(); ++idx) { VLOG(3) << "LodTensorArray: idx(" << idx << ")"; @@ -99,7 +194,16 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { auto out_dims = out->dims(); auto numel = out->numel(); - out->mutable_data(out_dims, place); + if (dynamic_shape) { + VLOG(3) << "recv_v2 will use dynamic shape with send_v2"; + framework::DDim new_dim = recv_shape_info(place, stream, comm, peer, + /* ProcessGroup* */ nullptr); + out->Resize(new_dim); + numel = out->numel(); + out->mutable_data(new_dim, place); + } else { + out->mutable_data(out_dims, place); + } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( out->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out->dims()) diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc index 753a33268cc95..d685dd561bc74 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -70,6 +70,10 @@ class SendOpV2Maker : public framework::OpProtoAndCheckerMaker { "use_calc_stream", "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); + AddAttr( + "dynamic_shape", + "(bool default false) the send/recv will be done with dynamic shape.") + .SetDefault(false); AddComment(R"DOC( Send Operator diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 57a3fe2e45d7e..8878b7c3449b9 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -24,6 +24,76 @@ limitations under the License. */ namespace paddle { namespace operators { +#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ + NCCL_VERSION_CODE >= 2703 +void send_shape_info(const framework::Tensor& x, const platform::Place& place, + const gpuStream_t& stream, platform::NCCLComm* comm, + const int& peer, distributed::ProcessGroup* group) { + if (!group) { + PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr), true, + platform::errors::InvalidArgument( + "NCCLComm and Stream should be provided if use NCCL " + "to send the shape info.")); + } + paddle::experimental::DataType shape_dytpe = + paddle::experimental::DataType::INT32; + ncclDataType_t nccl_dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(shape_dytpe)); + auto dims = x.dims(); + int shape_size = dims.size(); + + // step1: send the shape size + framework::Tensor cpu_shape_size_tensor(shape_dytpe); + cpu_shape_size_tensor.Resize({1}); + cpu_shape_size_tensor.mutable_data(platform::CPUPlace(), shape_dytpe); + auto* cpu_data = cpu_shape_size_tensor.data(); + cpu_data[0] = shape_size; + + if (group) { + std::vector shape_size_tensor; + shape_size_tensor.template emplace_back(cpu_shape_size_tensor); + auto shape_size_task = group->Send(shape_size_tensor, peer); + } else { + // copy the shape size tensor to gpu and send + framework::Tensor* gpu_shape_size_tensor = + new framework::Tensor(shape_dytpe); + gpu_shape_size_tensor->Resize({1}); + gpu_shape_size_tensor->mutable_data(place, shape_dytpe); + framework::TensorCopySync(cpu_shape_size_tensor, place, + gpu_shape_size_tensor); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(gpu_shape_size_tensor->data(), 1, + nccl_dtype, peer, comm->comm(), stream)); + } + VLOG(3) << "send the shape size: " << shape_size << " to peer"; + + // step2: send the shape + framework::Tensor cpu_shape_tensor(shape_dytpe); + cpu_shape_tensor.Resize({shape_size}); + cpu_shape_tensor.mutable_data(platform::CPUPlace(), shape_dytpe); + auto* cpu_shape_data = cpu_shape_tensor.data(); + for (int i = 0; i < shape_size; ++i) { + cpu_shape_data[i] = dims[i]; + } + + if (group) { + std::vector shape_tensor; + shape_tensor.template emplace_back(cpu_shape_tensor); + auto shape_task = group->Send(shape_tensor, peer); + } else { + // copy the shape tensor to gpu and send + framework::Tensor* gpu_shape_tensor = new framework::Tensor(shape_dytpe); + gpu_shape_tensor->Resize({shape_size}); + gpu_shape_tensor->mutable_data(place, shape_dytpe); + framework::TensorCopySync(cpu_shape_tensor, place, gpu_shape_tensor); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(gpu_shape_tensor->data(), shape_size, + nccl_dtype, peer, comm->comm(), stream)); + } + VLOG(3) << "send the shape: (" << dims << ") to peer"; +} +#endif + template class SendOpV2CUDAKernel : public framework::OpKernel { public: @@ -31,6 +101,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel { #if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ NCCL_VERSION_CODE >= 2703 int rid = ctx.Attr("ring_id"); + bool dynamic_shape = ctx.Attr("dynamic_shape"); PADDLE_ENFORCE_GE( rid, 0, platform::errors::InvalidArgument( @@ -45,8 +116,17 @@ class SendOpV2CUDAKernel : public framework::OpKernel { if (map->has(rid)) { // Use ProcessGroup distributed::ProcessGroup* pg = map->get(rid); - std::vector in_tensor; auto x = ctx.Input("X"); + + if (dynamic_shape) { + // dynamic shape for switch send/recv + VLOG(3) << "send_v2 will use dynamic shape with recv_v2 for switch"; + send_shape_info(*x, ctx.GetPlace(), + /* gpuStream_t */ nullptr, + /* NCCLComm* */ nullptr, peer, pg); + } + + std::vector in_tensor; in_tensor.push_back(*x); auto task = pg->Send(in_tensor, peer); return; @@ -68,6 +148,10 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto* x_var = ctx.InputVar("X"); if (x_var->IsType()) { + PADDLE_ENFORCE_EQ( + dynamic_shape, false, + platform::errors::InvalidArgument("Dynamic shape for send/recv not " + "support LoDTensorArray for now.")); auto& x_array = x_var->Get(); for (size_t idx = 0; idx < x_array.size(); idx++) { VLOG(3) << "LodTensorArray: idx(" << idx << ")"; @@ -85,6 +169,12 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto x = ctx.Input("X"); int numel = x->numel(); + if (dynamic_shape) { + VLOG(3) << "send_v2 will use dynamic shape with recv_v2"; + send_shape_info(*x, place, stream, comm, peer, + /* ProcessGroup* */ nullptr); + } + ncclDataType_t dtype = platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a781f314d3f20..cd03e55f25f61 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -337,6 +337,7 @@ def barrier(group=None): def _set_custom_gid(gid): + global _custom_gid _custom_gid = gid @@ -363,6 +364,7 @@ def new_group(ranks=None, backend=None): paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False) """ + global _custom_gid global _group_map if in_dygraph_mode(): global _default_group_name diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op_dynamic_shape.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op_dynamic_shape.py new file mode 100644 index 0000000000000..093af635f44f6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op_dynamic_shape.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveSendRecvDynamicShape(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = self.global_ring_id + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", + shape=[10, 1000], + dtype='float64', + append_batch_size=False) + if self.rank == 0: + main_prog.global_block().append_op( + type="send_v2", + inputs={'X': tindata}, + attrs={ + 'ring_id': ring_id, + 'peer': 1, + 'use_calc_stream': True, + 'dynamic_shape': True + }) + else: + main_prog.global_block().append_op( + type="recv_v2", + outputs={'Out': tindata}, + attrs={ + 'peer': 0, + 'ring_id': ring_id, + 'dtype': tindata.dtype, + 'out_shape': tindata.shape, + 'use_calc_stream': True, + 'dynamic_shape': True + }) + return tindata + + +if __name__ == "__main__": + runtime_main(TestCollectiveSendRecvDynamicShape, "sendrecv_dynamic_shape", + 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py index 40bacaf59d2f3..d3bcd0a7e6985 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py +++ b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py @@ -29,6 +29,10 @@ def _setup_config(self): def test_sendrecv(self): self.check_with_place("collective_sendrecv_op.py", "sendrecv") + def test_sendrecv_dynamic_shape(self): + self.check_with_place("collective_sendrecv_op_dynamic_shape.py", + "sendrecv_dynamic_shape") + def test_sendrecv_array(self): self.check_with_place("collective_sendrecv_op_array.py", "sendrecv_array")