From 259bc5bdb1993ee49cc52d354ad4a7cc747bcbbb Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 28 Oct 2024 20:38:44 +0000 Subject: [PATCH 1/6] Fix prefix strings for quantized VLMs --- vllm/model_executor/models/blip2.py | 5 +- vllm/model_executor/models/gemma.py | 58 +++++++++++++------ vllm/model_executor/models/internvl.py | 5 +- vllm/model_executor/models/llava.py | 20 +++++-- vllm/model_executor/models/llava_next.py | 10 +++- .../model_executor/models/llava_next_video.py | 10 +++- vllm/model_executor/models/llava_onevision.py | 10 +++- vllm/model_executor/models/minicpmv.py | 34 ++++++++--- vllm/model_executor/models/mllama.py | 1 + vllm/model_executor/models/paligemma.py | 7 ++- vllm/model_executor/models/phi3v.py | 19 ++++-- vllm/model_executor/models/pixtral.py | 5 +- vllm/model_executor/models/ultravox.py | 5 +- vllm/model_executor/models/utils.py | 15 +++++ 14 files changed, 155 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index cd2013e91514d..c3b3cc8a4ddb6 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -507,7 +507,10 @@ def __init__(self, ) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 436bd45d53f35..57b2b43c82f89 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,7 +43,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -83,16 +84,23 @@ def __init__( hidden_act: Optional[str] = None, hidden_activation: Optional[str] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -104,15 +112,18 @@ def forward(self, x): class GemmaAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int = 8192, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -142,12 +153,14 @@ def __init__(self, self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -186,6 +199,7 @@ def __init__( config: GemmaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -198,6 +212,7 @@ def __init__( rope_theta=config.rope_theta, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, @@ -205,6 +220,7 @@ def __init__( hidden_act=config.hidden_act, hidden_activation=getattr(config, "hidden_activation", None), quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -259,8 +275,8 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config - ), + lambda prefix: GemmaDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -366,6 +382,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -375,7 +392,10 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = GemmaModel(config, cache_config, quant_config) + self.model = GemmaModel(config, + cache_config, + quant_config, + prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 3ae37d9fe5d85..1c1fde5b30983 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -439,7 +439,10 @@ def __init__(self, ) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.mlp1 = self._init_mlp1(config) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b005d83c17f90..eda99c029881f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -210,6 +210,7 @@ def init_vision_tower_for_llava( quant_config: Optional[QuantizationConfig], *, require_post_norm: Optional[bool] = None, + prefix: str = "", ): vision_config = hf_config.vision_config @@ -224,23 +225,26 @@ def init_vision_tower_for_llava( if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( vision_config, - quant_config, + quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, + prefix=prefix, ) elif isinstance(vision_config, SiglipVisionConfig): return SiglipVisionModel( vision_config, - quant_config, + quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, + prefix=prefix, ) elif isinstance(vision_config, PixtralVisionConfig): return PixtralHFVisionModel( vision_config, - quant_config, + quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, + prefix=prefix, ) msg = f"Unsupported vision config: {type(vision_config)}" @@ -274,14 +278,20 @@ def __init__(self, # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( - config, quant_config, require_post_norm=False) + config, + quant_config, + require_post_norm=False, + prefix="vision_tower") self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 2a582deeaa2c9..f85129b206919 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -293,7 +293,10 @@ def __init__(self, # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( - config, quant_config, require_post_norm=False) + config, + quant_config, + require_post_norm=False, + prefix="vision_tower") self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( @@ -302,7 +305,10 @@ def __init__(self, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") # The same model class supports both language generation and embedding # because the architecture name is the same diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 43eec43d56643..b8051d5fc6ae2 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -257,14 +257,20 @@ def __init__(self, # Initialize the vision tower only up to the required feature layer self.vision_tower = init_vision_tower_for_llava( - config, quant_config, require_post_norm=False) + config, + quant_config, + require_post_norm=False, + prefix="vision_tower") self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 47e62409072e5..214a032d65fb8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -401,10 +401,16 @@ def __init__(self, # Initialize the vision tower only up to the required feature layer self.vision_tower = init_vision_tower_for_llava( - config, quant_config, require_post_norm=False) + config, + quant_config, + require_post_norm=False, + prefix="vision_tower") self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2ec51dc4647f5..a270282d87bc8 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -394,8 +394,11 @@ def __init__( self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) - self.llm = self.init_llm(config, cache_config, quant_config) - self.vpm = self.init_vision_module(config, quant_config) + self.llm = self.init_llm(config, + cache_config, + quant_config, + prefix="llm") + self.vpm = self.init_vision_module(config, quant_config, prefix="vpm") param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else @@ -403,9 +406,11 @@ def __init__( self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) + # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix="llm.lm_head") self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -644,6 +649,7 @@ def init_llm( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -651,6 +657,7 @@ def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -690,17 +697,20 @@ def init_llm( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: return LLMWrapper(MiniCPMModel(config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), name="model") def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: # TODO :refactor this vision model try: @@ -819,19 +829,23 @@ def init_llm( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: return LLMWrapper(LlamaModel(config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), name="model") def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -935,20 +949,24 @@ def init_llm( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: return LLMWrapper(Qwen2Model(config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), name="model") def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 44ef49729c969..27e83b487748c 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1002,6 +1002,7 @@ def __init__( org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=f"{prefix}.lm_head", ) def forward( diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 7a62a098a4525..8e29c6079b994 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -143,14 +143,17 @@ def __init__(self, self.multimodal_config = multimodal_config self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config) + quant_config, + prefix="vision_tower") self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config self.language_model = GemmaForCausalLM(config.text_config, - cache_config, quant_config) + cache_config, + quant_config, + prefix="language_model") logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 855a9b17585a4..0962d3d3847c9 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -71,7 +71,8 @@ def _init_img_processor(hf_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig]): + quant_config: Optional[QuantizationConfig], + prefix: str = "") -> CLIPVisionModel: clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG layer_idx = hf_config.img_processor.get('layer_idx', -2) @@ -86,6 +87,7 @@ def _init_img_processor(hf_config: PretrainedConfig, clip_config, quant_config, num_hidden_layers_override=num_hidden_layers, + prefix=prefix, ) return img_processor @@ -152,15 +154,18 @@ def get_img_features(self, class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """Phi3 Image embedding with HD transform.""" - def __init__(self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig]) -> None: + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "") -> None: super().__init__() # n_embed or hidden_size hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - self.img_processor = _init_img_processor(config, quant_config) + self.img_processor = _init_img_processor( + config, quant_config, prefix=f"{prefix}.img_processor") image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] @@ -537,11 +542,15 @@ def __init__(self, config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, + prefix="model.embed_tokens", ) # TODO: Optionally initializes this for supporting input embeddings. - self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config) + self.vision_embed_tokens = Phi3HDImageEmbedding( + config, quant_config, prefix="model.vision_embed_tokens") + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" self.language_model = LlamaForCausalLM(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index a9dbb3823743a..6b53bf5660096 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -164,7 +164,10 @@ def __init__(self, # init MistralForCausalLM self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 5f33b872beecb..f08e4aa355086 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -357,7 +357,10 @@ def __init__(self, )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, cache_config, quant_config) + config.text_config, + cache_config, + quant_config, + prefix="language_model") if config.text_model_id is not None: self.secondary_weights.append( DefaultModelLoader.Source(model_or_path=config.text_model_id, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6995f5805c5e1..0aecb5d151a45 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -242,6 +242,7 @@ def init_vllm_registered_model( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, + prefix: str = "", ) -> nn.Module: """ Helper function to initialize an inner model registered to vLLM, @@ -257,6 +258,7 @@ def init_vllm_registered_model( lora_config=lora_config, multimodal_config=multimodal_config, scheduler_config=scheduler_config, + prefix=prefix, ) @@ -610,3 +612,16 @@ def get_vit_attn_backend() -> _Backend: else: selected_backend = _Backend.XFORMERS return selected_backend + + +def maybe_prefix(prefix: str, name: str) -> str: + """Add a prefix to a name if the prefix is non-empty. + + Args: + prefix: The prefix to add. If empty, no prefix will be added. + name: The name to potentially prefix. + + Returns: + The string "prefix.name" if prefix was non-empty, otherwise just "name". + """ + return name if not prefix else f"{prefix}.{name}" From 61c49b7c2e1e65810f2dc19d30c91aedd6c75818 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 28 Oct 2024 21:05:41 +0000 Subject: [PATCH 2/6] Add prefix to language models --- vllm/model_executor/models/internlm2.py | 59 +++++++++++++++++-------- vllm/model_executor/models/llama.py | 7 ++- vllm/model_executor/models/opt.py | 34 +++++++++++--- vllm/model_executor/models/qwen2.py | 50 +++++++++++++++------ 4 files changed, 110 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 9a77e48626ca5..582532ebfe188 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -30,7 +30,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class InternLM2MLP(nn.Module): @@ -41,16 +42,23 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.w2 = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.w2", + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -75,6 +83,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -108,12 +117,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.wqkv", ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.wo", ) self.rotary_emb = get_rope( @@ -123,12 +134,15 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def split_qkv(self, qkv: torch.Tensor): seq_len = qkv.shape[0] @@ -176,6 +190,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -192,15 +207,18 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attention", ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.feed_forward", ) self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + eps=config.rms_norm_eps, + prefix=f"{prefix}.attention_norm") self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -251,8 +269,8 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer(config, cache_config, - quant_config), + lambda prefix: InternLMDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -306,14 +324,19 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, cache_config, quant_config) + self.model = InternLM2Model(config, + cache_config, + quant_config, + prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "output")) if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b0ca1fe006239..98c53bdaae811 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -55,7 +55,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class LlamaMLP(nn.Module): @@ -500,6 +501,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -510,7 +512,7 @@ def __init__( cache_config, quant_config, lora_config=lora_config, - prefix="model") + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -526,6 +528,7 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 37c3fa919124e..10cca8b56268a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -43,7 +43,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -68,6 +69,7 @@ def __init__( bias: bool = True, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.embed_dim = embed_dim @@ -85,18 +87,21 @@ def __init__( total_num_heads, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.out_proj", ) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -118,6 +123,7 @@ def __init__( config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -128,6 +134,7 @@ def __init__( bias=config.enable_bias, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.do_layer_norm_before = config.do_layer_norm_before @@ -139,6 +146,7 @@ def __init__( config.ffn_dim, bias=config.enable_bias, quant_config=quant_config, + prefix=f"{prefix}.fc1", ) self.activation_fn = get_act_fn(config.activation_function, quant_config, config.ffn_dim) @@ -147,6 +155,7 @@ def __init__( self.embed_dim, bias=config.enable_bias, quant_config=quant_config, + prefix=f"{prefix}.fc2", ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -214,7 +223,8 @@ def __init__( self.project_out = ReplicatedLinear(config.hidden_size, config.word_embed_proj_dim, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.project_out") else: self.project_out = None @@ -222,7 +232,8 @@ def __init__( self.project_in = ReplicatedLinear(config.word_embed_proj_dim, config.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.project_in") else: self.project_in = None @@ -239,7 +250,8 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OPTDecoderLayer(config, cache_config, quant_config), + lambda prefix: OPTDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -288,9 +300,13 @@ def __init__( config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() - self.decoder = OPTDecoder(config, cache_config, quant_config) + self.decoder = OPTDecoder(config, + cache_config, + quant_config, + prefix=f"{prefix}.decoder") self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) @@ -335,11 +351,15 @@ def __init__( config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, cache_config, quant_config) + self.model = OPTModel(config, + cache_config, + quant_config, + prefix=maybe_prefix(prefix, "model")) if self.config.tie_word_embeddings: self.lm_head = self.model.decoder.embed_tokens else: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 23eb1482ffef1..db1029345a8ac 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class Qwen2MLP(nn.Module): @@ -60,16 +61,23 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -92,7 +100,8 @@ def __init__(self, rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None) -> None: + rope_scaling: Optional[Tuple] = None, + prefix: str = "") -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -122,12 +131,14 @@ def __init__(self, self.total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -142,7 +153,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -166,6 +178,7 @@ def __init__( config: Qwen2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -180,12 +193,15 @@ def __init__( rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - rope_scaling=rope_scaling) + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -241,6 +257,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) else: self.embed_tokens = PPMissingLayer() @@ -249,7 +266,8 @@ def __init__( config.num_hidden_layers, lambda prefix: Qwen2DecoderLayer(config=config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers", ) @@ -393,6 +411,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None @@ -412,14 +431,19 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(config, cache_config, quant_config) + self.model = Qwen2Model(config, + cache_config, + quant_config, + prefix=maybe_prefix(prefix, "model")) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From a41cff0cae87f5990a46294b20b11439e0f411e6 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 29 Oct 2024 16:06:54 +0000 Subject: [PATCH 3/6] Fix qwen2-vl --- vllm/model_executor/models/qwen2_vl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4e60fe70b25f1..633d66b4af31a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -938,7 +938,10 @@ def __init__(self, quant_config=None, ) - self.model = Qwen2Model(config, cache_config, quant_config) + self.model = Qwen2Model(config, + cache_config, + quant_config, + prefix="model") if get_pp_group().is_last_rank: if config.tie_word_embeddings: @@ -946,7 +949,8 @@ def __init__(self, else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix="lm_head") else: self.lm_head = PPMissingLayer() From 75a4a30eeede0e06884ce2a9752d0909365d56f4 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 29 Oct 2024 16:23:27 +0000 Subject: [PATCH 4/6] Add optional prefix to model_loader build_model Signed-off-by: mgoin --- vllm/model_executor/model_loader/loader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 813f58339da37..944e9319b3280 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -147,15 +147,20 @@ def _get_model_initialization_kwargs( return extra_kwargs -def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, +def build_model(model_class: Type[nn.Module], + hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], *, + quant_config: Optional[QuantizationConfig], + *, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + scheduler_config: Optional[SchedulerConfig], + prefix: Optional[str] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, scheduler_config) + if prefix: + extra_kwargs["prefix"] = prefix return model_class(config=hf_config, cache_config=cache_config, From 2b7105f338bedd4f6c91f11163cb1c4588de6d2c Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 29 Oct 2024 17:54:00 +0000 Subject: [PATCH 5/6] Fix internlm2 Signed-off-by: mgoin --- vllm/model_executor/models/internlm2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 582532ebfe188..313d98b649b48 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -217,8 +217,7 @@ def __init__( prefix=f"{prefix}.feed_forward", ) self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - prefix=f"{prefix}.attention_norm") + eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( From 1e7e2363a8bffc8bf4861c432af18f0a09b510e1 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 29 Oct 2024 23:15:49 +0000 Subject: [PATCH 6/6] Support quantization of Qwen2VisionTransformer for Qwen2-VL --- vllm/model_executor/models/qwen2_vl.py | 58 ++++++++++++++++---------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 633d66b4af31a..1e12c2332b65e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -126,15 +126,18 @@ def __init__( hidden_features: int = None, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.fc1") self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.fc2") def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -196,6 +199,7 @@ def __init__( num_heads: Optional[int] = None, projection_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() # Per attention head and per partition values. @@ -207,10 +211,12 @@ def __init__( self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() @@ -310,6 +316,7 @@ def __init__( act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() if norm_layer is None: @@ -321,11 +328,13 @@ def __init__( self.attn = Qwen2VisionAttention(embed_dim=dim, num_heads=num_heads, projection_size=dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mlp") def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor: @@ -374,6 +383,7 @@ def __init__( norm_layer: Type[nn.Module] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -384,12 +394,14 @@ def __init__( ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=True, - quant_config=quant_config), + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, - quant_config=quant_config), + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -440,6 +452,7 @@ def __init__( vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -467,28 +480,29 @@ def __init__( self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ - Qwen2VisionBlock( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - ) for _ in range(depth) + Qwen2VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) ]) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, + prefix=f"{prefix}.merger", ) @property def dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype + return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device + return self.patch_embed.proj.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] @@ -932,10 +946,8 @@ def __init__(self, self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - - # NOTE: Qwen2-VL vision encoder does not support any - # quantization method now. - quant_config=None, + quant_config=quant_config, + prefix="visual", ) self.model = Qwen2Model(config, @@ -1175,7 +1187,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - if "visual" in name and "qkv.weight" in name: + if "visual" in name and name.endswith("qkv.weight"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads @@ -1184,7 +1196,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): visual_embed_dim) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif "visual" in name and "qkv.bias" in name: + elif "visual" in name and name.endswith("qkv.bias"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads