diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 4e81ce52ef908..22a7bed2fa587 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -29,6 +29,7 @@ from .collective import split # noqa: F401 from .collective import new_group # noqa: F401 +from .collective import is_available # noqa: F401 from .communication import ( stream, @@ -39,9 +40,11 @@ alltoall, alltoall_single, broadcast, + broadcast_object_list, reduce, send, scatter, + scatter_object_list, isend, recv, irecv, @@ -53,6 +56,7 @@ get_group, wait, barrier, + get_backend, ) # noqa: F401 from .auto_parallel import shard_op # noqa: F401 @@ -81,7 +85,9 @@ "spawn", "launch", "scatter", + "scatter_object_list", "broadcast", + "broadcast_object_list", "ParallelEnv", "new_group", "init_parallel_env", @@ -114,4 +120,6 @@ "isend", "irecv", "reduce_scatter", + "is_available", + "get_backend", ] diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 83777bbb2f924..7073758b9d52a 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -307,3 +307,21 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): paddle.distributed.all_reduce(tmp, sync_op=True) paddle.distributed.wait(tmp) return gp + + +def is_available(): + """ + Check whether the distributed package is available. + + Returns: + Returns True if the distributed package is available, otherwise False. + + Examples: + .. code-block:: python + + import paddle + + print(paddle.distributed.is_available()) + + """ + return core.is_compiled_with_dist() diff --git a/python/paddle/distributed/communication/__init__.py b/python/paddle/distributed/communication/__init__.py index fb3408020d624..1d21e8103353c 100644 --- a/python/paddle/distributed/communication/__init__.py +++ b/python/paddle/distributed/communication/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from .all_gather import all_gather, all_gather_object from .all_reduce import all_reduce -from .broadcast import broadcast +from .broadcast import broadcast, broadcast_object_list from .reduce import reduce, ReduceOp from .send import send, isend from .recv import recv, irecv -from .scatter import scatter +from .scatter import scatter, scatter_object_list from .batch_isend_irecv import batch_isend_irecv, P2POp from .reduce_scatter import reduce_scatter from .all_to_all import alltoall, alltoall_single @@ -27,4 +27,5 @@ get_group, wait, barrier, + get_backend, ) diff --git a/python/paddle/distributed/communication/all_gather.py b/python/paddle/distributed/communication/all_gather.py index 45496c0e30d96..18f4bbab7ce08 100644 --- a/python/paddle/distributed/communication/all_gather.py +++ b/python/paddle/distributed/communication/all_gather.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io -import pickle - import numpy as np import paddle import paddle.distributed.communication.stream as stream import paddle.fluid.framework as framework +from .serialization_utils import ( + convert_object_to_tensor, + convert_tensor_to_object, +) + def all_gather(tensor_list, tensor, group=None, sync_op=True): """ @@ -66,20 +68,6 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): return stream.all_gather(tensor_list, tensor, group, sync_op) -def _convert_object_to_tensor(obj): - _pickler = pickle.Pickler - f = io.BytesIO() - _pickler(f).dump(obj) - data = np.frombuffer(f.getvalue(), dtype=np.uint8) - tensor = paddle.to_tensor(data) - return tensor, tensor.numel() - - -def _convert_tensor_to_object(tensor, len_of_tensor): - _unpickler = pickle.Unpickler - return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load() - - def all_gather_object(object_list, obj, group=None): """ @@ -117,7 +105,7 @@ def all_gather_object(object_list, obj, group=None): framework.in_dygraph_mode() ), "all_gather_object doesn't support static graph mode." - tensor, len_of_tensor = _convert_object_to_tensor(obj) + tensor, len_of_tensor = convert_object_to_tensor(obj) # gather len_of_tensor from all ranks list_len_of_tensor = [] @@ -135,5 +123,5 @@ def all_gather_object(object_list, obj, group=None): all_gather(tensor_list, input_tensor, group) for i, tensor in enumerate(tensor_list): object_list.append( - _convert_tensor_to_object(tensor, list_len_of_tensor[i]) + convert_tensor_to_object(tensor, list_len_of_tensor[i]) ) diff --git a/python/paddle/distributed/communication/broadcast.py b/python/paddle/distributed/communication/broadcast.py index eccd3bf983633..fd6c2219c8b25 100644 --- a/python/paddle/distributed/communication/broadcast.py +++ b/python/paddle/distributed/communication/broadcast.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle +import paddle.distributed as dist import paddle.distributed.communication.stream as stream +import paddle.fluid.framework as framework + +from .serialization_utils import ( + convert_object_to_tensor, + convert_tensor_to_object, +) def broadcast(tensor, src, group=None, sync_op=True): @@ -60,3 +68,70 @@ def broadcast(tensor, src, group=None, sync_op=True): sync_op=sync_op, use_calc_stream=False, ) + + +def broadcast_object_list(object_list, src, group=None): + """ + + Broadcast picklable objects from the source to all others. Similiar to broadcast(), but python object can be passed in. + + Args: + object_list (list): The list of objects to send if current rank is the source, or the list of objects to receive otherwise. + src (int): The source rank in global view. + group (Group): The group instance return by new_group or None for global default group. + + Returns: + None. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle.distributed as dist + + dist.init_parallel_env() + if dist.get_rank() == 0: + object_list = [{"foo": [1, 2, 3]}] + else: + object_list = [{"bar": [4, 5, 6]}] + dist.broadcast_object_list(object_list, src=1) + print(object_list) + # [{"bar": [4, 5, 6]}] (2 GPUs) + """ + assert ( + framework.in_dygraph_mode() + ), "broadcast_object_list doesn't support static graph mode." + + rank = dist.get_rank() + obj_tensors = [] + obj_nums = len(object_list) + + if rank == src: + obj_sizes = [] + for obj in object_list: + obj_tensor, obj_size = convert_object_to_tensor(obj) + obj_tensors.append(obj_tensor) + obj_sizes.append(obj_size) + obj_size_tensor = paddle.concat(obj_sizes) + else: + obj_size_tensor = paddle.empty([obj_nums], dtype="int64") + broadcast(obj_size_tensor, src) + + if rank == src: + # cast to uint8 to keep the same dtype + obj_data_tensor = paddle.concat(obj_tensors).cast("uint8") + else: + data_len = paddle.sum(obj_size_tensor).item() + obj_data_tensor = paddle.empty([data_len], dtype="uint8") + broadcast(obj_data_tensor, src) + + offset = 0 + for i in range(obj_nums): + data_len = obj_size_tensor[i] + object_list[i] = convert_tensor_to_object( + obj_data_tensor[offset : offset + data_len], data_len + ) + offset += data_len diff --git a/python/paddle/distributed/communication/group.py b/python/paddle/distributed/communication/group.py index a48c0c080f0e5..f0236a2bdbb39 100644 --- a/python/paddle/distributed/communication/group.py +++ b/python/paddle/distributed/communication/group.py @@ -19,7 +19,6 @@ import paddle.fluid.core as core import paddle.fluid.framework as framework import paddle.fluid.layer_helper as layer_helper -from paddle.fluid.framework import in_dygraph_mode class Group: @@ -236,7 +235,7 @@ def get_group(id=0): def _sync_calc_stream(tensor): - if in_dygraph_mode(): + if framework.in_dygraph_mode(): return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor) else: op_type = 'c_sync_calc_stream' @@ -249,7 +248,7 @@ def _sync_calc_stream(tensor): def _sync_comm_stream(tensor, ring_id=0): - if in_dygraph_mode(): + if framework.in_dygraph_mode(): return paddle._legacy_C_ops.c_sync_comm_stream( [tensor], [tensor], 'ring_id', ring_id ) @@ -337,7 +336,7 @@ def barrier(group=None): ring_id = 0 if group is None else group.id barrier_tensor = paddle.full([1], 1, dtype="int32") - if in_dygraph_mode(): + if framework.in_dygraph_mode(): return paddle._legacy_C_ops.barrier( barrier_tensor, barrier_tensor, 'ring_id', ring_id ) @@ -352,3 +351,29 @@ def barrier(group=None): outputs={'Out': [barrier_tensor]}, attrs={'ring_id': ring_id}, ) + + +def get_backend(group=None): + """ + Get the backend of given group. + + Args: + group (Group): The group to work on. Use the global group as default. + + Returns: + Returns the name of the given group backend. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + + paddle.distributed.init_parallel_env() + paddle.distributed.get_backend() # NCCL + """ + if _warn_cur_rank_not_in_group(group): + raise RuntimeError("Invalid group specified") + + group = _get_global_group() if group is None else group + return group.backend diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index f551811721490..ee5886c414d61 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + +import paddle +import paddle.distributed as dist import paddle.distributed.communication.stream as stream +import paddle.fluid.framework as framework + +from .serialization_utils import ( + convert_object_to_tensor, + convert_tensor_to_object, +) def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): @@ -59,3 +69,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1) """ return stream.scatter(tensor, tensor_list, src, group, sync_op) + + +def scatter_object_list( + out_object_list, in_object_list=None, src=0, group=None +): + """ + + Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in. + + Args: + out_object_list (list): The list of objects to store the scattered objects. + in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered. + src (int): The source rank in global view. + group (Group): The group instance return by new_group or None for global default group. + + Returns: + None. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle.distributed as dist + + dist.init_parallel_env() + out_object_list = [] + if dist.get_rank() == 0: + in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}] + else: + in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}] + dist.scatter_object_list(out_object_list, in_object_list, src=1) + print(out_object_list) + # [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0) + # [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1) + """ + assert ( + framework.in_dygraph_mode() + ), "scatter_object_list doesn't support static graph mode." + + rank = dist.get_rank() + in_obj_tensors = [] + in_obj_sizes = [] + + if rank == src: + for obj in in_object_list: + obj_tensor, obj_size = convert_object_to_tensor(obj) + in_obj_tensors.append(obj_tensor) + in_obj_sizes.append(obj_size) + max_obj_size_tensor = max(in_obj_sizes) + else: + # NOTE: shape can be [] after 0D tensor support + max_obj_size_tensor = paddle.empty([1], dtype="int64") + stream.broadcast(max_obj_size_tensor, src) + max_obj_size = int(max_obj_size_tensor.item()) + + # resize to the same size + in_tensor_list = [] + for tensor in in_obj_tensors: + numpy_data = tensor.numpy() + numpy_data = np.resize(numpy_data, [max_obj_size]) + in_tensor = paddle.to_tensor(numpy_data) + in_tensor_list.append(in_tensor) + out_tensor = paddle.empty([max_obj_size], dtype="uint8") + scatter(out_tensor, in_tensor_list if rank == src else None, src) + + # NOTE: shape can be [] after 0D tensor support + out_tensor_size = paddle.empty([1], dtype="int64") + scatter(out_tensor_size, in_obj_sizes if rank == src else None, src) + + out_object_list.clear() + out_object_list.append( + convert_tensor_to_object(out_tensor, out_tensor_size.item()) + ) diff --git a/python/paddle/distributed/communication/serialization_utils.py b/python/paddle/distributed/communication/serialization_utils.py new file mode 100644 index 0000000000000..f445a3f62bb23 --- /dev/null +++ b/python/paddle/distributed/communication/serialization_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import pickle + +import numpy as np + +import paddle + + +def convert_object_to_tensor(obj): + _pickler = pickle.Pickler + f = io.BytesIO() + _pickler(f).dump(obj) + data = np.frombuffer(f.getvalue(), dtype=np.uint8) + tensor = paddle.to_tensor(data) + return tensor, tensor.numel() + + +def convert_tensor_to_object(tensor, len_of_tensor): + _unpickler = pickle.Unpickler + return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load() diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index 4e19583be68ed..12161326cf0a7 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -127,6 +127,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_broadcast_object_list_api MODULES + test_collective_broadcast_object_list_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_broadcast_object_list_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_cpu_barrier_with_gloo MODULES @@ -223,6 +231,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_collective_scatter_api PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_scatter_object_list_api MODULES + test_collective_scatter_object_list_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_scatter_object_list_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_sendrecv MODULES test_collective_sendrecv ENVS diff --git a/python/paddle/fluid/tests/unittests/collective/collective_broadcast_object_list_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_object_list_api_dygraph.py new file mode 100644 index 0000000000000..7e34818ca1cd1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_object_list_api_dygraph.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_collective_api_base as test_base + +import paddle.distributed as dist +import paddle.fluid as fluid + + +class TestCollectiveBroadcastObjectListAPI( + test_base.TestCollectiveAPIRunnerBase +): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + object_list = [indata] + dist.broadcast_object_list(object_list, src=1) + return object_list + + +if __name__ == "__main__": + test_base.runtime_main( + TestCollectiveBroadcastObjectListAPI, "broadcast_object_list" + ) diff --git a/python/paddle/fluid/tests/unittests/collective/collective_scatter_object_list_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_scatter_object_list_api_dygraph.py new file mode 100644 index 0000000000000..b53a35feac51a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_scatter_object_list_api_dygraph.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_collective_api_base as test_base + +import paddle.distributed as dist +import paddle.fluid as fluid + + +class TestCollectiveScatterObjectListAPI(test_base.TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + data_len = len(indata) // 2 + in_object_list = [indata[:data_len], indata[data_len:]] + out_object_list = [] + dist.scatter_object_list(out_object_list, in_object_list, src=1) + return out_object_list + + +if __name__ == "__main__": + test_base.runtime_main( + TestCollectiveScatterObjectListAPI, "scatter_object_list" + ) diff --git a/python/paddle/fluid/tests/unittests/collective/process_group_nccl.py b/python/paddle/fluid/tests/unittests/collective/process_group_nccl.py index 6debf636958fe..130510d90cb04 100644 --- a/python/paddle/fluid/tests/unittests/collective/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/collective/process_group_nccl.py @@ -46,10 +46,14 @@ def test_create_process_group_nccl(self): device_id = paddle.distributed.ParallelEnv().dev_id paddle.set_device('gpu:%d' % device_id) + assert paddle.distributed.is_available() + pg = init_process_group() print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name()) print("test new group api ok") + assert paddle.distributed.get_backend() == "NCCL" + # test allreduce sum # rank 0 x = np.random.random(self.shape).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_object_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_object_api.py index 4f7a0c35a860c..1c30baa6d81d5 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_object_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_object_api.py @@ -27,14 +27,7 @@ def test_allgather_nccl(self): "allgather_object", "nccl", static_mode="0", - dtype="pylist", - ) - self.check_with_place( - "collective_allgather_object_api_dygraph.py", - "allgather_object", - "nccl", - static_mode="0", - dtype="pydict", + dtype="pyobject", ) def test_allgather_gloo_dygraph(self): @@ -44,15 +37,7 @@ def test_allgather_gloo_dygraph(self): "gloo", "3", static_mode="0", - dtype="pylist", - ) - self.check_with_place( - "collective_allgather_object_api_dygraph.py", - "allgather_object", - "gloo", - "3", - static_mode="0", - dtype="pydict", + dtype="pyobject", ) diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_object_list_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_object_list_api.py new file mode 100644 index 0000000000000..cb8983208cadf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_object_list_api.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import test_collective_api_base as test_base + + +class TestCollectiveBroadcastObjectListAPI(test_base.TestDistBase): + def _setup_config(self): + pass + + def test_broadcast_nccl(self): + self.check_with_place( + "collective_broadcast_object_list_api_dygraph.py", + "broadcast_object_list", + "nccl", + static_mode="0", + dtype="pyobject", + ) + + def test_broadcast_gloo_dygraph(self): + self.check_with_place( + "collective_broadcast_object_list_api_dygraph.py", + "broadcast_object_list", + "gloo", + "3", + static_mode="0", + dtype="pyobject", + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_object_list_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_object_list_api.py new file mode 100644 index 0000000000000..c31378f29785f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_object_list_api.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import test_collective_api_base as test_base + + +class TestCollectiveScatterObjectListAPI(test_base.TestDistBase): + def _setup_config(self): + pass + + def test_scatter_nccl(self): + self.check_with_place( + "collective_scatter_object_list_api_dygraph.py", + "scatter_object_list", + "nccl", + static_mode="0", + dtype="pyobject", + ) + + def test_scatter_gloo_dygraph(self): + self.check_with_place( + "collective_scatter_object_list_api_dygraph.py", + "scatter_object_list", + "gloo", + "3", + static_mode="0", + dtype="pyobject", + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index 5d554aeee880a..6c02e39c42231 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -14,6 +14,7 @@ test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,ht test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_broadcast_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., @@ -26,6 +27,7 @@ test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_p test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_scatter_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index d18469c90393c..ecabdb92fcb2a 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -58,22 +58,15 @@ def create_complex_test_data(shape=None, dtype=None, seed=None): return data -def create_pylist_test_data(shape=None, seed=None): +def create_pyobject_test_data(shape=None, seed=None): if seed: np.random.seed(seed) - # Generate random shape test case for xxx_object api - shape = np.random.randint(0, high=100, size=(2)).tolist() - data = np.random.random(shape).tolist() - return data - - -def create_pydict_test_data(shape=None, seed=None): - if seed: - np.random.seed(seed) - key = [i for i in range(0, shape[0])] - value = np.random.random(shape).tolist() - data = dict(zip(key, value)) - return data + list_shape = np.random.randint(0, high=100, size=(2)).tolist() + list_data = np.random.random(shape).tolist() + dict_key = [i for i in range(0, shape[0])] + dict_val = np.random.random(shape).tolist() + dict_data = dict(zip(dict_key, dict_val)) + return [list_data, dict_data] def create_test_data(shape=None, dtype=None, seed=None): @@ -94,10 +87,8 @@ def create_test_data(shape=None, dtype=None, seed=None): return create_int_test_data(shape=shape, dtype=dtype, seed=seed) elif dtype == "complex64" or dtype == "complex128": return create_complex_test_data(shape=shape, dtype=dtype, seed=seed) - elif dtype == "pylist": - return create_pylist_test_data(shape=shape, seed=seed) - elif dtype == "pydict": - return create_pydict_test_data(shape=shape, seed=seed) + elif dtype == "pyobject": + return create_pyobject_test_data(shape=shape, seed=seed) else: raise NotImplementedError("Unsupported dtype for creating test data.") @@ -342,7 +333,7 @@ def check_with_place( tr_out1 = np.vstack((tr1_out[0], tr1_out[1])) np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05) np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05) - if col_type == "allgather_object": + elif col_type == "allgather_object": need_result = [input1, input2] self.assertEqual(need_result, tr0_out) self.assertEqual(need_result, tr1_out) @@ -350,6 +341,10 @@ def check_with_place( need_result = input2 np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05) np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05) + elif col_type == "broadcast_object_list": + need_result = [input2] + self.assertEqual(need_result, tr0_out) + self.assertEqual(need_result, tr1_out) elif col_type == "reduce": need_result = input1 + input2 # bfloat16 precision loss comes from truncating the last 16 bits of float32, @@ -365,6 +360,12 @@ def check_with_place( need_result2 = need_result[need_result.shape[0] // 2 :] np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05) np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05) + elif col_type == "scatter_object_list": + need_result = input2 + need_result1 = [need_result[0 : len(need_result) // 2]] + need_result2 = [need_result[len(need_result) // 2 :]] + self.assertEqual(need_result1, tr0_out) + self.assertEqual(need_result2, tr1_out) elif col_type == "reduce_scatter": need_result = input1 + input2 need_result1 = need_result[0 : need_result.shape[0] // 2]