From abe5398c9ed1ffdd4faf9aab9a9825b97dbe6078 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 23 Jul 2024 15:07:25 +0800 Subject: [PATCH 1/2] fix --- python/paddle/distributed/collective.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index aae5b3c37f4587..478db1f4d7a4cd 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -11,10 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import datetime import hashlib import os +from typing import ( + TYPE_CHECKING, + Literal, +) + +from typing_extensions import TypeAlias import paddle @@ -38,6 +45,8 @@ split, ) +if TYPE_CHECKING: + _BackendList: TypeAlias = Literal["gloo", "nccl", "xccl", "bkcl"] __all__ = [] _global_env = None @@ -184,11 +193,11 @@ def _set_custom_gid(gid): def new_group( - ranks=None, - backend=None, - timeout=_default_timeout, - nccl_comm_init_option=0, -): + ranks: list | None = None, + backend: str | None = None, + timeout: datetime.timedelta = _default_timeout, + nccl_comm_init_option: int = 0, +) -> Group: """ Creates a new distributed communication group. @@ -320,7 +329,7 @@ def new_group( return gp -def is_available(): +def is_available() -> bool: """ Check whether the distributed package is available. @@ -337,7 +346,7 @@ def is_available(): return core.is_compiled_with_dist() -def _init_parallel_env(backend): +def _init_parallel_env(backend: _BackendList) -> None: store = core.create_or_get_global_tcp_store() global_env = _get_global_env() rank = global_env.rank From 7ac745a06d497dc41d10fc092c4a7b728859d076 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 25 Jul 2024 15:21:01 +0800 Subject: [PATCH 2/2] fix --- python/paddle/distributed/collective.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 478db1f4d7a4cd..3a8af8be35c6bc 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -193,8 +193,8 @@ def _set_custom_gid(gid): def new_group( - ranks: list | None = None, - backend: str | None = None, + ranks: list[int] | None = None, + backend: Literal['nccl'] | None = None, timeout: datetime.timedelta = _default_timeout, nccl_comm_init_option: int = 0, ) -> Group: