Skip to content

Commit

Permalink
[Misc] Add a wrapper for torch.inference_mode (#6618)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 22, 2024
1 parent c9eef37 commit 42de2ce
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 4 deletions.
9 changes: 7 additions & 2 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from .interface import Platform, PlatformEnum
from vllm.utils import is_tpu

from .interface import Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Optional[Platform]

Expand All @@ -12,7 +14,10 @@
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = None
current_platform = UnspecifiedPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform']
21 changes: 21 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import enum
from typing import Tuple

import torch


class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
UNSPECIFIED = enum.auto()


class Platform:
Expand All @@ -16,6 +20,23 @@ def is_cuda(self) -> bool:
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM

def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU

@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError

@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
17 changes: 17 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Tuple

import torch

from .interface import Platform, PlatformEnum


class TpuPlatform(Platform):
_enum = PlatformEnum.TPU

@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise RuntimeError("TPU does not have device capability.")

@staticmethod
def inference_mode():
return torch.no_grad()
3 changes: 2 additions & 1 deletion vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)

Expand Down Expand Up @@ -163,7 +164,7 @@ def prepare_model_input(
"""
raise NotImplementedError

@torch.inference_mode()
@current_platform.inference_mode()
def execute_model(
self,
model_input: T,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread,
Expand Down Expand Up @@ -53,7 +54,7 @@ def initialize_cache(self, num_gpu_blocks: int,
"""
raise NotImplementedError

@torch.inference_mode()
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
Expand Down

0 comments on commit 42de2ce

Please sign in to comment.