diff --git a/vllm/config.py b/vllm/config.py index 160ec63d4d501..562bd6a4f69b5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron, + print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -133,6 +134,17 @@ def __init__( code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, max_model_len=max_model_len, @@ -1224,20 +1236,32 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None and rope_scaling["type"] != "su": - if disable_sliding_window: - # TODO(robertgshaw): Find a model that supports rope_scaling - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "with rope_scaling. Please raise an issue so we can " - "investigate.") - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d03903d206d33..c15be15a9f0ce 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -503,6 +503,159 @@ def forward( return query.flatten(-2), key.flatten(-2) +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class GemmaRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 + inv_freq = 1.0 / (base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / + self.rotary_dim)) + return inv_freq + + +class ExtendedRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + return self.apply_scaling(inv_freqs) + + def apply_scaling(self, freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -534,10 +687,17 @@ def get_rope( rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: - scaling_type = rope_scaling["type"] - if scaling_type != "su": + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + # The correct one should be "longrope" but keep "su" here + # for backward compatible + if scaling_type not in {"su", "longrope", "llama3"}: scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "llama3": + rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style,