From 5c976a7e1a1bec875bf6474824b7dff39e38de18 Mon Sep 17 00:00:00 2001 From: Roy Date: Tue, 13 Feb 2024 16:09:23 +0800 Subject: [PATCH] Refactor llama family models (#2637) --- vllm/model_executor/layers/layernorm.py | 25 ++ vllm/model_executor/models/__init__.py | 9 +- vllm/model_executor/models/aquila.py | 342 ------------------- vllm/model_executor/models/baichuan.py | 335 ++----------------- vllm/model_executor/models/internlm.py | 299 ----------------- vllm/model_executor/models/internlm2.py | 285 ++-------------- vllm/model_executor/models/llama.py | 162 +++++---- vllm/model_executor/models/mistral.py | 352 -------------------- vllm/model_executor/models/qwen.py | 267 ++------------- vllm/model_executor/models/stablelm.py | 283 +--------------- vllm/model_executor/models/yi.py | 330 ------------------ vllm/transformers_utils/config.py | 4 - vllm/transformers_utils/configs/__init__.py | 8 - vllm/transformers_utils/configs/aquila.py | 69 ---- vllm/transformers_utils/configs/baichuan.py | 62 ---- vllm/transformers_utils/configs/qwen.py | 60 ---- vllm/transformers_utils/configs/yi.py | 64 ---- 17 files changed, 236 insertions(+), 2720 deletions(-) delete mode 100644 vllm/model_executor/models/aquila.py delete mode 100644 vllm/model_executor/models/internlm.py delete mode 100644 vllm/model_executor/models/mistral.py delete mode 100644 vllm/model_executor/models/yi.py delete mode 100644 vllm/transformers_utils/configs/aquila.py delete mode 100644 vllm/transformers_utils/configs/baichuan.py delete mode 100644 vllm/transformers_utils/configs/qwen.py delete mode 100644 vllm/transformers_utils/configs/yi.py diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5ad..83bc189ee4d90 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -7,6 +7,31 @@ from vllm._C import ops +class LayerNorm(nn.LayerNorm): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__(hidden_size, eps=eps) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """normalization.""" + if residual is not None: + x = x + residual + residual = x + x = super().forward(x) + if residual is None: + return x + else: + return x, residual + + class RMSNorm(nn.Module): """Root mean square normalization. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fb519b3c0cf92..7c1ce1bf1fc60 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -10,8 +10,8 @@ # Architecture -> (module, class). _MODELS = { - "AquilaModel": ("aquila", "AquilaForCausalLM"), - "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), @@ -24,12 +24,12 @@ "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), - "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), - "MistralForCausalLM": ("mistral", "MistralForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case @@ -41,7 +41,6 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "YiForCausalLM": ("yi", "YiForCausalLM") } # Models not supported by ROCm. diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py deleted file mode 100644 index 2f2bd5ffb4a63..0000000000000 --- a/vllm/model_executor/models/aquila.py +++ /dev/null @@ -1,342 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import nn - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.aquila import AquilaConfig - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class AquilaMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class AquilaRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - AquilaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, - keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - - return (self.weight * hidden_states).to(input_dtype) - - -class AquilaAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class AquilaDecoderLayer(nn.Module): - - def __init__( - self, - config: AquilaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = AquilaAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - rope_scaling=rope_scaling, - linear_method=linear_method, - ) - self.mlp = AquilaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = AquilaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class AquilaModel(nn.Module): - - def __init__( - self, - config: AquilaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - AquilaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - ) - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class AquilaForCausalLM(nn.Module): - - def __init__( - self, - config, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = AquilaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index f08c3c8d257ff..e0f826bf7e29a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -18,305 +18,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" -import math -from typing import List, Optional, Tuple +from typing import Optional import torch -from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.baichuan import BaiChuanConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - -def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32, - ) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != total_num_heads: - extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32, - ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - -class BaiChuanMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class BaiChuanAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - position_embedding: str, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.postion_embedding = position_embedding - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - # pylint: disable=invalid-name - self.W_pack = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - # Create the alibi slopes and slice them. - if self.postion_embedding == "ALIBI": - tp_rank = get_tensor_model_parallel_rank() - head_start = tp_rank * self.num_heads - head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(self.total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end].tolist() - - scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) - else: - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - ) - self.scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, self.head_dim, - self.scaling) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.W_pack(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.postion_embedding != "ALIBI": - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class BaiChuanDecoderLayer(nn.Module): - - def __init__(self, - config: BaiChuanConfig, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = BaiChuanAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - position_embedding=position_embedding, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - ) - self.mlp = BaiChuanMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class BaiChuanModel(nn.Module): - - def __init__(self, - config: BaiChuanConfig, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class BaiChuanBaseForCausalLM(nn.Module): - - def __init__(self, - config, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = BaiChuanModel(config, position_embedding, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens +class BaiChuanBaseForCausalLM(LlamaForCausalLM): def load_weights(self, model_name_or_path: str, @@ -328,9 +42,15 @@ def load_weights(self, ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + param_weight_map = [ + ("qkv_proj", "W_pack"), + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + for (param_name, weight_name) in param_weight_map: + name = name.replace(weight_name, param_name) + if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": @@ -368,19 +88,28 @@ def load_weights(self, class BaichuanForCausalLM(BaiChuanBaseForCausalLM): """Baichuan 13B and Baichuan2 7B/13B.""" - def __init__(self, - config, - linear_method: Optional[LinearMethodBase] = None): - if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", linear_method) - else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", linear_method) + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + if config.hidden_size != 4096: # baichuan 13b, baichuan2 13b + config.postion_embedding = "ALIBI" + super().__init__(config=config, + linear_method=linear_method, + lora_config=lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): """Baichuan 7B.""" - def __init__(self, - config, - linear_method: Optional[LinearMethodBase] = None): - super().__init__(config, "ROPE", linear_method) + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__(config=config, + linear_method=linear_method, + lora_config=lora_config) diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py deleted file mode 100644 index 5d0b93793c89d..0000000000000 --- a/vllm/model_executor/models/internlm.py +++ /dev/null @@ -1,299 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import nn -from transformers import LlamaConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class InternLMMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class InternLMAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - bias: bool, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=bias, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=bias, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class InternLMDecoderLayer(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = InternLMAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - bias=config.bias, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - rope_scaling=getattr(config, "rope_scaling", None), - ) - self.mlp = InternLMMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class InternLMModel(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.embed_tokens = VocabParallelEmbedding( - vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class InternLMForCausalLM(nn.Module): - - def __init__( - self, - config, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = InternLMModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index ebf1d8a89a022..bffe3ae2408cc 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,276 +1,27 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple +from typing import Optional import torch -from torch import nn from transformers import PretrainedConfig +from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - -class InternLM2MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.w2 = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.w2(x) - return x - - -class InternLM2Attention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.wqkv = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.wo = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.wqkv(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.wo(attn_output) - return output - - -class InternLMDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.attention = InternLM2Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - ) - self.feed_forward = InternLM2MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.attention_norm(hidden_states) - else: - hidden_states, residual = self.attention_norm( - hidden_states, residual) - hidden_states = self.attention( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.ffn_norm(hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual - - -class InternLM2Model(nn.Module): +class InternLM2ForCausalLM(LlamaForCausalLM): def __init__( self, - config: PretrainedConfig, + config: Optional[PretrainedConfig] = None, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.tok_embeddings = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.tok_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class InternLM2ForCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = InternLM2Model(config, linear_method) - self.output = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.output.weight, hidden_states, - sampling_metadata) - return next_tokens + super().__init__(config=config, + linear_method=linear_method, + lora_config=lora_config) def load_weights(self, model_name_or_path: str, @@ -282,9 +33,23 @@ def load_weights(self, ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] + param_weight_map = [ + ("qkv_proj", "wqkv"), + ("o_proj", "wo"), + ("down_proj", "w2"), + ("input_layernorm", "attention_norm"), + ("post_attention_layernorm", "ffn_norm"), + ("embed_tokens", "tok_embeddings"), + (".self_attn.", ".attention."), + ("mlp", "feed_forward"), + ("lm_head", "output"), + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + for (param_name, weight_name) in param_weight_map: + name = name.replace(weight_name, param_name) + if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -303,7 +68,7 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - if "wqkv" in name: + if "qkv_proj" in name: config = self.config kv_groups = config.num_attention_heads // config.num_key_value_heads head_dim = config.hidden_size // config.num_attention_heads diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e5a1abebf1420..462469ef79d1c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,8 +21,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple +import math import torch from torch import nn from transformers import LlamaConfig @@ -40,34 +41,60 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.config import LoRAConfig +from copy import deepcopy + KVCache = Tuple[torch.Tensor, torch.Tensor] +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + class LlamaMLP(nn.Module): def __init__( self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, + config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + config.hidden_size, [config.intermediate_size] * 2, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, bias=False, linear_method=linear_method) + hidden_act = getattr(config, "hidden_act", "silu") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -84,21 +111,19 @@ class LlamaAttention(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, + config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.hidden_size = hidden_size + self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads + self.total_num_heads = getattr(config, "num_attention_heads", None) assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads + + # defaut to mha + self.total_num_kv_heads = getattr(config, "num_key_value_heads", + self.total_num_heads) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -108,39 +133,68 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.max_position_embeddings = config.max_position_embeddings + # internlm + bias = getattr(config, "bias", False) + + # stablelm + qkv_bias = getattr(config, "use_qkv_bias", False) self.qkv_proj = QKVParallelLinear( - hidden_size, + self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, + bias=bias or qkv_bias, linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, - hidden_size, - bias=False, + self.hidden_size, + bias=bias, linear_method=linear_method, ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + # mistral + sliding_window = getattr(config, "sliding_window", None) + + self.postion_embedding = getattr(config, "postion_embedding", "ROPE") + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window) + else: + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + # stablelm + rope_pct = getattr(config, "rope_pct", 1) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.head_dim * rope_pct), + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, @@ -151,7 +205,8 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) @@ -164,32 +219,20 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + norm: Optional[torch.Tensor] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) self.self_attn = LlamaAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, + config, linear_method=linear_method, ) self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, + config, linear_method=linear_method, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = deepcopy(norm) + self.post_attention_layernorm = deepcopy(norm) def forward( self, @@ -226,6 +269,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + norm: Optional[torch.Tensor] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -241,10 +285,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, linear_method, norm) for _ in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = norm def forward( self, @@ -275,12 +319,18 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + norm: Optional[torch.Tensor] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + if norm is None: + norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.model = LlamaModel(config, + linear_method, + norm=norm, + lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py deleted file mode 100644 index 01cde67844122..0000000000000 --- a/vllm/model_executor/models/mistral.py +++ /dev/null @@ -1,352 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mistral model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple - -import torch -from torch import nn -from transformers import MistralConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -from vllm.config import LoRAConfig - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class MistralMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class MistralAttention(nn.Module): - - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=self.rope_theta, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class MistralDecoderLayer(nn.Module): - - def __init__( - self, - config: MistralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MistralAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - linear_method=linear_method, - sliding_window=config.sliding_window) - self.mlp = MistralMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class MistralModel(nn.Module): - - def __init__( - self, - config: MistralConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - self.layers = nn.ModuleList([ - MistralDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class MistralForCausalLM(nn.Module): - supports_lora = True - - def __init__( - self, - config: MistralConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = MistralModel(config, - linear_method, - lora_config=lora_config) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) - self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index fbc7320fb45a4..79ad182b1c06b 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,253 +4,33 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Optional -import torch -from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.qwen import QWenConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - -class QWenMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str = "silu", - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.c_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.c_proj(x) - return x - - -class QWenAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.c_attn = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=True, - linear_method=linear_method, - ) - self.c_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - self.scaling = self.head_dim**-0.5 - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - - output, _ = self.c_proj(attn_output) - return output - - -class QWenBlock(nn.Module): - - def __init__( - self, - config: QWenConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - self.attn = QWenAttention(config.hidden_size, - config.num_attention_heads, - config.max_position_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - linear_method=linear_method) - - self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.mlp = QWenMLP(config.hidden_size, - config.intermediate_size // 2, - linear_method=linear_method) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - else: - hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.ln_2(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class QWenModel(nn.Module): +class QWenLMHeadModel(LlamaForCausalLM): def __init__( self, - config: QWenConfig, + config: Optional[PretrainedConfig] = None, linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - - self.wte = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.h = nn.ModuleList([ - QWenBlock(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - residual = None - for i in range(len(self.h)): - layer = self.h[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states - - -class QWenLMHeadModel(nn.Module): - - def __init__( - self, - config: QWenConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.linear_method = linear_method - self.transformer = QWenModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens + lora_config: Optional[LoRAConfig] = None, + ) -> None: + norm = RMSNorm(config.hidden_size, config.layer_norm_epsilon) + config.use_qkv_bias = True + config.intermediate_size = config.intermediate_size // 2 + super().__init__(config=config, + linear_method=linear_method, + norm=norm, + lora_config=lora_config) def load_weights(self, model_name_or_path: str, @@ -262,9 +42,24 @@ def load_weights(self, ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] + param_weight_map = [ + ("model", "transformer"), + (".self_attn.", ".attn."), + (".layers.", ".h."), + ("qkv_proj", "c_attn"), + (".self_attn.o_proj", ".self_attn.c_proj"), + ("norm", "ln_f"), + ("mlp.down_proj", "mlp.c_proj"), + ("input_layernorm", "ln_1"), + ("post_attention_layernorm", "ln_2"), + ("embed_tokens", "wte"), + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + for (param_name, weight_name) in param_weight_map: + name = name.replace(weight_name, param_name) + if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 95e5ad8ede63e..6845384768129 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -17,283 +17,26 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Optional -import torch -from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.layernorm import LayerNorm +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.config import LoRAConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - -class StablelmMLP(nn.Module): - - def __init__(self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [config.intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=False) - self.act_fn = SiluAndMul() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class StablelmAttention(nn.Module): - - def __init__(self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - self.num_heads = self.total_num_heads // tp_size - - self.total_num_key_value_heads = config.num_key_value_heads - if self.total_num_key_value_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_key_value_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_key_value_heads == 0 - self.num_key_value_heads = max( - 1, self.total_num_key_value_heads // tp_size) - self.head_dim = self.hidden_size // self.total_num_heads - self.max_position_embeddings = config.max_position_embeddings - self.rotary_ndims = int(self.head_dim * self.config.rope_pct) - self.scaling = self.head_dim**-0.5 - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_key_value_heads * self.head_dim - self.qkv_bias = getattr(config, "use_qkv_bias", False) - if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") - - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_key_value_heads, - self.qkv_bias, - linear_method=linear_method) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - linear_method=linear_method) - self.rotary_ndims = int(self.head_dim * self.config.rope_pct) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.rotary_ndims, - max_position=self.config.max_position_embeddings, - base=self.config.rope_theta, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_key_value_heads) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class StablelmDecoderLayer(nn.Module): +class StablelmForCausalLM(LlamaForCausalLM): def __init__( self, - config: PretrainedConfig, + config: Optional[PretrainedConfig] = None, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, linear_method) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, residual - - -class StableLMEpochModel(nn.Module): - - def __init__(self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: - super().__init__() - # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - ) - hidden_states = self.norm(hidden_states) - return hidden_states - - -class StablelmForCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = StableLMEpochModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + norm = LayerNorm(config.hidden_size, config.norm_eps) + super().__init__(config=config, + linear_method=linear_method, + norm=norm, + lora_config=lora_config) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py deleted file mode 100644 index 53daa6c4cd939..0000000000000 --- a/vllm/model_executor/models/yi.py +++ /dev/null @@ -1,330 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import nn -from vllm.transformers_utils.configs.yi import YiConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class YiMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class YiAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class YiDecoderLayer(nn.Module): - - def __init__( - self, - config: YiConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = YiAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - ) - self.mlp = YiMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.ln1(hidden_states) - else: - hidden_states, residual = self.ln1(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.ln2(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class YiModel(nn.Module): - - def __init__( - self, - config: YiConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - YiDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class YiForCausalLM(nn.Module): - - def __init__( - self, - config: YiConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = YiModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8b16e559b24f2..bc27784087aa7 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -5,14 +5,10 @@ from vllm.transformers_utils.configs import * _CONFIG_REGISTRY = { - "aquila": AquilaConfig, - "baichuan": BaiChuanConfig, "chatglm": ChatGLMConfig, "mpt": MPTConfig, - "qwen": QWenConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "yi": YiConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 284867414e0ed..ef955f75cedaa 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,20 +1,12 @@ -from vllm.transformers_utils.configs.aquila import AquilaConfig -from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.mpt import MPTConfig -from vllm.transformers_utils.configs.qwen import QWenConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.yi import YiConfig __all__ = [ - "AquilaConfig", - "BaiChuanConfig", "ChatGLMConfig", "MPTConfig", - "QWenConfig", "RWConfig", - "YiConfig", ] diff --git a/vllm/transformers_utils/configs/aquila.py b/vllm/transformers_utils/configs/aquila.py deleted file mode 100644 index 86a6f2ba304af..0000000000000 --- a/vllm/transformers_utils/configs/aquila.py +++ /dev/null @@ -1,69 +0,0 @@ -# coding=utf-8 -# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Aquila model configuration""" - -from transformers import PretrainedConfig - - -class AquilaConfig(PretrainedConfig): - model_type = "aquila" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=100008, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.006, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/vllm/transformers_utils/configs/baichuan.py b/vllm/transformers_utils/configs/baichuan.py deleted file mode 100644 index 869817525c11a..0000000000000 --- a/vllm/transformers_utils/configs/baichuan.py +++ /dev/null @@ -1,62 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers.configuration_utils import PretrainedConfig - - -class BaiChuanConfig(PretrainedConfig): - model_type = "baichuan" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=64000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - hidden_act="silu", - max_position_embeddings=4096, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/vllm/transformers_utils/configs/qwen.py b/vllm/transformers_utils/configs/qwen.py deleted file mode 100644 index bb033a337ad04..0000000000000 --- a/vllm/transformers_utils/configs/qwen.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE - -from transformers import PretrainedConfig - - -class QWenConfig(PretrainedConfig): - model_type = "qwen" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - num_hidden_layers=32, - num_attention_heads=32, - emb_dropout_prob=0.0, - attn_dropout_prob=0.0, - layer_norm_epsilon=1e-6, - initializer_range=0.02, - max_position_embeddings=8192, - scale_attn_weights=True, - use_cache=True, - bf16=False, - fp16=False, - fp32=False, - kv_channels=128, - rotary_pct=1.0, - rotary_emb_base=10000, - use_dynamic_ntk=True, - use_logn_attn=True, - use_flash_attn="auto", - intermediate_size=22016, - no_bias=True, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.emb_dropout_prob = emb_dropout_prob - self.attn_dropout_prob = attn_dropout_prob - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - self.scale_attn_weights = scale_attn_weights - self.use_cache = use_cache - self.max_position_embeddings = max_position_embeddings - self.bf16 = bf16 - self.fp16 = fp16 - self.fp32 = fp32 - self.kv_channels = kv_channels - self.rotary_pct = rotary_pct - self.rotary_emb_base = rotary_emb_base - self.use_dynamic_ntk = use_dynamic_ntk - self.use_logn_attn = use_logn_attn - self.use_flash_attn = use_flash_attn - self.no_bias = no_bias - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/vllm/transformers_utils/configs/yi.py b/vllm/transformers_utils/configs/yi.py deleted file mode 100644 index 359922ed26952..0000000000000 --- a/vllm/transformers_utils/configs/yi.py +++ /dev/null @@ -1,64 +0,0 @@ -""" Yi model configuration""" -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -Yi_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class YiConfig(PretrainedConfig): - r""" - Reference: - https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py - """ - model_type = "Yi" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=64000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=4, - hidden_act="silu", - max_position_embeddings=4096, - initializer_range=0.02, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - output_attentions=False, - rope_theta=5000000.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.output_attentions = output_attentions - self.rope_theta = rope_theta - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - )