From 2ae21c00eedcefc482543301e07e9d8b4a783ca4 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 18 Sep 2024 21:54:30 +0800 Subject: [PATCH 1/7] add support custom_op check --- vllm/distributed/parallel_state.py | 48 ++++++++++++++++-------------- vllm/utils.py | 7 +++++ 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index df07842edfa56..f311b45812408 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -36,6 +36,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import supports_custom_op @dataclass @@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) # type: ignore -@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce(tensor) +if supports_custom_op(): + @torch.library.custom_op("vllm::inplace_all_reduce", + mutates_args=["tensor"]) + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) -@inplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> None: - return - - -@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) -def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce(tensor) - + @inplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> None: + return -@outplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + def outplace_all_reduce(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) + + @outplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) class GroupCoordinator: diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec7834..3f7367d1133e1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1245,6 +1245,13 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") + + class AtomicCounter: """An atomic, thread-safe counter""" From 374918047a74be06556d6a2bc1fd21777609c37f Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 19 Sep 2024 17:48:07 +0800 Subject: [PATCH 2/7] fix --- vllm/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 3f7367d1133e1..033a21e16e354 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1248,8 +1248,8 @@ def supports_dynamo() -> bool: # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: - base_torch_version = Version(Version(torch.__version__).base_version) - return base_torch_version >= Version("2.4.0") + # use Version lib like `supports_dynamo`` will break doc build. + return not is_xpu() class AtomicCounter: From 9e5fdcebb08a386520454b48b91093799a0750cf Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 19 Sep 2024 18:55:41 +0800 Subject: [PATCH 3/7] fix all reduce call --- vllm/distributed/parallel_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f311b45812408..bc499109c91cf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -345,6 +345,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name) + elif not supports_custom_op(): + return self._all_reduce(input_) else: torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name) From b6f524c55c8dae92b721497ba5a5805ed70bb1ea Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 20 Sep 2024 16:28:10 +0800 Subject: [PATCH 4/7] add neuron --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 033a21e16e354..403e6db8a7113 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1249,7 +1249,7 @@ def supports_dynamo() -> bool: # support `torch.library.custom_op`. def supports_custom_op() -> bool: # use Version lib like `supports_dynamo`` will break doc build. - return not is_xpu() + return not (is_xpu() and is_neuron()) class AtomicCounter: From 87c0c95c1002a9cc90e1369e120a02d62d6b3d04 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 20 Sep 2024 18:04:55 +0800 Subject: [PATCH 5/7] fix --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 403e6db8a7113..3a19ccdc5afe4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1249,7 +1249,7 @@ def supports_dynamo() -> bool: # support `torch.library.custom_op`. def supports_custom_op() -> bool: # use Version lib like `supports_dynamo`` will break doc build. - return not (is_xpu() and is_neuron()) + return not (is_xpu() or is_neuron()) class AtomicCounter: From 3a38075b7f520e82daea82c4e559c5ba97058b34 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 20 Sep 2024 13:45:20 -0700 Subject: [PATCH 6/7] change implementation --- vllm/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 3a19ccdc5afe4..43b64263d645a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1248,8 +1248,7 @@ def supports_dynamo() -> bool: # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: - # use Version lib like `supports_dynamo`` will break doc build. - return not (is_xpu() or is_neuron()) + return hasattr(torch.library, "custom_op") class AtomicCounter: From d4f215fdeeb2e1988d4dbc8cc44ba55b4c49fdff Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 20 Sep 2024 13:48:49 -0700 Subject: [PATCH 7/7] polish condition --- vllm/distributed/parallel_state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bc499109c91cf..d3ac4eb78b155 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -337,6 +337,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if not supports_custom_op(): + return self._all_reduce(input_) + if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. @@ -345,8 +348,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name) - elif not supports_custom_op(): - return self._all_reduce(input_) else: torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name)