From 66a0790d96200611ad26a84ff630db360c1c7223 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Mon, 6 May 2024 17:18:47 -0400 Subject: [PATCH] added mlp and attn bias option to flash and paged llama models (#85) #### Motivation The `Calico` models currently set the mlp and attention bias to true, which was hard-coded to false in flash and paged llama implementations. This will use the config params set in https://github.com/huggingface/transformers/pull/30031 to set those values properly. #### Modifications - added attention_bias, mlp_bias to config for Flash and Paged Llama implementations (default is False) - set bias in attention and mlp to the config value #### Result Models should be able to load properly if containing attention and mlp bias --------- Signed-off-by: Joshua Rosenkranz Signed-off-by: Joe Runde Co-authored-by: Joe Runde --- .../inference_engine/tgis_native.py | 5 +++++ .../models/custom_modeling/flash_llama_modeling.py | 14 +++++++++----- .../models/custom_modeling/paged_llama_modeling.py | 14 +++++++++----- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/inference_engine/tgis_native.py b/server/text_generation_server/inference_engine/tgis_native.py index c11c6153..86291815 100644 --- a/server/text_generation_server/inference_engine/tgis_native.py +++ b/server/text_generation_server/inference_engine/tgis_native.py @@ -101,6 +101,11 @@ def __init__( model_class = FlashRWForCausalLM elif model_type == "llama": + # See: https://github.com/ibm-granite/vllm_granite/blob/main/vllm/model_executor/models/llama.py#L353-L354 + if self._config.tie_word_embeddings: + aliases = { + "lm_head.weight": ["model.embed_tokens.weight"] + } if PAGED_ATTENTION: from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM model_class = PagedLlamaForCausalLM diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d56178ad..fe7be06c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -64,6 +64,8 @@ def __init__( tie_word_embeddings=False, rope_scaling=None, rope_theta=10000.0, + attention_bias=False, + mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size @@ -85,6 +87,8 @@ def __init__( self.use_cache = use_cache self.rope_scaling = rope_scaling self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias super().__init__( pad_token_id=pad_token_id, @@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize)) class FlashLlamaAttention(torch.nn.Module): @@ -220,13 +224,13 @@ def __init__( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, - bias=False, + bias=config.attention_bias, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=False, + bias=config.attention_bias, ) def forward( @@ -309,13 +313,13 @@ def __init__(self, prefix, config, weights): prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, - bias=False, + bias=config.mlp_bias, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, - bias=False, + bias=config.mlp_bias, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/paged_llama_modeling.py b/server/text_generation_server/models/custom_modeling/paged_llama_modeling.py index 636c983d..05d9eda7 100644 --- a/server/text_generation_server/models/custom_modeling/paged_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/paged_llama_modeling.py @@ -64,6 +64,8 @@ def __init__( tie_word_embeddings=False, rope_scaling=None, rope_theta=10000.0, + attention_bias=False, + mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size @@ -85,6 +87,8 @@ def __init__( self.use_cache = use_cache self.rope_scaling = rope_scaling self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias super().__init__( pad_token_id=pad_token_id, @@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize)) class PagedLlamaAttention(torch.nn.Module): @@ -207,13 +211,13 @@ def __init__( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, - bias=False, + bias=config.attention_bias, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=False, + bias=config.attention_bias, ) def forward( @@ -280,13 +284,13 @@ def __init__(self, prefix, config, weights): prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, - bias=False, + bias=config.mlp_bias, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, - bias=False, + bias=config.mlp_bias, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size()