From 54be8a0be2819340ce7c2d7993382559597f5665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=BA=8F?= Date: Fri, 15 Mar 2024 04:56:57 +0800 Subject: [PATCH] Fix assertion failure in Qwen 1.5 with prefix caching enabled (#3373) Co-authored-by: Cade Daniel --- tests/test_config.py | 43 +++++++++++++++++++++++++++++++++++++++++++ vllm/config.py | 14 ++++++++++++-- 2 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 tests/test_config.py diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000000000..13a9f76212679 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +from vllm.config import ModelConfig + + +def test_get_sliding_window(): + TEST_SLIDING_WINDOW = 4096 + # Test that the sliding window is correctly computed. + # For Qwen1.5/Qwen2, get_sliding_window() should be None + # when use_sliding_window is False. + qwen2_model_config = ModelConfig( + "Qwen/Qwen1.5-7B", + "Qwen/Qwen1.5-7B", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + ) + + qwen2_model_config.hf_config.use_sliding_window = False + qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW + assert qwen2_model_config.get_sliding_window() is None + + qwen2_model_config.hf_config.use_sliding_window = True + assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + mistral_model_config = ModelConfig( + "mistralai/Mistral-7B-v0.1", + "mistralai/Mistral-7B-v0.1", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + ) + mistral_model_config.hf_config.sliding_window = None + assert mistral_model_config.get_sliding_window() is None + + mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 319c1569f5e98..de687395a0001 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -103,6 +103,7 @@ def __init__( # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C + if not os.path.exists(model): model_path = snapshot_download(model_id=model, cache_dir=download_dir, @@ -139,7 +140,7 @@ def _verify_load_format(self) -> None: if (f not in rocm_not_supported_load_format) ] raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " + f"load format '{load_format}' is not supported in ROCm. " f"Supported load format are " f"{rocm_supported_load_format}") @@ -232,6 +233,15 @@ def verify_with_parallel_config( f"({pipeline_parallel_size}).") def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_config, "use_sliding_window") + and not self.hf_config.use_sliding_window): + return None return getattr(self.hf_config, "sliding_window", None) def get_vocab_size(self) -> int: @@ -624,7 +634,7 @@ def _get_and_verify_dtype( k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE) ] - raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " + raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " f"Supported dtypes are {rocm_supported_dtypes}") # Verify the dtype.