Skip to content

Commit

Permalink
[Platforms] refactor xpu code
Browse files Browse the repository at this point in the history
Signed-off-by: MengqingCao <[email protected]>
  • Loading branch information
MengqingCao committed Nov 20, 2024
1 parent d200972 commit 85b99b1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 27 deletions.
27 changes: 0 additions & 27 deletions vllm/executor/xpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Callable, List, Optional, Tuple, Type, Union

import torch

from vllm.config import ModelConfig, ParallelConfig
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
Expand All @@ -23,7 +20,6 @@ def _init_executor(self) -> None:
assert self.speculative_config is None, (
"Speculative decoding not yet supported for XPU backend")

self.model_config = _verify_and_get_model_config(self.model_config)
GPUExecutor._init_executor(self)

def _get_worker_module_and_class(
Expand Down Expand Up @@ -53,26 +49,3 @@ async def execute_model_async(
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output


def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
config.dtype = torch.float16
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
config.enforce_eager = True
return config


def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
if (config.distributed_executor_backend is not None
and config.distributed_executor_backend != "ray"):
logger.warning(
"%s is not supported on XPU, fallback to ray distributed executor "
"backend.", config.distributed_executor_backend)
config.distributed_executor_backend = "ray"
return config
21 changes: 21 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import TYPE_CHECKING

import torch

from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None

logger = init_logger(__name__)


Expand Down Expand Up @@ -34,3 +41,17 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
@staticmethod
def inference_mode():
return torch.no_grad()

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True

0 comments on commit 85b99b1

Please sign in to comment.