Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CodeLlama #854

Merged
merged 2 commits into from
Aug 25, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -99,6 +100,7 @@ def __init__(
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta

self.qkv_proj = ColumnParallelLinear(
hidden_size,
Expand All @@ -118,6 +120,7 @@ def __init__(
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)

Expand All @@ -143,10 +146,16 @@ class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
try:
rope_theta = config.rope_theta
except AttributeError:
rope_theta = 10000
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down
Loading