Skip to content

Commit

Permalink
[Bugfix / Core] Prefix Caching Guards (merged with main) (vllm-projec…
Browse files Browse the repository at this point in the history
…t#4846)

Co-authored-by: rsnm2 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
  • Loading branch information
3 people authored and dtrifiro committed May 31, 2024
1 parent 3409f65 commit 4ee6996
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 44 deletions.
44 changes: 44 additions & 0 deletions tests/prefix_caching/test_disable_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest

from tests.conftest import cleanup
from vllm import LLM

MODEL_LEN_LEN = [
# Example models with sliding window.
("bigcode/starcoder2-3b", 4096, 16384),
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI

# Confirm model with sliding window works.
# config has "use_sliding_window": false
("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768),
# config has no sliding window attribute.
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048),
]


@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN)
def test_disable_sliding_window(model_len_len, ):
model, sliding_len, full_len = model_len_len
vllm_disabled_model = LLM(model, disable_sliding_window=True)
vllm_disabled_model.generate("Hi my name is")
model_config = vllm_disabled_model.llm_engine.model_config
assert model_config.max_model_len == sliding_len, (
"Max len expected to equal sliding_len of %s, but got %s", sliding_len,
model_config.max_model_len)

del vllm_disabled_model
cleanup()

vllm_enabled_model = LLM(model, disable_sliding_window=False)
vllm_enabled_model.generate("Hi my name is")
model_config = vllm_enabled_model.llm_engine.model_config
assert model_config.max_model_len == full_len, (
"Max len expected to equal full_len of %s, but got %s", full_len,
model_config.max_model_len)

del vllm_enabled_model
cleanup()
24 changes: 24 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
import pytest

from vllm.config import ModelConfig

MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
disable_sliding_window=True,
)
assert model_config.max_model_len == expected


def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
Expand All @@ -39,9 +38,11 @@ def __init__(
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
if num_kv_heads is None:
num_kv_heads = num_heads

Expand Down
67 changes: 64 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class ModelConfig:
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
Expand Down Expand Up @@ -96,6 +100,7 @@ def __init__(
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
Expand All @@ -118,14 +123,18 @@ def __init__(
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
if not self.skip_tokenizer_init:
Expand Down Expand Up @@ -220,7 +229,7 @@ def verify_with_parallel_config(
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")

def get_sliding_window(self) -> Optional[int]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""

Expand All @@ -232,6 +241,15 @@ def get_sliding_window(self) -> Optional[int]:
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size

Expand Down Expand Up @@ -336,6 +354,7 @@ def __init__(
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

# Will be set after profiling.
self.num_gpu_blocks = None
Expand Down Expand Up @@ -364,6 +383,19 @@ def _verify_cache_dtype(self) -> None:
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

def _verify_prefix_caching(self) -> None:
if not self.enable_prefix_caching:
return

if self.sliding_window is not None:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.cache_dtype == "fp8":
raise NotImplementedError(
"Prefix caching is not supported for fp8 cache_dtype. "
"Run with --kv-cache-dtype auto to use prefix caching.")

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
Expand Down Expand Up @@ -1116,6 +1148,8 @@ def _get_and_verify_dtype(
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand All @@ -1135,13 +1169,24 @@ def _get_and_verify_max_len(
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys.
max_len_key = None
for key in possible_keys:
max_len = getattr(hf_config, key, None)
if max_len is not None:
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)

# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
max_len_key = "sliding_window" \
if sliding_window_len < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len, sliding_window_len)

# If none of the keys were found in the config, use a default and
# log a warning.
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
Expand All @@ -1157,13 +1202,22 @@ def _get_and_verify_max_len(

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
Expand All @@ -1172,6 +1226,13 @@ def _get_and_verify_max_len(
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
if model_max_length is not None and max_model_len <= model_max_length:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(
Expand Down
12 changes: 9 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
Expand Down Expand Up @@ -267,6 +268,10 @@ def add_cli_args(
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2.')
Expand Down Expand Up @@ -558,8 +563,8 @@ def create_engine_config(self, ) -> EngineConfig:
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down Expand Up @@ -645,7 +650,8 @@ def create_engine_config(self, ) -> EngineConfig:
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled):
raise ValueError(
"Chunked prefill is not supported with sliding window.")
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")

return EngineConfig(model_config=model_config,
cache_config=cache_config,
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -146,7 +145,6 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config)

Expand Down Expand Up @@ -183,7 +181,6 @@ def __init__(
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
Expand All @@ -198,7 +195,6 @@ def __init__(
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = LlamaMLP(
Expand Down
22 changes: 10 additions & 12 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class MixtralAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -276,7 +277,6 @@ def __init__(self,
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window

if isinstance(
quant_config,
Expand Down Expand Up @@ -312,7 +312,6 @@ def __init__(self,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)

Expand Down Expand Up @@ -349,7 +348,6 @@ def __init__(
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
Expand Down
Loading

0 comments on commit 4ee6996

Please sign in to comment.