From 9f669a9a7c2b2d0a7963a6e29253280e57680adb Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 3 Nov 2023 14:12:48 -0700 Subject: [PATCH] Support YaRN models (#1264) Signed-off-by: Antoni Baum Co-authored-by: Viktor Ferenczi Co-authored-by: Woosuk Kwon --- csrc/activation_kernels.cu | 12 +- csrc/pos_encoding_kernels.cu | 2 +- vllm/config.py | 3 + vllm/model_executor/layers/attention.py | 15 ++- .../model_executor/layers/rotary_embedding.py | 104 ++++++++++++++++++ 5 files changed, 128 insertions(+), 8 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 581525e970ce5..89d1ba2d37dd8 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { - const int token_idx = blockIdx.x; - for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; @@ -30,7 +30,7 @@ void silu_and_mul( torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - int num_tokens = input.numel() / input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; dim3 grid(num_tokens); @@ -55,8 +55,8 @@ __global__ void activation_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., d] const int d) { - const int token_idx = blockIdx.x; - for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } @@ -67,7 +67,7 @@ __global__ void activation_kernel( // Launch element-wise activation kernel. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ int d = input.size(-1); \ - int num_tokens = input.numel() / d; \ + int64_t num_tokens = input.numel() / d; \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 41001ba64746f..0a5ec95f8c0d4 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -84,7 +84,7 @@ void rotary_embedding( int head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int num_tokens = query.numel() / query.size(-1); + int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; diff --git a/vllm/config.py b/vllm/config.py index 6e19491083d44..a9e86c24b2733 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -390,6 +390,9 @@ def _get_and_verify_max_len( if rope_scaling is not None: assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] derived_max_model_len *= scaling_factor if max_model_len is None: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 7aa01ffe14bab..c1259a1b11ea5 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) + RotaryEmbedding, YaRNScalingRotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -334,6 +334,19 @@ def __init__( self.rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", + "beta_fast", "beta_slow") + } + self.rotary_emb = YaRNScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 4ecde07562faf..2cbd3b584c06e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Rotary Positional Embeddings.""" +import math from typing import Tuple, Union import torch @@ -167,3 +168,106 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range(low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> int: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype, + device: torch.device) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - + low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: float = 32, + beta_slow: float = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, + device="cuda")) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache