From 91e0c6733db733164b671ce22a5f9a00ac04b85b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 18:43:11 -0700 Subject: [PATCH] [Misc] Add a wrapper for torch.inference_mode (#6618) --- vllm/platforms/__init__.py | 9 +++++++-- vllm/platforms/interface.py | 21 +++++++++++++++++++++ vllm/platforms/tpu.py | 17 +++++++++++++++++ vllm/worker/model_runner_base.py | 3 ++- vllm/worker/worker_base.py | 3 ++- 5 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 vllm/platforms/tpu.py diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 7309f7bf795d6..eac917786bd6b 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -2,7 +2,9 @@ import torch -from .interface import Platform, PlatformEnum +from vllm.utils import is_tpu + +from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Optional[Platform] @@ -12,7 +14,10 @@ elif torch.version.hip is not None: from .rocm import RocmPlatform current_platform = RocmPlatform() +elif is_tpu(): + from .tpu import TpuPlatform + current_platform = TpuPlatform() else: - current_platform = None + current_platform = UnspecifiedPlatform() __all__ = ['Platform', 'PlatformEnum', 'current_platform'] diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2ac092c258d15..0760f9554fb78 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,10 +1,14 @@ import enum from typing import Tuple +import torch + class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() + TPU = enum.auto() + UNSPECIFIED = enum.auto() class Platform: @@ -16,6 +20,23 @@ def is_cuda(self) -> bool: def is_rocm(self) -> bool: return self._enum == PlatformEnum.ROCM + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + @staticmethod def get_device_capability(device_id: int = 0) -> Tuple[int, int]: raise NotImplementedError + + @staticmethod + def inference_mode(): + """A device-specific wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such a case, they will fall + back to `torch.no_grad` by overriding this method. + """ + return torch.inference_mode(mode=True) + + +class UnspecifiedPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py new file mode 100644 index 0000000000000..5e32bee1c5511 --- /dev/null +++ b/vllm/platforms/tpu.py @@ -0,0 +1,17 @@ +from typing import Tuple + +import torch + +from .interface import Platform, PlatformEnum + + +class TpuPlatform(Platform): + _enum = PlatformEnum.TPU + + @staticmethod + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + raise RuntimeError("TPU does not have device capability.") + + @staticmethod + def inference_mode(): + return torch.no_grad() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bc7a6a73b17c4..5fb97025af5c0 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,6 +5,7 @@ import torch +from vllm.platforms import current_platform from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -163,7 +164,7 @@ def prepare_model_input( """ raise NotImplementedError - @torch.inference_mode() + @current_platform.inference_mode() def execute_model( self, model_input: T, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8e5c0ededba15..03e3857e23c4b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -9,6 +9,7 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, @@ -53,7 +54,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError - @torch.inference_mode() + @current_platform.inference_mode() def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker.