From e8f9cc58cedf09f7bd17fa057594d0c62c03f20b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 21 Feb 2024 18:28:23 -0800 Subject: [PATCH] Use Llama RMSNorm custom op for Gemma (#2974) --- vllm/model_executor/models/gemma.py | 60 +++++++++++++---------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index affe54c448a2c..03bd149c001d3 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,6 +22,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -40,21 +41,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class GemmaRMSNorm(nn.Module): - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * (1 + self.weight) - - class GemmaMLP(nn.Module): def __init__( @@ -185,10 +171,10 @@ def __init__( intermediate_size=config.intermediate_size, linear_method=linear_method, ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -196,25 +182,27 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, + residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states + return hidden_states, residual class GemmaModel(nn.Module): @@ -235,7 +223,7 @@ def __init__( GemmaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -246,17 +234,19 @@ def forward( ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Normalize the embedding by sqrt(hidden_size) - hidden_states = hidden_states * (self.config.hidden_size**0.5) + hidden_states *= self.config.hidden_size**0.5 + residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -321,6 +311,10 @@ def load_weights(self, # Skip loading extra layer for lora models. if "lm_head" in name: continue + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + if "norm.weight" in name: + loaded_weight += 1.0 param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -329,5 +323,5 @@ def load_weights(self, unloaded_params = params_dict.keys() - loaded_params if unloaded_params: raise RuntimeError( - f"Some weights are not initialized from checkpoints: {unloaded_params}" - ) + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}")