Skip to content

Commit

Permalink
Refactor llama family models (vllm-project#2637)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Feb 13, 2024
1 parent da5d9fb commit a630978
Show file tree
Hide file tree
Showing 17 changed files with 236 additions and 2,720 deletions.
25 changes: 25 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@
from vllm._C import ops


class LayerNorm(nn.LayerNorm):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__(hidden_size, eps=eps)

def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""normalization."""
if residual is not None:
x = x + residual
residual = x
x = super().forward(x)
if residual is None:
return x
else:
return x, residual


class RMSNorm(nn.Module):
"""Root mean square normalization.
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

# Architecture -> (module, class).
_MODELS = {
"AquilaModel": ("aquila", "AquilaForCausalLM"),
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
Expand All @@ -24,12 +24,12 @@
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
Expand All @@ -41,7 +41,6 @@
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM")
}

# Models not supported by ROCm.
Expand Down
Loading

0 comments on commit a630978

Please sign in to comment.