From ca6e4822a4a54c3dd12a9ae7a5d731d82ad8c399 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 22 Feb 2024 00:59:44 +0000 Subject: [PATCH 1/3] Use RMSNorm for Gemma --- vllm/model_executor/models/gemma.py | 58 +++++++++++++---------------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index affe54c448a2c..485733ebf7eac 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,6 +22,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -40,21 +41,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class GemmaRMSNorm(nn.Module): - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * (1 + self.weight) - - class GemmaMLP(nn.Module): def __init__( @@ -185,10 +171,10 @@ def __init__( intermediate_size=config.intermediate_size, linear_method=linear_method, ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -196,25 +182,27 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, + residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states + return hidden_states, residual class GemmaModel(nn.Module): @@ -235,7 +223,7 @@ def __init__( GemmaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -248,15 +236,17 @@ def forward( # Normalize the embedding by sqrt(hidden_size) hidden_states = hidden_states * (self.config.hidden_size**0.5) + residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -322,6 +312,10 @@ def load_weights(self, if "lm_head" in name: continue param = params_dict[name] + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + if "norm.weight" in name: + loaded_weight += 1.0 weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) @@ -329,5 +323,5 @@ def load_weights(self, unloaded_params = params_dict.keys() - loaded_params if unloaded_params: raise RuntimeError( - f"Some weights are not initialized from checkpoints: {unloaded_params}" - ) + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}") From ef29c951b2ac97ed915ae0bea0706a8ea9b38715 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 22 Feb 2024 01:00:01 +0000 Subject: [PATCH 2/3] Minor --- vllm/model_executor/models/gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 485733ebf7eac..4b778cb5c71ee 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -311,11 +311,11 @@ def load_weights(self, # Skip loading extra layer for lora models. if "lm_head" in name: continue - param = params_dict[name] # GemmaRMSNorm is different from Llama's in that it multiplies # (1 + weight) to the output, instead of just weight. if "norm.weight" in name: loaded_weight += 1.0 + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) From 1986eb95c82b9eab27832181d80f54663fb55760 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 22 Feb 2024 02:22:12 +0000 Subject: [PATCH 3/3] Address comment --- vllm/model_executor/models/gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 4b778cb5c71ee..03bd149c001d3 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -234,7 +234,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Normalize the embedding by sqrt(hidden_size) - hidden_states = hidden_states * (self.config.hidden_size**0.5) + hidden_states *= self.config.hidden_size**0.5 residual = None for i in range(len(self.layers)):