Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assertion failure in Qwen 1.5 with prefix caching enabled #3373

Merged
merged 7 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from vllm.config import ModelConfig


def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)

qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None

qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None

mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
10 changes: 8 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C

if not os.path.exists(model):
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
Expand Down Expand Up @@ -139,7 +140,7 @@ def _verify_load_format(self) -> None:
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")

Expand Down Expand Up @@ -232,6 +233,11 @@ def verify_with_parallel_config(
f"({pipeline_parallel_size}).")

def get_sliding_window(self) -> Optional[int]:
# For Qwen2 and Qwen1.5
# sliding_window is not None, but use_sliding_window is disabled.
chenxu2048 marked this conversation as resolved.
Show resolved Hide resolved
if (hasattr(self.hf_config, "use_sliding_window")
and not self.hf_config.use_sliding_window):
return None
chenxu2048 marked this conversation as resolved.
Show resolved Hide resolved
return getattr(self.hf_config, "sliding_window", None)

def get_vocab_size(self) -> int:
Expand Down Expand Up @@ -624,7 +630,7 @@ def _get_and_verify_dtype(
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")

# Verify the dtype.
Expand Down
Loading