forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bugfix / Core] Prefix Caching Guards (merged with main) (vllm-projec…
…t#4846) Co-authored-by: rsnm2 <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
- Loading branch information
Showing
11 changed files
with
167 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Compare the with and without prefix caching. | ||
Run `pytest tests/prefix_caching/test_prefix_caching.py`. | ||
""" | ||
import pytest | ||
|
||
from tests.conftest import cleanup | ||
from vllm import LLM | ||
|
||
MODEL_LEN_LEN = [ | ||
# Example models with sliding window. | ||
("bigcode/starcoder2-3b", 4096, 16384), | ||
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI | ||
|
||
# Confirm model with sliding window works. | ||
# config has "use_sliding_window": false | ||
("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), | ||
# config has no sliding window attribute. | ||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) | ||
def test_disable_sliding_window(model_len_len, ): | ||
model, sliding_len, full_len = model_len_len | ||
vllm_disabled_model = LLM(model, disable_sliding_window=True) | ||
vllm_disabled_model.generate("Hi my name is") | ||
model_config = vllm_disabled_model.llm_engine.model_config | ||
assert model_config.max_model_len == sliding_len, ( | ||
"Max len expected to equal sliding_len of %s, but got %s", sliding_len, | ||
model_config.max_model_len) | ||
|
||
del vllm_disabled_model | ||
cleanup() | ||
|
||
vllm_enabled_model = LLM(model, disable_sliding_window=False) | ||
vllm_enabled_model.generate("Hi my name is") | ||
model_config = vllm_enabled_model.llm_engine.model_config | ||
assert model_config.max_model_len == full_len, ( | ||
"Max len expected to equal full_len of %s, but got %s", full_len, | ||
model_config.max_model_len) | ||
|
||
del vllm_enabled_model | ||
cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.