diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 92f6053cc6d7e..7acea6087fdfd 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -3,8 +3,8 @@ import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - seed_everything) +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser @torch.inference_mode() @@ -16,7 +16,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device("cuda") layer = RMSNorm(hidden_size).to(dtype=dtype) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 4f88e8e6eb1a6..8f538c21f7f7e 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,8 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser, seed_everything +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser class BenchmarkConfig(TypedDict): @@ -167,7 +168,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - seed_everything(seed) + current_platform.seed_everything(seed) self.seed = seed def benchmark( @@ -181,7 +182,7 @@ def benchmark( use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - seed_everything(self.seed) + current_platform.seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 87864d038d593..14eef00b855ac 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,8 +5,9 @@ import torch from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random, seed_everything) + create_kv_caches_with_random) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,7 +29,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 743a5744e8614..1d62483448946 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -3,8 +3,8 @@ import torch from vllm import _custom_ops as ops -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - seed_everything) +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser @torch.inference_mode() @@ -17,7 +17,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device("cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 784b1cf9844e4..250d505168d09 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser, seed_everything +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser def benchmark_rope_kernels_multi_lora( @@ -22,7 +23,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 0e3d3c3a2e987..057a11746014c 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, NewGELU, QuickGELU, SiluAndMul) -from vllm.utils import seed_everything +from vllm.platforms import current_platform from .allclose_default import get_default_atol, get_default_rtol @@ -37,7 +37,7 @@ def test_act_and_mul( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) if activation == "silu": @@ -85,7 +85,7 @@ def test_activation( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) layer = activation[0]() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 1604aa4d2d6e5..4ecd0fc1a21ad 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes, seed_everything +from vllm.utils import get_max_shared_memory_bytes from .allclose_default import get_default_atol, get_default_rtol @@ -144,7 +144,7 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -382,7 +382,7 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index e95e5bd948212..406a0c8dd8080 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) -from vllm.utils import seed_everything +from vllm.platforms import current_platform device = "cuda" @@ -80,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): zeros_cols = qweight_cols zeros_dtype = torch.int32 - seed_everything(0) + current_platform.seed_everything(0) qweight = torch.randint(0, torch.iinfo(torch.int32).max, @@ -134,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size): qzeros_rows = scales_rows qzeros_cols = qweight_cols - seed_everything(0) + current_platform.seed_everything(0) input = torch.rand((input_rows, input_cols), dtype=input_dtype, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index b65efb3abc230..fb601852dd523 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -8,7 +8,7 @@ from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes, seed_everything +from vllm.utils import get_max_shared_memory_bytes from .allclose_default import get_default_atol, get_default_rtol @@ -173,7 +173,7 @@ def test_paged_attention( blocksparse_block_size: int, blocksparse_head_sliding_step: int, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -384,7 +384,7 @@ def test_varlen_blocksparse_attention_prefill( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index b0e7097fdfbd4..5b8311a33c361 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,7 +6,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops -from vllm.utils import seed_everything +from vllm.platforms import current_platform COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -56,7 +56,7 @@ def test_copy_blocks( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. @@ -132,7 +132,7 @@ def test_reshape_and_cache( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks @@ -224,7 +224,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. @@ -339,7 +339,7 @@ def test_swap_blocks( if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - seed_everything(seed) + current_platform.seed_everything(seed) src_device = device if direction[0] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu' @@ -408,7 +408,7 @@ def test_fp8_e4m3_conversion( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) low = -224.0 high = 224.0 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 277d7e4977d73..96bfe06d74ae5 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -9,7 +9,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) -from vllm.utils import seed_everything +from vllm.platforms import current_platform def causal_conv1d_ref( @@ -70,7 +70,7 @@ def causal_conv1d_update_ref(x, bias: (dim,) cache_seqlens: (batch,), dtype int32. If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the + The conv_state will be updated by copying x to the conv_state starting at the index @cache_seqlens % state_len before performing the convolution. @@ -161,7 +161,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - seed_everything(0) + current_platform.seed_everything(0) x = torch.randn(batch, dim, seqlen, device=device, dtype=itype).contiguous() @@ -223,7 +223,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - seed_everything(0) + current_platform.seed_everything(0) batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) x_ref = x.clone() @@ -270,7 +270,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, rtol, atol = 1e-2, 5e-2 # set seed - seed_everything(0) + current_platform.seed_everything(0) batch_size = 3 padding = 5 if with_padding else 0 @@ -343,7 +343,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - seed_everything(0) + current_platform.seed_everything(0) seqlens = [] batch_size = 4 if seqlen < 10: diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 35c29c5bd1028..a20c73345218f 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -3,7 +3,7 @@ import pytest import torch -from vllm.utils import seed_everything +from vllm.platforms import current_platform from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) @@ -91,7 +91,7 @@ def test_flash_attn_with_paged_kv( sliding_window: Optional[int], ) -> None: torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -161,7 +161,7 @@ def test_varlen_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 80a388db6530e..a2c8f71665737 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm.utils import seed_everything +from vllm.platforms import current_platform NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] @@ -84,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv( soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -170,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -268,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -381,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") - seed_everything(0) + current_platform.seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index c18f5f468dc5a..ebaaae2321885 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -6,7 +6,7 @@ ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) from tests.kernels.utils import opcheck -from vllm.utils import seed_everything +from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, @@ -46,7 +46,7 @@ def opcheck_fp8_quant(output, def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans @@ -76,7 +76,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -95,7 +95,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() @pytest.mark.parametrize("seed", SEEDS) def test_fp8_quant_large(seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings hidden_size = 1152 # Smallest hidden_size to reproduce the error diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index 1513fc196153c..893af99ba4977 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -7,7 +7,7 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops -from vllm.utils import seed_everything +from vllm.platforms import current_platform GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") @@ -75,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - seed_everything(0) + current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") @@ -111,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - seed_everything(0) + current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 41e103e1d09f9..8db6a0d0d9fa4 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,7 +4,7 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant -from vllm.utils import seed_everything +from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -45,7 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @torch.inference_mode() def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,7 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, @@ -112,7 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -138,7 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float, azp: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 382079d472ee9..9dfa2cbe45e94 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,7 +3,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import seed_everything +from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing @@ -31,7 +31,7 @@ def test_rms_norm( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index e92d401368a7b..bf7ff3b5c59b8 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -8,7 +8,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.utils import seed_everything +from vllm.platforms import current_platform def selective_state_update_ref(state, @@ -235,7 +235,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed - seed_everything(0) + current_platform.seed_everything(0) batch_size = 1 dim = 4 dstate = 8 @@ -358,7 +358,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): if torch.version.hip: atol *= 2 # set seed - seed_everything(0) + current_platform.seed_everything(0) batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 70906ab2187bc..19c3fc1e1fe3a 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -19,7 +19,6 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import seed_everything @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @@ -115,7 +114,7 @@ def test_fused_marlin_moe( num_bits: int, is_k_full: bool, ): - seed_everything(7) + current_platform.seed_everything(7) # Filter act_order if act_order: diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 94da00915d40e..b408559cc0b07 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,7 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.utils import seed_everything +from vllm.platforms import current_platform from .allclose_default import get_default_atol, get_default_rtol @@ -48,7 +48,7 @@ def test_rotary_embedding( if rotary_dim is None: rotary_dim = head_size - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -100,7 +100,7 @@ def test_batched_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -160,7 +160,7 @@ def test_batched_rotary_embedding_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 3181d92562399..a8a187ebaede4 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,7 +9,8 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -39,7 +40,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: - seed_everything(0) + current_platform.seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: - seed_everything(0) + current_platform.seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index db877219a285c..eb882faf3974a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed -from vllm.utils import seed_everything +from vllm.platforms import current_platform from .utils import DummyLoRAManager @@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.float16 seed = 0 - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 41c37a4813c68..e756544d96e98 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -1,5 +1,5 @@ """ -This script is mainly used to tests various hidden_sizes. We have collected the +This script is mainly used to tests various hidden_sizes. We have collected the hidden_sizes included in the LoRA models currently supported by vLLM. It tests whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. @@ -15,8 +15,8 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.platforms import current_platform from vllm.triton_utils.libentry import LibEntry -from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -146,7 +146,7 @@ def test_punica_sgmv( device: str, ): torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 128 ( @@ -239,7 +239,7 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 1 ( @@ -327,7 +327,7 @@ def test_punica_expand_nslices( from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 128 if op_type == "sgmv" else 1 ( diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 185da6399a06a..dc0edeb10ef46 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -1,6 +1,6 @@ """ -This script is mainly used to test whether trtion kernels can run normally -under different conditions, including various batches, numbers of LoRA , and +This script is mainly used to test whether trtion kernels can run normally +under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ from unittest.mock import patch @@ -14,8 +14,8 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.platforms import current_platform from vllm.triton_utils.libentry import LibEntry -from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -61,7 +61,7 @@ def test_punica_sgmv( device: str, ): torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 128 ( @@ -154,7 +154,7 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 1 ( @@ -242,7 +242,7 @@ def test_punica_expand_nslices( from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel torch.set_default_device(device) - seed_everything(seed) + current_platform.seed_everything(seed) seq_length = 128 if op_type == "sgmv" else 1 ( diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index c27b1cf6ac7b9..39ead08c238ce 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -4,11 +4,10 @@ import torch from vllm.platforms import current_platform -from vllm.utils import seed_everything def set_random_seed(seed: int) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) def set_weight_attrs( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 7c933385d6ff6..c3a3e7a284457 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,6 +1,8 @@ import enum +import random from typing import NamedTuple, Optional, Tuple, Union +import numpy as np import torch @@ -111,6 +113,18 @@ def inference_mode(cls): """ return torch.inference_mode(mode=True) + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/utils.py b/vllm/utils.py index c3f9a6bdd8b80..fea318ebcdf41 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,7 +7,6 @@ import inspect import ipaddress import os -import random import socket import subprocess import sys @@ -331,22 +330,6 @@ def get_cpu_memory() -> int: return psutil.virtual_memory().total -def seed_everything(seed: int) -> None: - """ - Set the seed of each random module. - - Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 - """ - random.seed(seed) - np.random.seed(seed) - - if current_platform.is_cuda_alike(): - torch.cuda.manual_seed_all(seed) - - if current_platform.is_xpu(): - torch.xpu.manual_seed_all(seed) - - def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - seed_everything(seed) + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) @@ -685,7 +668,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) - seed_everything(seed) + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)