Skip to content

Commit

Permalink
Add support for CodeLlama (vllm-project#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored and randxie committed Aug 29, 2023
1 parent 2bdeb00 commit 99662c2
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -99,6 +100,7 @@ def __init__(
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta

self.qkv_proj = ColumnParallelLinear(
hidden_size,
Expand All @@ -118,6 +120,7 @@ def __init__(
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)

Expand All @@ -143,10 +146,13 @@ class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down

0 comments on commit 99662c2

Please sign in to comment.