Skip to content

Commit

Permalink
Migrate MistralForCausalLM to LlamaForCausalLM (vllm-project#2868)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored and jimpang committed Mar 4, 2024
1 parent eaefc80 commit 031fd41
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 379 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -141,7 +142,8 @@ def __init__(
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)

def forward(
self,
Expand Down Expand Up @@ -172,6 +174,7 @@ def __init__(
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand All @@ -182,6 +185,7 @@ def __init__(
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down
Loading

0 comments on commit 031fd41

Please sign in to comment.