diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6202e81fffa7c..1d0353d7d396e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -175,7 +175,8 @@ def __init__( self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings,