diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py new file mode 100644 index 0000000000000..eeac6ab43c05f --- /dev/null +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -0,0 +1,44 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from tests.conftest import cleanup +from vllm import LLM + +MODEL_LEN_LEN = [ + # Example models with sliding window. + ("bigcode/starcoder2-3b", 4096, 16384), + # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI + + # Confirm model with sliding window works. + # config has "use_sliding_window": false + ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), + # config has no sliding window attribute. + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), +] + + +@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) +def test_disable_sliding_window(model_len_len, ): + model, sliding_len, full_len = model_len_len + vllm_disabled_model = LLM(model, disable_sliding_window=True) + vllm_disabled_model.generate("Hi my name is") + model_config = vllm_disabled_model.llm_engine.model_config + assert model_config.max_model_len == sliding_len, ( + "Max len expected to equal sliding_len of %s, but got %s", sliding_len, + model_config.max_model_len) + + del vllm_disabled_model + cleanup() + + vllm_enabled_model = LLM(model, disable_sliding_window=False) + vllm_enabled_model.generate("Hi my name is") + model_config = vllm_enabled_model.llm_engine.model_config + assert model_config.max_model_len == full_len, ( + "Max len expected to equal full_len of %s, but got %s", full_len, + model_config.max_model_len) + + del vllm_enabled_model + cleanup() diff --git a/tests/test_config.py b/tests/test_config.py index 6bc51a53dc07c..7cbdaeca9c4d4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,29 @@ +import pytest + from vllm.config import ModelConfig +MODEL_IDS_EXPECTED = [ + ("Qwen/Qwen1.5-7B", 32768), + ("mistralai/Mistral-7B-v0.1", 4096), + ("mistralai/Mistral-7B-Instruct-v0.2", 32768), +] + + +@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED) +def test_disable_sliding_window(model_id_expected): + model_id, expected = model_id_expected + model_config = ModelConfig( + model_id, + model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + disable_sliding_window=True, + ) + assert model_config.max_model_len == expected + def test_get_sliding_window(): TEST_SLIDING_WINDOW = 4096 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dc7b3940bc9b7..addee160694b0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -30,7 +30,6 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -38,9 +37,11 @@ def __init__( if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size + sliding_window = cache_config.sliding_window else: kv_cache_dtype = "auto" block_size = 16 + sliding_window = None if num_kv_heads is None: num_kv_heads = num_heads diff --git a/vllm/config.py b/vllm/config.py index b245a1a3ee6d3..4b256d00a32df 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -69,6 +69,10 @@ class ModelConfig: max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, @@ -96,6 +100,7 @@ def __init__( max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: @@ -118,14 +123,18 @@ def __init__( self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture) self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_text_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) self.served_model_name = get_served_model_name(model, served_model_name) if not self.skip_tokenizer_init: @@ -220,7 +229,7 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -232,6 +241,15 @@ def get_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + def get_vocab_size(self) -> int: return self.hf_text_config.vocab_size @@ -336,6 +354,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() + self._verify_prefix_caching() # Will be set after profiling. self.num_gpu_blocks = None @@ -364,6 +383,19 @@ def _verify_cache_dtype(self) -> None: else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -1116,6 +1148,8 @@ def _get_and_verify_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1135,6 +1169,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] + # Choose the smallest "max_length" from the possible keys. max_len_key = None for key in possible_keys: max_len = getattr(hf_config, key, None) @@ -1142,6 +1177,16 @@ def _get_and_verify_max_len( max_len_key = key if max_len < derived_max_model_len \ else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -1157,6 +1202,13 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None and rope_scaling["type"] != "su": + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": @@ -1164,6 +1216,8 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. if max_model_len is None: max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: @@ -1172,6 +1226,13 @@ def _get_and_verify_max_len( # with model_max_length and allow this override when it's smaller. model_max_length = getattr(hf_config, "model_max_length", None) if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") pass else: raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538e3427e37fb..3267c8c9f44d2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,6 +41,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -267,6 +268,10 @@ def add_cli_args( parser.add_argument('--enable-prefix-caching', action='store_true', help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2.') @@ -558,8 +563,8 @@ def create_engine_config(self, ) -> EngineConfig: self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, self.max_seq_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init, - self.served_model_name) + self.max_logprobs, self.disable_sliding_window, + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -645,7 +650,8 @@ def create_engine_config(self, ) -> EngineConfig: if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled): raise ValueError( - "Chunked prefill is not supported with sliding window.") + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window.") return EngineConfig(model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f43a40a0bfd34..c35f8cee4db23 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -93,7 +93,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -145,7 +144,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -182,7 +180,6 @@ def __init__( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( @@ -197,7 +194,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = LlamaMLP( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ea95cf7380d54..d6dd7fa1fe9e2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -246,15 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -276,7 +277,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if isinstance( quant_config, @@ -312,7 +312,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -349,7 +348,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9b99ff729aadd..1894c05e167d6 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -166,7 +166,6 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -190,7 +189,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -217,7 +215,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -254,7 +251,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index ec203c3b9001a..9a4829a27873e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -86,10 +86,8 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -112,7 +110,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -140,7 +137,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -173,18 +168,14 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - use_sliding_window = (config.use_sliding_window - and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window, rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -244,8 +235,8 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -302,6 +293,18 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + super().__init__() self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 91ffd0861c39d..4324bf50d4ad1 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -74,7 +74,6 @@ def __init__(self, self.rope_theta = config.rope_theta self.max_position_embeddings = config.max_position_embeddings self.use_bias = config.use_bias - self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( self.hidden_size, @@ -101,7 +100,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index dda13d83f89a3..1e5280dde3ff9 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -88,7 +88,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -134,7 +133,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -167,7 +165,6 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) self.self_attn = XverseAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -178,7 +175,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=getattr(config, "bias", False), - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = XverseMLP(