From 63cd94348dbe6304807daa549430ab1adb94053e Mon Sep 17 00:00:00 2001 From: ki6an Date: Mon, 1 Apr 2024 04:41:08 +0000 Subject: [PATCH 1/2] fix attn bias in llama --- vllm/model_executor/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 57857deb9eb86..6062d65e2aff8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -172,6 +172,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) + attention_bias = getattr(config, "attention_bias", False) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -181,7 +182,7 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, - bias=getattr(config, "bias", False), + bias=attention_bias, sliding_window=sliding_window, ) self.mlp = LlamaMLP( From 5a179d0b10412b95a71759b678252974647b5ee0 Mon Sep 17 00:00:00 2001 From: roy Date: Mon, 8 Apr 2024 22:27:25 +0800 Subject: [PATCH 2/2] add bias support and comments --- vllm/model_executor/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6062d65e2aff8..96d229701c32d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -172,7 +172,10 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) - attention_bias = getattr(config, "attention_bias", False) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads,