Skip to content

Commit

Permalink
Revert MptConfig to MPTConfig (vllm-project#1668)
Browse files Browse the repository at this point in the history
  • Loading branch information
megha95 authored Nov 16, 2023
1 parent e3d1abb commit 638cecd
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 26 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MptForCausalLM,
"MPTForCausalLM": MptForCausalLM,
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mpt import MptForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
Expand All @@ -29,7 +29,7 @@
"GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"MptForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
Expand Down
40 changes: 20 additions & 20 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
import torch.nn as nn
from transformers import MptConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
Expand All @@ -22,6 +21,7 @@
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand All @@ -39,21 +39,21 @@ def _get_alibi_slopes(
return slopes


class MptAttention(nn.Module):
class MPTAttention(nn.Module):

def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.alibi_bias_max = config.attn_config.alibi_bias_max
assert not config.attn_config.prefix_lm
assert config.attn_config.alibi
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]

# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
Expand Down Expand Up @@ -113,11 +113,11 @@ def forward(
return output


class MptMLP(nn.Module):
class MPTMLP(nn.Module):

def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
Expand Down Expand Up @@ -145,19 +145,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class MptBlock(nn.Module):
class MPTBlock(nn.Module):

def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MptAttention(config, linear_method)
self.attn = MPTAttention(config, linear_method)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MptMLP(config, linear_method)
self.ffn = MPTMLP(config, linear_method)

def forward(
self,
Expand All @@ -182,11 +182,11 @@ def forward(
return hidden_states


class MptModel(nn.Module):
class MPTModel(nn.Module):

def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
Expand All @@ -198,7 +198,7 @@ def __init__(
config.d_model,
)
self.blocks = nn.ModuleList(
[MptBlock(config, linear_method) for _ in range(config.n_layers)])
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
Expand Down Expand Up @@ -233,19 +233,19 @@ def forward(
return hidden_states


class MptForCausalLM(nn.Module):
class MPTForCausalLM(nn.Module):

def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.linear_method = linear_method

self.transformer = MptModel(config, linear_method)
self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)

Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional

from transformers import AutoConfig, MptConfig, PretrainedConfig
from transformers import AutoConfig, PretrainedConfig

from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import

_CONFIG_REGISTRY = {
"aquila": AquilaConfig,
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig,
"mpt": MptConfig,
"mpt": MPTConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
Expand All @@ -12,6 +13,7 @@
"AquilaConfig",
"BaiChuanConfig",
"ChatGLMConfig",
"MPTConfig",
"QWenConfig",
"RWConfig",
"YiConfig",
Expand Down
Loading

0 comments on commit 638cecd

Please sign in to comment.