diff --git a/vllm/config.py b/vllm/config.py index a9e86c24b2733..6ae8ccb8bba95 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -166,6 +166,10 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: if getattr(self.hf_config, "num_key_value_heads", None) is not None: return (self.hf_config.num_key_value_heads // parallel_config.tensor_parallel_size) + # For ChatGLM-2: + if getattr(self.hf_config, "multi_query_group_num", None) is not None: + return (self.hf_config.multi_query_group_num // + parallel_config.tensor_parallel_size) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 6ae20795160ef..7fca7ebdbf09e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -18,6 +18,7 @@ "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BloomForCausalLM": BloomForCausalLM, + "ChatGLMModel": ChatGLMForCausalLM, "FalconForCausalLM": FalconForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index e04864a430258..e6ac45d8f01ac 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -13,6 +13,7 @@ from vllm.model_executor.models.mpt import MptForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel +from vllm.model_executor.models.chatglm import ChatGLMForCausalLM from vllm.model_executor.models.yi import YiForCausalLM __all__ = [ @@ -20,6 +21,7 @@ "BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM", + "ChatGLMForCausalLM", "FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 0000000000000..8acc8e468b652 --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,408 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.sequence import SequenceOutputs + +from vllm.transformers_utils.configs import ChatGLMConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GLMAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = ColumnParallelLinear( + config.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=config.add_qkv_bias, + gather_output=False, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + input_is_parallel=True, + ) + + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim // 2, + num_kv_heads=self.num_kv_heads, + is_neox_style=False, + # is_glm_style=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + key_cache, value_cache = kv_cache + + context_layer = self.attn( + position_ids, + q, + k, + v, + key_cache, + value_cache, + input_metadata, + cache_event, + ) + + attn_output, _ = self.dense(context_layer) + + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=config.add_bias_linear, + gather_output=False, + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config, + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, config): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList( + [GLMBlock(config) for i in range(self.num_layers)]) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + for i in range(self.num_layers): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + position_ids=position_ids, + kv_cache=kv_caches[i], + input_metadata=input_metadata, + cache_event=cache_event, + ) + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class ChatGLMModel(nn.Module): + + def __init__(self, config): + super().__init__() + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size) + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config) + + self.output_layer = ColumnParallelLinear( + config.hidden_size, + config.padded_vocab_size, + bias=False, + gather_output=False, + params_dtype=config.torch_dtype, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ): + inputs_embeds = self.embedding(input_ids) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=inputs_embeds, + position_ids=position_ids, + kv_caches=kv_caches, + input_metadata=input_metadata, + cache_events=cache_events, + ) + + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.config: ChatGLMConfig = config + self.transformer = ChatGLMModel(config) + self.lm_head_weight = self.transformer.output_layer.weight + self.sampler = Sampler(config.padded_vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = [ + "output_layer.weight", + "embedding.weight", + ] + _row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + q_proj_shard_size = self.config.hidden_size // tp_size + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.multi_query_group_num // tp_size) + + mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size + + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + + if name in state_dict: + param = state_dict[name] + if "query_key_value" in name: + q_offset = q_proj_shard_size * tp_rank + k_offset = (q_proj_shard_size * tp_size + + kv_proj_shard_size * tp_rank) + v_offset = (q_proj_shard_size * tp_size + + kv_proj_shard_size * (tp_size + tp_rank)) + wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] + wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] + wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size] + loaded_weight = torch.cat([wq, wk, wv], dim=0) + param.data.copy_(loaded_weight) + continue + + if "dense_h_to_4h" in name: + w_gate = loaded_weight[mlp_hidden_shard_size * + tp_rank:mlp_hidden_shard_size * + (tp_rank + 1)] + w_proj = loaded_weight[mlp_hidden_shard_size * + (tp_size + + tp_rank):mlp_hidden_shard_size * + (tp_size + tp_rank + 1)] + loaded_weight = torch.cat([w_gate, w_proj], dim=0) + param.data.copy_(loaded_weight) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) + elif name == "transformer.rotary_pos_emb.inv_freq": + continue + else: + print("Warning never found tensor's name:", name) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 56c7c09b4cab6..97c55f8a16f2f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -5,9 +5,10 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { - "mpt": MptConfig, - "baichuan": BaiChuanConfig, "aquila": AquilaConfig, + "baichuan": BaiChuanConfig, + "chatglm": ChatGLMConfig, + "mpt": MptConfig, "qwen": QWenConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 77f2df6e9b6b9..fb90babf06e5b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,5 +1,6 @@ -from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.aquila import AquilaConfig +from vllm.transformers_utils.configs.baichuan import BaiChuanConfig +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.qwen import QWenConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the @@ -8,8 +9,9 @@ from vllm.transformers_utils.configs.yi import YiConfig __all__ = [ - "BaiChuanConfig", "AquilaConfig", + "BaiChuanConfig", + "ChatGLMConfig", "QWenConfig", "RWConfig", "YiConfig", diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py new file mode 100644 index 0000000000000..c4244f8c77f44 --- /dev/null +++ b/vllm/transformers_utils/configs/chatglm.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs)