Skip to content

Commit

Permalink
[misc][cuda] use nvml to avoid accidentally cuda initialization (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and jimpang committed Jul 24, 2024
1 parent b0a3b5f commit b7abcfa
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 68 deletions.
3 changes: 2 additions & 1 deletion tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]


Expand Down
3 changes: 2 additions & 1 deletion tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import get_device_capability_stateless


def is_quant_method_supported(quant_method: str) -> bool:
# Currently, all quantization methods require Nvidia or AMD GPUs
if not torch.cuda.is_available():
return False

capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
6 changes: 3 additions & 3 deletions vllm/attention/ops/blocksparse_attention/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import torch

from vllm.utils import is_cpu, is_hip
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip

from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)

IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8)
and get_device_capability_stateless()[0] >= 8)

if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
Expand Down Expand Up @@ -235,4 +235,4 @@ def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale)
sm_scale=sm_scale)
4 changes: 3 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton
import triton.language as tl

from vllm.utils import get_device_capability_stateless

if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -683,7 +685,7 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None):

cap = torch.cuda.get_device_capability()
cap = get_device_capability_stateless()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
Expand Down
58 changes: 5 additions & 53 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,66 +11,18 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from vllm.utils import cuda_device_count_stateless, is_full_nvlink

try:
import pynvml

# Simulate ImportError if custom_ar ops are not supported.
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)

assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
custom_ar = True

@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()

except ImportError:
# For AMD GPUs
except Exception:
# For AMD GPUs and CPUs
custom_ar = False
pynvml = None

@contextmanager
def _nvml():
try:
yield
finally:
pass


logger = init_logger(__name__)


@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
Expand Down Expand Up @@ -161,7 +113,7 @@ def __init__(self,
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids)
full_nvlink = is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless


def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return

if torch.cuda.get_device_capability() < (8, 0):
if get_device_capability_stateless() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
from vllm.utils import get_device_capability_stateless


class CompressedTensorsConfig(QuantizationConfig):
Expand Down Expand Up @@ -84,7 +85,7 @@ def get_config_filenames(cls) -> List[str]:
return []

def _check_gptq_and_marlin_can_run(self):
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from vllm.utils import get_device_capability_stateless, print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)


def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]

return ops.cutlass_scaled_mm_supports_fp8(capability)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import get_device_capability_stateless

logger = init_logger(__name__)

Expand Down Expand Up @@ -165,7 +166,7 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
major, minor = get_device_capability_stateless()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.utils import get_device_capability_stateless

__cuda_arch = torch.cuda.get_device_capability()
__cuda_arch = get_device_capability_stateless()

MARLIN_TILE = 16

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu
from vllm.utils import get_device_capability_stateless, is_tpu

logger = init_logger(__name__)

Expand All @@ -46,7 +46,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
Expand Down
57 changes: 57 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,63 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

try:
import pynvml
except ImportError:
# For non-NV devices
pynvml = None


def with_nvml_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
if pynvml is not None:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
if pynvml is not None:
pynvml.nvmlShutdown()

return wrapper


@with_nvml_context
def is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_device_capability_stateless
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -322,7 +323,7 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
compute_capability = get_device_capability_stateless()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
Expand Down

0 comments on commit b7abcfa

Please sign in to comment.