diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 9cc18a0ea4611..d409df34ee5e5 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -8,13 +8,13 @@ import torch from vllm import _custom_ops as ops -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -capability = get_device_capability_stateless() +capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 5c1b5ad9b9c07..65bb80ed70c6a 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,7 +1,7 @@ import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform def is_quant_method_supported(quant_method: str) -> bool: @@ -9,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool: if not torch.cuda.is_available(): return False - capability = get_device_capability_stateless() + capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] return (capability >= QUANTIZATION_METHODS[quant_method].get_min_capability()) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index 637cfda214964..e870a8e614d12 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -2,13 +2,14 @@ import torch -from vllm.utils import get_device_capability_stateless, is_cpu, is_hip +from vllm.platforms import current_platform +from vllm.utils import is_cpu, is_hip from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() - and get_device_capability_stateless()[0] >= 8) + and current_platform.get_device_capability()[0] >= 8) if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index ca9f28fcb7013..4cd4976ade729 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -5,7 +5,7 @@ import triton import triton.language as tl -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform if triton.__version__ >= "2.1.0": @@ -685,7 +685,7 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - cap = get_device_capability_stateless() + cap = current_platform.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index f30b2c13f28b8..64f87a4b2c69d 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -5,14 +5,14 @@ import torch from vllm import _custom_ops as ops -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform def _check_punica_support(): if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): return - if get_device_capability_stateless() < (8, 0): + if current_platform.get_device_capability() < (8, 0): raise ImportError( "punica LoRA kernels require compute capability >= 8.0") else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 491396c3da6f7..e88bbc361a5e0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform class CompressedTensorsConfig(QuantizationConfig): @@ -85,7 +85,7 @@ def get_config_filenames(cls) -> List[str]: return [] def _check_gptq_and_marlin_can_run(self): - capability = get_device_capability_stateless() + capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < 80: raise RuntimeError("The quantization config is not supported for ", diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dc2ca35c6d2c0..6d942fa611c7b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import get_device_capability_stateless, print_warning_once +from vllm.platforms import current_platform +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -20,7 +21,7 @@ def cutlass_fp8_supported() -> bool: - capability = get_device_capability_stateless() + capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] return ops.cutlass_scaled_mm_supports_fp8(capability) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 97aae33f133be..a6284d0ed7b1b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -173,7 +173,7 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]): return False # If the capability of the device is too low, cannot convert. - major, minor = get_device_capability_stateless() + major, minor = current_platform.get_device_capability() device_capability = major * 10 + minor if device_capability < cls.get_min_capability(): return False diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 6f4aa2d77e680..ecd29a80e6c03 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -12,9 +12,9 @@ marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) -from vllm.utils import get_device_capability_stateless +from vllm.platforms import current_platform -__cuda_arch = get_device_capability_stateless() +__cuda_arch = current_platform.get_device_capability() MARLIN_TILE = 16 diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b61ac7490d1f6..6f4dcf4a03c35 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -35,7 +35,8 @@ from vllm.model_executor.models.interfaces import (supports_lora, supports_vision) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import get_device_capability_stateless, is_tpu +from vllm.platforms import current_platform +from vllm.utils import is_tpu logger = init_logger(__name__) @@ -46,7 +47,7 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = get_device_capability_stateless() + capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): raise ValueError( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py new file mode 100644 index 0000000000000..7309f7bf795d6 --- /dev/null +++ b/vllm/platforms/__init__.py @@ -0,0 +1,18 @@ +from typing import Optional + +import torch + +from .interface import Platform, PlatformEnum + +current_platform: Optional[Platform] + +if torch.version.cuda is not None: + from .cuda import CudaPlatform + current_platform = CudaPlatform() +elif torch.version.hip is not None: + from .rocm import RocmPlatform + current_platform = RocmPlatform() +else: + current_platform = None + +__all__ = ['Platform', 'PlatformEnum', 'current_platform'] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py new file mode 100644 index 0000000000000..b2ca758131e92 --- /dev/null +++ b/vllm/platforms/cuda.py @@ -0,0 +1,34 @@ +"""Code inside this file can safely assume cuda platform, e.g. importing +pynvml. However, it should not initialize cuda context. +""" + +from functools import lru_cache, wraps +from typing import Tuple + +import pynvml + +from .interface import Platform, PlatformEnum + + +def with_nvml_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +class CudaPlatform(Platform): + _enum = PlatformEnum.CUDA + + @staticmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return pynvml.nvmlDeviceGetCudaComputeCapability(handle) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py new file mode 100644 index 0000000000000..2ac092c258d15 --- /dev/null +++ b/vllm/platforms/interface.py @@ -0,0 +1,21 @@ +import enum +from typing import Tuple + + +class PlatformEnum(enum.Enum): + CUDA = enum.auto() + ROCM = enum.auto() + + +class Platform: + _enum: PlatformEnum + + def is_cuda(self) -> bool: + return self._enum == PlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._enum == PlatformEnum.ROCM + + @staticmethod + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + raise NotImplementedError diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py new file mode 100644 index 0000000000000..36b3ba8f7d1bb --- /dev/null +++ b/vllm/platforms/rocm.py @@ -0,0 +1,15 @@ +from functools import lru_cache +from typing import Tuple + +import torch + +from .interface import Platform, PlatformEnum + + +class RocmPlatform(Platform): + _enum = PlatformEnum.ROCM + + @staticmethod + @lru_cache(maxsize=8) + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device_id) diff --git a/vllm/utils.py b/vllm/utils.py index 1977bc05d75e2..763b0b91c8646 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -866,13 +866,6 @@ def is_full_nvlink(device_ids: List[int]) -> bool: return True -@lru_cache(maxsize=8) -@with_nvml_context -def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return pynvml.nvmlDeviceGetCudaComputeCapability(handle) - - #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5b57282909914..b25f29f485d95 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,8 +15,8 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_device_capability_stateless from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -333,7 +333,7 @@ def init_worker_distributed_environment( def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: - compute_capability = get_device_capability_stateless() + compute_capability = current_platform.get_device_capability() if compute_capability[0] < 8: gpu_name = torch.cuda.get_device_name() raise ValueError(