diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2c5d74e7abcbf..e8456357e6db1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -305,7 +305,7 @@ steps: ##### models test ##### -- label: Basic Models Test # 3min +- label: Basic Models Test # 10min source_file_dependencies: - vllm/ - tests/models @@ -314,23 +314,24 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/*.py --ignore=models/test_oot_registration.py -- label: Decoder-only Language Models Test (Standard) # 35min +- label: Decoder-only Language Models Test (Standard) # 18min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/language commands: - - pytest -v -s models/decoder_only/language/test_models.py + - pytest -v -s models/decoder_only/language -m core_model + - pytest -v -s models/decoder_only/language -m quant_model -- label: Decoder-only Language Models Test (Extended) # 1h20min +- label: Decoder-only Language Models Test (Extended) # 46min nightly: true source_file_dependencies: - vllm/ - tests/models/decoder_only/language commands: - - pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py + - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' -- label: Decoder-only Multi-Modal Models Test (Standard) # 26min +- label: Decoder-only Multi-Modal Models Test (Standard) # 22min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -339,21 +340,24 @@ steps: commands: - pytest -v -s models/decoder_only/audio_language -m core_model - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model + # No tests under this group for now + # - pytest -v -s models/decoder_only/audio_language -m quant_model + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model -- label: Decoder-only Multi-Modal Models Test (Extended) +- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m nightly: true source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language commands: - - pytest -v -s models/decoder_only/audio_language -m 'not core_model' + - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' -- label: Other Models Test # 6min +- label: Other Models Test # 20min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ diff --git a/pyproject.toml b/pyproject.toml index 797e7a88ab31b..3c8c46cc8621e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ markers = [ "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "cpu_model: enable this model test in CPU tests", + "quant_model: run this model test under Quantized category", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "skip_v1: do not run this test with v1", ] diff --git a/tests/models/decoder_only/language/test_aqlm.py b/tests/models/decoder_only/language/test_aqlm.py index de46032113086..a8cb5bbf9349e 100644 --- a/tests/models/decoder_only/language/test_aqlm.py +++ b/tests/models/decoder_only/language/test_aqlm.py @@ -38,6 +38,7 @@ ] +@pytest.mark.quant_model @pytest.mark.skipif(not is_quant_method_supported("aqlm"), reason="AQLM is not supported on this GPU type.") @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) diff --git a/tests/models/decoder_only/language/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py index f874bf6c73142..53f23e24511b3 100644 --- a/tests/models/decoder_only/language/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -15,6 +15,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" +@pytest.mark.quant_model @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize( diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 5dc83942632fd..2b8f5e2faa45e 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -17,26 +17,21 @@ MAX_MODEL_LEN = 1024 -# FIXME: Move this to confest -MODELS = [ - ("meta-llama/Llama-3.2-1B-Instruct", - hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")), - ("meta-llama/Llama-3.2-1B-Instruct", - hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")), - ("Qwen/Qwen2-1.5B-Instruct", - hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF", - filename="qwen2-1_5b-instruct-q4_k_m.gguf")), - ("Qwen/Qwen2-1.5B-Instruct", - hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF", - filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")), -] - @pytest.mark.skipif(not is_quant_method_supported("gguf"), reason="gguf is not supported on this GPU type.") -@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [ + ("meta-llama/Llama-3.2-1B-Instruct", + "bartowski/Llama-3.2-1B-Instruct-GGUF", + "Llama-3.2-1B-Instruct-Q4_K_M.gguf"), + ("meta-llama/Llama-3.2-1B-Instruct", + "bartowski/Llama-3.2-1B-Instruct-GGUF", + "Llama-3.2-1B-Instruct-IQ4_XS.gguf"), + ("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF", + "qwen2-1_5b-instruct-q4_k_m.gguf"), + ("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF", + "Qwen2-1.5B-Instruct.IQ4_XS.gguf"), +]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -45,7 +40,9 @@ def test_models( num_gpus_available, vllm_runner, example_prompts, - model, + original_model, + gguf_id, + gguf_path, dtype: str, max_tokens: int, num_logprobs: int, @@ -54,7 +51,7 @@ def test_models( if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - original_model, gguf_model = model + gguf_model = hf_hub_download(gguf_id, filename=gguf_path) tokenizer = AutoTokenizer.from_pretrained(original_model) messages = [[{ diff --git a/tests/models/decoder_only/language/test_gptq_marlin.py b/tests/models/decoder_only/language/test_gptq_marlin.py index a896f145c11f1..037411a18c19f 100644 --- a/tests/models/decoder_only/language/test_gptq_marlin.py +++ b/tests/models/decoder_only/language/test_gptq_marlin.py @@ -33,6 +33,7 @@ ] +@pytest.mark.quant_model @pytest.mark.flaky(reruns=3) @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="gptq_marlin is not supported on this GPU type.") diff --git a/tests/models/decoder_only/language/test_gptq_marlin_24.py b/tests/models/decoder_only/language/test_gptq_marlin_24.py index aa63f9f36a3a8..26cb3ec310701 100644 --- a/tests/models/decoder_only/language/test_gptq_marlin_24.py +++ b/tests/models/decoder_only/language/test_gptq_marlin_24.py @@ -38,6 +38,7 @@ class ModelPair: ] +@pytest.mark.quant_model @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"), reason="Marlin24 is not supported on this GPU type.") diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index 0b71f0d49c70a..5e93842f46164 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -7,7 +7,9 @@ from ...utils import check_logprobs_close MODELS = [ + # TODO(sang): Sliding window should be tested separately. "ibm/PowerLM-3b", + "ibm/PowerMoE-3b", ] @@ -24,7 +26,6 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/decoder_only/language/test_granitemoe.py b/tests/models/decoder_only/language/test_granitemoe.py deleted file mode 100644 index ba73375229eb3..0000000000000 --- a/tests/models/decoder_only/language/test_granitemoe.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. - -Run `pytest tests/models/test_granite.py`. -""" -import pytest - -from ...utils import check_logprobs_close - -MODELS = [ - "ibm/PowerMoE-3b", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/decoder_only/language/test_modelopt.py b/tests/models/decoder_only/language/test_modelopt.py index e643b115d0ea8..077e50e3a4dfd 100644 --- a/tests/models/decoder_only/language/test_modelopt.py +++ b/tests/models/decoder_only/language/test_modelopt.py @@ -39,6 +39,7 @@ @pytest.mark.skip( reason= "Prevent unstable test based on golden strings from breaking the build.") +@pytest.mark.quant_model @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index d705909c24bf8..beb1ffb18436e 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -1,8 +1,5 @@ """Compare the outputs of HF and vLLM when using greedy sampling. -This test only tests small models. Big models such as 7B should be tested from -test_big_models.py because it could use a larger instance to run tests. - Run `pytest tests/models/test_models.py`. """ import pytest @@ -35,6 +32,7 @@ target_dtype = "half" +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py index c2d3fda6994f6..51c0085101dd0 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py @@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next, ctx.model_config.hf_config.image_grid_pinpoints = gridpoints seq_len = 5000 # bigger than the max feature size for any image - seq_data, mm_data = dummy_data_for_llava_next( + dummy_data = dummy_data_for_llava_next( ctx, seq_len=seq_len, mm_counts={"image": 1}, ) + seq_data = dummy_data.seq_data + mm_data = dummy_data.multi_modal_data # The dummy data dims should match the gridpoint with the biggest feat size assert mm_data["image"].height == expected_size[0] diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index d6a7b34fdde9f..60a8f63eb5faa 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, mm_processor_kwargs=None, ) - sequence_data, _, = dummy_data_for_phi3v( + dummy_data = dummy_data_for_phi3v( ctx=ctx, seq_len=8192, # Should be bigger than num_imgs * toks_per_img mm_counts={"image": num_imgs}, num_crops=num_crops, ) + sequence_data = dummy_data.seq_data # Ensure we have the right number of placeholders per num_crops size img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) assert img_tok_count == toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index c23fbedf0c6ae..7e2bea130583e 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl, # NOTE: video value is required, but isn't actually used # when making the dummy data except for error handling currently - seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, { - "image": 1, - "video": 0 - }, **mm_processor_kwargs) + dummy_data = dummy_data_for_qwen2_vl( + ctx=qwen2_vl_context, + seq_len=seq_len, + mm_counts={ + "image": 1, + "video": 0 + }, + **mm_processor_kwargs, + ) + seq_data = dummy_data.seq_data + mm_data = dummy_data.multi_modal_data # Ensure we have the right number of placeholders for min/max pixel values assert seq_data.get_token_ids().count(image_token_id) == token_count diff --git a/tests/models/decoder_only/vision_language/test_internvl.py b/tests/models/decoder_only/vision_language/test_awq.py similarity index 90% rename from tests/models/decoder_only/vision_language/test_internvl.py rename to tests/models/decoder_only/vision_language/test_awq.py index 2fd1ac4bb08f7..6e6e5b40d6a35 100644 --- a/tests/models/decoder_only/vision_language/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_awq.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Type import pytest import torch @@ -19,7 +19,8 @@ def run_awq_test( vllm_runner: Type[VllmRunner], image_assets: _ImageAssets, - models: Tuple[str, str], + source_model: str, + quant_model: str, *, size_factors: List[float], dtype: str, @@ -28,8 +29,6 @@ def run_awq_test( tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): - source_model, quant_model = models - images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -84,8 +83,11 @@ def run_awq_test( ) +@pytest.mark.quant_model @pytest.mark.parametrize( - "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) + ("source_model", "quant_model"), + [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")], +) @pytest.mark.parametrize( "size_factors", [ @@ -103,12 +105,13 @@ def run_awq_test( @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @torch.inference_mode() -def test_awq_models(vllm_runner, image_assets, models, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_awq_models(vllm_runner, image_assets, source_model, quant_model, + size_factors, dtype, max_tokens, num_logprobs) -> None: run_awq_test( vllm_runner, image_assets, - models, + source_model, + quant_model, size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, diff --git a/tests/models/decoder_only/vision_language/test_intern_vit.py b/tests/models/decoder_only/vision_language/test_intern_vit.py index 98f313eb9b9af..32fcb0bbc42f1 100644 --- a/tests/models/decoder_only/vision_language/test_intern_vit.py +++ b/tests/models/decoder_only/vision_language/test_intern_vit.py @@ -11,21 +11,17 @@ # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] -models = [ - snapshot_download("OpenGVLab/InternViT-300M-448px", - allow_patterns=DOWNLOAD_PATTERN), - snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5", - allow_patterns=DOWNLOAD_PATTERN), -] def run_intern_vit_test( image_assets: _ImageAssets, - model: str, + model_id: str, *, dtype: str, distributed_executor_backend: Optional[str] = None, ): + model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) + img_processor = CLIPImageProcessor.from_pretrained(model) images = [asset.pil_image for asset in image_assets] pixel_values = [ @@ -67,12 +63,15 @@ def run_intern_vit_test( assert cos_similar(vllm_output, hf_output).mean() > 0.99 -@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("model_id", [ + "OpenGVLab/InternViT-300M-448px", + "OpenGVLab/InternViT-6B-448px-V1-5", +]) @pytest.mark.parametrize("dtype", [torch.half]) @torch.inference_mode() -def test_models(dist_init, image_assets, model, dtype: str) -> None: +def test_models(dist_init, image_assets, model_id, dtype: str) -> None: run_intern_vit_test( image_assets, - model, + model_id, dtype=dtype, ) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 1ab42f8c126f8..3f6d8ef42cd5f 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -130,8 +130,8 @@ max_num_seqs=2, auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - marks=[pytest.mark.core_model, pytest.mark.cpu_model], image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), #### Extended model tests "blip2": VLMTestInfo( @@ -159,9 +159,9 @@ dtype="bfloat16", marks=[ pytest.mark.skipif( - transformers.__version__.startswith("4.46"), + transformers.__version__ < "4.46.2", reason="Model broken in HF, see huggingface/transformers#34379" - ) + ), ] ), "fuyu": VLMTestInfo( @@ -185,8 +185,8 @@ max_num_seqs=2, dtype="bfloat16", get_stop_token_ids=lambda tok: [151329, 151336, 151338], - marks=[large_gpu_mark(min_gb=48)], patch_hf_runner=model_utils.glm_patch_hf_runner, + marks=[large_gpu_mark(min_gb=48)], ), "h2ovl": VLMTestInfo( models = [ @@ -205,6 +205,22 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.h2ovl_patch_hf_runner, ), + "idefics3": VLMTestInfo( + models=["HuggingFaceM4/Idefics3-8B-Llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + marks=[ + pytest.mark.skipif( + transformers.__version__ < "4.46.0", + reason="Model introduced in HF >= 4.46.0" + ), + large_gpu_mark(min_gb=48), + ], + ), "intern_vl": VLMTestInfo( models=[ "OpenGVLab/InternVL2-1B", @@ -263,7 +279,6 @@ runner_mm_key="videos", )], ), - # FIXME "llava_next_video": VLMTestInfo( models=["llava-hf/LLaVA-NeXT-Video-7B-hf"], test_type=VLMTestType.VIDEO, @@ -275,7 +290,7 @@ image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], marks=[ pytest.mark.skipif( - transformers.__version__.startswith("4.46"), + transformers.__version__ < "4.46.2", reason="Model broken with changes in transformers 4.46" ) ], @@ -316,6 +331,7 @@ max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForVision2Seq, + marks=[large_gpu_mark(min_gb=48)], ), "qwen": VLMTestInfo( models=["Qwen/Qwen-VL"], @@ -327,22 +343,6 @@ vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, ), - "idefics3": VLMTestInfo( - models=["HuggingFaceM4/Idefics3-8B-Llama3"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 - img_idx_to_prompt=lambda idx: "", - max_model_len=8192, - max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, - marks=[ - pytest.mark.skipif( - transformers.__version__ < "4.46.0", - reason="Model introduced in HF >= 4.46.0" - ), - large_gpu_mark(min_gb=48), - ], - ), ### Tensor parallel / multi-gpu broadcast tests "broadcast-chameleon": VLMTestInfo( models=["facebook/chameleon-7b"], @@ -362,7 +362,7 @@ reason="Need at least 2 GPUs to run the test.", ), pytest.mark.skipif( - transformers.__version__.startswith("4.46"), + transformers.__version__ < "4.46.2", reason="Model broken in HF, see huggingface/transformers#34379" ) ], diff --git a/vllm/config.py b/vllm/config.py index b902499bf5bdc..f9b230e1bc688 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,8 @@ +import copy import enum import json import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -2078,6 +2079,12 @@ def _get_quantization_config( return quant_config return None + def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + def __post_init__(self): """Verify configs are valid & consistent with each other. """ diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cac10f505df67..37f38d4d76671 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -246,9 +245,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: quant_config=quant_config, gather_output=True, ) - self.language_model = PersimmonForCausalLM(config.text_config, - cache_config=cache_config, - quant_config=quant_config) + self.language_model = PersimmonForCausalLM( + vllm_config.with_hf_config(config.text_config)) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index f7bc823574034..51e2c64d5552d 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -164,10 +164,12 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", ) -> None: + super().__init__(vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - super().__init__(config, cache_config, quant_config) + self.model = InternLM2VEModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 60eeceb18bcf0..ca4fc8ec952bf 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -241,11 +241,11 @@ def init_vllm_registered_model( based on the arguments passed to the outer vLLM model. """ model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) - import copy - copied_config = copy.deepcopy(vllm_config) - copied_config.model_config.hf_config = hf_config - return model_class(vllm_config=copied_config, prefix=prefix) + return model_class( + vllm_config=vllm_config.with_hf_config(hf_config), + prefix=prefix, + ) @overload