Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collective communication APIs to improve completeness #49252

Merged
merged 5 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
from .collective import is_available # noqa: F401

from .communication import (
stream,
Expand All @@ -39,9 +40,11 @@
alltoall,
alltoall_single,
broadcast,
broadcast_object_list,
reduce,
send,
scatter,
scatter_object_list,
isend,
recv,
irecv,
Expand All @@ -53,6 +56,7 @@
get_group,
wait,
barrier,
get_backend,
) # noqa: F401

from .auto_parallel import shard_op # noqa: F401
Expand Down Expand Up @@ -81,7 +85,9 @@
"spawn",
"launch",
"scatter",
"scatter_object_list",
"broadcast",
"broadcast_object_list",
"ParallelEnv",
"new_group",
"init_parallel_env",
Expand Down Expand Up @@ -114,4 +120,6 @@
"isend",
"irecv",
"reduce_scatter",
"is_available",
"get_backend",
]
18 changes: 18 additions & 0 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,21 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
paddle.distributed.all_reduce(tmp, sync_op=True)
paddle.distributed.wait(tmp)
return gp


def is_available():
"""
Check whether the distributed package is available.

Returns:
Returns True if the distributed package is available, otherwise False.

Examples:
.. code-block:: python

import paddle

print(paddle.distributed.is_available())

"""
return core.is_compiled_with_dist()
5 changes: 3 additions & 2 deletions python/paddle/distributed/communication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from .all_gather import all_gather, all_gather_object
from .all_reduce import all_reduce
from .broadcast import broadcast
from .broadcast import broadcast, broadcast_object_list
from .reduce import reduce, ReduceOp
from .send import send, isend
from .recv import recv, irecv
from .scatter import scatter
from .scatter import scatter, scatter_object_list
from .batch_isend_irecv import batch_isend_irecv, P2POp
from .reduce_scatter import reduce_scatter
from .all_to_all import alltoall, alltoall_single
Expand All @@ -27,4 +27,5 @@
get_group,
wait,
barrier,
get_backend,
)
26 changes: 7 additions & 19 deletions python/paddle/distributed/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import pickle

import numpy as np

import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework

from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)


def all_gather(tensor_list, tensor, group=None, sync_op=True):
"""
Expand Down Expand Up @@ -66,20 +68,6 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
return stream.all_gather(tensor_list, tensor, group, sync_op)


def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()


def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()


def all_gather_object(object_list, obj, group=None):
"""

Expand Down Expand Up @@ -117,7 +105,7 @@ def all_gather_object(object_list, obj, group=None):
framework.in_dygraph_mode()
), "all_gather_object doesn't support static graph mode."

tensor, len_of_tensor = _convert_object_to_tensor(obj)
tensor, len_of_tensor = convert_object_to_tensor(obj)

# gather len_of_tensor from all ranks
list_len_of_tensor = []
Expand All @@ -135,5 +123,5 @@ def all_gather_object(object_list, obj, group=None):
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i])
convert_tensor_to_object(tensor, list_len_of_tensor[i])
)
75 changes: 75 additions & 0 deletions python/paddle/distributed/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework

from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)


def broadcast(tensor, src, group=None, sync_op=True):
Expand Down Expand Up @@ -60,3 +68,70 @@ def broadcast(tensor, src, group=None, sync_op=True):
sync_op=sync_op,
use_calc_stream=False,
)


def broadcast_object_list(object_list, src, group=None):
"""

Broadcast picklable objects from the source to all others. Similiar to broadcast(), but python object can be passed in.

Args:
object_list (list): The list of objects to send if current rank is the source, or the list of objects to receive otherwise.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.

Returns:
None.

Warning:
This API only supports the dygraph mode.

Examples:
.. code-block:: python

# required: distributed
import paddle.distributed as dist

dist.init_parallel_env()
if dist.get_rank() == 0:
object_list = [{"foo": [1, 2, 3]}]
else:
object_list = [{"bar": [4, 5, 6]}]
dist.broadcast_object_list(object_list, src=1)
print(object_list)
# [{"bar": [4, 5, 6]}] (2 GPUs)
"""
assert (
framework.in_dygraph_mode()
), "broadcast_object_list doesn't support static graph mode."

rank = dist.get_rank()
obj_tensors = []
obj_nums = len(object_list)

if rank == src:
obj_sizes = []
for obj in object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
obj_tensors.append(obj_tensor)
obj_sizes.append(obj_size)
obj_size_tensor = paddle.concat(obj_sizes)
else:
obj_size_tensor = paddle.empty([obj_nums], dtype="int64")
broadcast(obj_size_tensor, src)

if rank == src:
# cast to uint8 to keep the same dtype
obj_data_tensor = paddle.concat(obj_tensors).cast("uint8")
else:
data_len = paddle.sum(obj_size_tensor).item()
obj_data_tensor = paddle.empty([data_len], dtype="uint8")
broadcast(obj_data_tensor, src)

offset = 0
for i in range(obj_nums):
data_len = obj_size_tensor[i]
object_list[i] = convert_tensor_to_object(
obj_data_tensor[offset : offset + data_len], data_len
)
offset += data_len
33 changes: 29 additions & 4 deletions python/paddle/distributed/communication/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.fluid.framework import in_dygraph_mode


class Group:
Expand Down Expand Up @@ -236,7 +235,7 @@ def get_group(id=0):


def _sync_calc_stream(tensor):
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
else:
op_type = 'c_sync_calc_stream'
Expand All @@ -249,7 +248,7 @@ def _sync_calc_stream(tensor):


def _sync_comm_stream(tensor, ring_id=0):
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
)
Expand Down Expand Up @@ -337,7 +336,7 @@ def barrier(group=None):
ring_id = 0 if group is None else group.id

barrier_tensor = paddle.full([1], 1, dtype="int32")
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
Expand All @@ -352,3 +351,29 @@ def barrier(group=None):
outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id},
)


def get_backend(group=None):
"""
Get the backend of given group.

Args:
group (Group): The group to work on. Use the global group as default.

Returns:
Returns the name of the given group backend.

Examples:
.. code-block:: python

# required: distributed
import paddle

paddle.distributed.init_parallel_env()
paddle.distributed.get_backend() # NCCL
"""
if _warn_cur_rank_not_in_group(group):
raise RuntimeError("Invalid group specified")

group = _get_global_group() if group is None else group
return group.backend
86 changes: 86 additions & 0 deletions python/paddle/distributed/communication/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework

from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)


def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
Expand Down Expand Up @@ -59,3 +69,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
# [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
"""
return stream.scatter(tensor, tensor_list, src, group, sync_op)


def scatter_object_list(
out_object_list, in_object_list=None, src=0, group=None
):
"""

Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in.

Args:
out_object_list (list): The list of objects to store the scattered objects.
in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.

Returns:
None.

Warning:
This API only supports the dygraph mode.

Examples:
.. code-block:: python

# required: distributed
import paddle.distributed as dist

dist.init_parallel_env()
out_object_list = []
if dist.get_rank() == 0:
in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}]
else:
in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}]
dist.scatter_object_list(out_object_list, in_object_list, src=1)
print(out_object_list)
# [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0)
# [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1)
"""
assert (
framework.in_dygraph_mode()
), "scatter_object_list doesn't support static graph mode."

rank = dist.get_rank()
in_obj_tensors = []
in_obj_sizes = []

if rank == src:
for obj in in_object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
in_obj_tensors.append(obj_tensor)
in_obj_sizes.append(obj_size)
max_obj_size_tensor = max(in_obj_sizes)
else:
# NOTE: shape can be [] after 0D tensor support
max_obj_size_tensor = paddle.empty([1], dtype="int64")
stream.broadcast(max_obj_size_tensor, src)
max_obj_size = int(max_obj_size_tensor.item())

# resize to the same size
in_tensor_list = []
for tensor in in_obj_tensors:
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_obj_size])
in_tensor = paddle.to_tensor(numpy_data)
in_tensor_list.append(in_tensor)
out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src)

# NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src)

out_object_list.clear()
out_object_list.append(
convert_tensor_to_object(out_tensor, out_tensor_size.item())
)
Loading