Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix / Core] Prefix Caching Guards (merged with main) #4846

Merged
merged 32 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
56680e7
added guards for prefix-caching. added ability to disable sliding window
robertgshaw2-neuralmagic Apr 7, 2024
64aac2e
format.sh
robertgshaw2-neuralmagic Apr 7, 2024
28ae0cc
added tests
robertgshaw2-neuralmagic Apr 7, 2024
2a01ae6
Merge remote-tracking branch 'upstream/main' into prefix-caching-guards
robertgshaw2-neuralmagic Apr 28, 2024
cd0f666
merge
robertgshaw2-neuralmagic Apr 28, 2024
f30c3de
removed images
robertgshaw2-neuralmagic Apr 28, 2024
da5a982
fixed bad merge
robertgshaw2-neuralmagic Apr 28, 2024
8502b6a
./format
robertgshaw2-neuralmagic Apr 28, 2024
1bef541
validated that prefix caching working on turing with recent update
robertgshaw2-neuralmagic Apr 28, 2024
6620e53
Merge branch 'main' into prefix-caching-guards
zhuohan123 May 16, 2024
8efc774
Merge branch 'main' into prefix-caching-guards-new
robertgshaw2-neuralmagic May 23, 2024
033c2c5
updated to remove sliding window usage in models
robertgshaw2-neuralmagic May 23, 2024
0638960
removed spurious changes
robertgshaw2-neuralmagic May 24, 2024
63c0097
revert change to make PR easier to read
robertgshaw2-neuralmagic May 24, 2024
4a3630c
more cleanup
robertgshaw2-neuralmagic May 24, 2024
3f73426
more cleanup
robertgshaw2-neuralmagic May 24, 2024
6f754c3
stash
robertgshaw2-neuralmagic May 24, 2024
a497b7b
cleanup prints and comments to match current for easier review
robertgshaw2-neuralmagic May 24, 2024
9fd64fe
more cleanup for PR readibility
robertgshaw2-neuralmagic May 24, 2024
1126c5a
more cleanup for PR readibility
robertgshaw2-neuralmagic May 24, 2024
8a53180
more cleanup for PR readibility
robertgshaw2-neuralmagic May 24, 2024
034bbde
removed from mixtral, need to fix qwen
robertgshaw2-neuralmagic May 24, 2024
b56352b
updated models to remove sliding window. Big update to qwen to preven…
robertgshaw2-neuralmagic May 24, 2024
8225c3f
format
robertgshaw2-neuralmagic May 24, 2024
13797c1
added test and fixed requirements dev
robertgshaw2-neuralmagic May 24, 2024
37efe98
added test
robertgshaw2-neuralmagic May 24, 2024
7b186c2
format
robertgshaw2-neuralmagic May 27, 2024
93bce37
updated comment
robertgshaw2-neuralmagic May 27, 2024
7c8a9d0
updated test
robertgshaw2-neuralmagic May 27, 2024
4285763
format
robertgshaw2-neuralmagic May 27, 2024
84253fd
fixed logging
robertgshaw2-neuralmagic May 27, 2024
7a61f51
Update test_disable_sliding_window.py
robertgshaw2-neuralmagic May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/prefix_caching/test_disable_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest

from tests.conftest import cleanup
from vllm import LLM

MODEL_LEN_LEN = [
# Example models with sliding window.
("bigcode/starcoder2-3b", 4096, 16384),
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI

# Confirm model with sliding window works.
# config has "use_sliding_window": false
("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768),
# config has no sliding window attribute.
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048),
]


@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN)
def test_disable_sliding_window(model_len_len, ):
model, sliding_len, full_len = model_len_len
vllm_disabled_model = LLM(model, disable_sliding_window=True)
vllm_disabled_model.generate("Hi my name is")
model_config = vllm_disabled_model.llm_engine.model_config
assert model_config.max_model_len == sliding_len, (
"Max len expected to equal sliding_len of %s, but got %s", sliding_len,
model_config.max_model_len)

del vllm_disabled_model
cleanup()

vllm_enabled_model = LLM(model, disable_sliding_window=False)
vllm_enabled_model.generate("Hi my name is")
model_config = vllm_enabled_model.llm_engine.model_config
assert model_config.max_model_len == full_len, (
"Max len expected to equal full_len of %s, but got %s", full_len,
model_config.max_model_len)

del vllm_enabled_model
cleanup()
24 changes: 24 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
import pytest

from vllm.config import ModelConfig

MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
disable_sliding_window=True,
)
assert model_config.max_model_len == expected


def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ def __init__(
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
if num_kv_heads is None:
num_kv_heads = num_heads

Expand Down
67 changes: 64 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class ModelConfig:
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
Expand Down Expand Up @@ -96,6 +100,7 @@ def __init__(
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
Expand All @@ -118,14 +123,18 @@ def __init__(
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
if not self.skip_tokenizer_init:
Expand Down Expand Up @@ -220,7 +229,7 @@ def verify_with_parallel_config(
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")

def get_sliding_window(self) -> Optional[int]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""

Expand All @@ -232,6 +241,15 @@ def get_sliding_window(self) -> Optional[int]:
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work for the model that already has sliding window like mistral?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure what you mean?

If the user does not specify --disable-sliding-window then we use sliding window if the model supports it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh maybe it is a dumb question, but my question is for models that has slinding window by default https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/26bca36bde8333b5d7f72e9ed20ccda6a618af24/config.json#L18, if we use --disable-sliding-window, does it work properly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, specifically what this does is handle a case like Mistral.

--disable-sliding-window means we turn off sliding window and set max_model_len=sliding_window

So in the case of Mistral, we then would treat the model as a 4096 ctx-len model with no sliding window.

The reason for this feature is that if we want to use features that are incompatible with sliding window (e.g. APC or chunked prefill), then there is a pathway to disable sliding window

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. that makes sense! Thanks for the explanation

"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size

Expand Down Expand Up @@ -336,6 +354,7 @@ def __init__(
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

# Will be set after profiling.
self.num_gpu_blocks = None
Expand Down Expand Up @@ -364,6 +383,19 @@ def _verify_cache_dtype(self) -> None:
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

def _verify_prefix_caching(self) -> None:
if not self.enable_prefix_caching:
return

if self.sliding_window is not None:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.cache_dtype == "fp8":
raise NotImplementedError(
"Prefix caching is not supported for fp8 cache_dtype. "
"Run with --kv-cache-dtype auto to use prefix caching.")

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
Expand Down Expand Up @@ -1116,6 +1148,8 @@ def _get_and_verify_dtype(
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand All @@ -1135,13 +1169,24 @@ def _get_and_verify_max_len(
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys.
max_len_key = None
for key in possible_keys:
max_len = getattr(hf_config, key, None)
if max_len is not None:
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)

# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
max_len_key = "sliding_window" \
if sliding_window_len < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len, sliding_window_len)

# If none of the keys were found in the config, use a default and
# log a warning.
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
Expand All @@ -1157,13 +1202,22 @@ def _get_and_verify_max_len(

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
Expand All @@ -1172,6 +1226,13 @@ def _get_and_verify_max_len(
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
if model_max_length is not None and max_model_len <= model_max_length:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(
Expand Down
12 changes: 9 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
Expand Down Expand Up @@ -267,6 +268,10 @@ def add_cli_args(
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2.')
Expand Down Expand Up @@ -558,8 +563,8 @@ def create_engine_config(self, ) -> EngineConfig:
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down Expand Up @@ -645,7 +650,8 @@ def create_engine_config(self, ) -> EngineConfig:
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled):
raise ValueError(
"Chunked prefill is not supported with sliding window.")
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")

return EngineConfig(model_config=model_config,
cache_config=cache_config,
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -145,7 +144,6 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config)

Expand Down Expand Up @@ -182,7 +180,6 @@ def __init__(
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
Expand All @@ -197,7 +194,6 @@ def __init__(
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = LlamaMLP(
Expand Down
22 changes: 10 additions & 12 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class MixtralAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -276,7 +277,6 @@ def __init__(self,
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window

if isinstance(
quant_config,
Expand Down Expand Up @@ -312,7 +312,6 @@ def __init__(self,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)

Expand Down Expand Up @@ -349,7 +348,6 @@ def __init__(
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
Expand Down
Loading
Loading