From 534910b1f25d083d810229039d38fc71637428e3 Mon Sep 17 00:00:00 2001 From: ericperfect Date: Tue, 5 Nov 2024 12:14:25 +0800 Subject: [PATCH 1/8] Support lora of Qwen2VLForConditionalGeneration Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 31 ++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d801903f8f9fe..c7c7e83805a88 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,7 +41,7 @@ from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, MultiModalConfig, LoRAConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, @@ -66,7 +66,7 @@ from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -928,14 +928,37 @@ def input_processor_for_qwen2_vl( "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__(self, config: Qwen2VLConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None) -> None: + super().__init__() assert not cache_config.enable_prefix_caching, \ From ff743794a5c73855436cc843c2339236ad9c608e Mon Sep 17 00:00:00 2001 From: ericperfect Date: Tue, 5 Nov 2024 18:18:36 +0800 Subject: [PATCH 2/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 376 ++++++++++--------------- 1 file changed, 152 insertions(+), 224 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c7c7e83805a88..bae9a8eb840d1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import partial +from functools import lru_cache, partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) @@ -41,11 +41,11 @@ from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, MultiModalConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -61,25 +61,25 @@ MultiModalInputs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope -from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) logger = init_logger(__name__) + # === Vision Inputs === # class Qwen2VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: + """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ @@ -98,13 +98,13 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, - Qwen2VLImageEmbeddingInputs] +Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoInputs(TypedDict): pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, + """Shape: + `(num_patches, num_channels * temporal_patch_size * patch_size * patch_size)` """ @@ -121,23 +121,20 @@ class Qwen2VLVideoInputs(TypedDict): class Qwen2VisionMLP(nn.Module): def __init__( - self, - in_features: int, - hidden_features: int = None, - act_layer: Type[nn.Module] = QuickGELU, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + in_features: int, + hidden_features: int = None, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + quant_config=quant_config) self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + quant_config=quant_config) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -194,12 +191,11 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, class Qwen2VisionAttention(nn.Module): def __init__( - self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + embed_dim: Optional[int] = None, + num_heads: Optional[int] = None, + projection_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -211,26 +207,24 @@ def __init__( self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + quant_config=quant_config) self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + quant_config=quant_config) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -282,7 +276,7 @@ def forward( dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True + cu_seqlens[i - 1]:cu_seqlens[i]] = True output = F.scaled_dot_product_attention(q, k, v, @@ -309,14 +303,13 @@ def forward( class Qwen2VisionBlock(nn.Module): def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() if norm_layer is None: @@ -328,13 +321,11 @@ def __init__( self.attn = Qwen2VisionAttention(embed_dim=dim, num_heads=num_heads, projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + quant_config=quant_config) self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + quant_config=quant_config) def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor: @@ -348,11 +339,11 @@ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, class Qwen2VisionPatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_chans: int = 3, - embed_dim: int = 1152, + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size @@ -377,16 +368,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionPatchMerger(nn.Module): def __init__( - self, - d_model: int, - context_dim: int, - norm_layer: Type[nn.Module] = None, - spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + d_model: int, + context_dim: int, + norm_layer: Type[nn.Module] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.hidden_size = context_dim * (spatial_merge_size**2) + self.hidden_size = context_dim * (spatial_merge_size ** 2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) @@ -394,14 +384,12 @@ def __init__( ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + quant_config=quant_config), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + quant_config=quant_config), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -422,7 +410,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.dim = dim self.theta = theta inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -431,9 +419,9 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( + self.inv_freq = 1.0 / (self.theta ** (torch.arange( 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) + / self.dim)) seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) @@ -448,11 +436,10 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): def __init__( - self, - vision_config: Qwen2VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -480,29 +467,28 @@ def __init__( self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ - Qwen2VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) for _ in range(depth) ]) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.merger", ) @property def dtype(self) -> torch.dtype: - return self.patch_embed.proj.weight.dtype + return self.blocks[0].mlp.fc2.weight.dtype @property def device(self) -> torch.device: - return self.patch_embed.proj.weight.device + return self.blocks[0].mlp.fc2.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] @@ -530,9 +516,9 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: return rotary_pos_emb def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, + self, + x: torch.Tensor, + grid_thw: torch.Tensor, ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -544,7 +530,7 @@ def forward( # compute cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -559,14 +545,13 @@ def forward( # === Vision input helpers === # +cached_get_processor = lru_cache(get_processor) + def mm_input_mapper_for_qwen2_vl( - ctx: InputContext, - data: MultiModalData[object], - data_type_key: str, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, + ctx: InputContext, + data: MultiModalData[object], + data_type_key: str, ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): @@ -575,19 +560,8 @@ def mm_input_mapper_for_qwen2_vl( "image_grid_thw": data.get("image_grid_thw"), }) model_config = ctx.model_config - # Handle mm processor kwargs; we pass these at creation time - # because preprocess() in transformers doesn't expose them - mm_processor_kwargs = {} - if min_pixels: - mm_processor_kwargs["min_pixels"] = min_pixels - if max_pixels: - mm_processor_kwargs["max_pixels"] = max_pixels - image_processor = cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs, - ) + model_config.model, trust_remote_code=model_config.trust_remote_code) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") @@ -618,14 +592,14 @@ def mm_input_mapper_for_qwen2_vl( def _get_vision_info( - image_processor, - height: int, - width: int, - min_pixels: int, - max_pixels: int, - do_resize: bool = True, - data_type_key: str = "image", - mm_count: int = 1, + image_processor, + height: int, + width: int, + min_pixels: int, + max_pixels: int, + do_resize: bool = True, + data_type_key: str = "image", + mm_count: int = 1, ): """Get information (resized height / width and number of vision tokens) of input image / video frame.""" @@ -657,39 +631,28 @@ def _get_vision_info( def _get_max_image_info( - image_processor, - data_type_key: str = "image", - mm_count: int = 1, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, + image_processor, + data_type_key: str = "image", + mm_count: int = 1, ): - # Limit min / max pixels unless they're explicitly provided - if min_pixels is None: - min_pixels = max(image_processor.min_pixels, 28 * 28) - if max_pixels is None: - max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28) - return _get_vision_info( image_processor, height=9999999, width=9999999, - min_pixels=min_pixels, - max_pixels=max_pixels, + + # Limit min / max pixels. + min_pixels=max(image_processor.min_pixels, 28 * 28), + max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28), data_type_key=data_type_key, mm_count=mm_count, ) -def get_max_qwen2_vl_mm_tokens(ctx: InputContext, - data_type_key: str, - *, - min_pixels=None, - max_pixels=None) -> int: +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: image_processor = cached_get_image_processor(ctx.model_config.model) max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key=data_type_key, - mm_count=1, min_pixels=min_pixels, - max_pixels=max_pixels) + mm_count=1) return max_llm_image_tokens @@ -700,20 +663,14 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, def dummy_data_for_qwen2_vl( - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key="image", - mm_count=num_images, min_pixels=min_pixels, - max_pixels=max_pixels) + mm_count=num_images) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " @@ -724,11 +681,10 @@ def dummy_data_for_qwen2_vl( num_videos = mm_counts["video"] max_resized_height, max_resized_width, max_llm_video_tokens = \ _get_max_image_info(image_processor, data_type_key="video", - mm_count=num_videos, min_pixels=min_pixels, - max_pixels=max_pixels) + mm_count=num_videos) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( - f"Qwen2-VL cannot process {num_videos} videos in a prompt, " + f"Qwen2-VL cannot process {num_images} videos in a prompt, " "please increase max_model_len or reduce video limit by " "--limit-mm-per-prompt.") @@ -744,18 +700,15 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - return DummyData(dummy_seqdata, { - "image": - dummy_image if num_images == 1 else [dummy_image] * num_images - }) + return dummy_seqdata, { + "image": dummy_image if num_images == 1 else [dummy_image] * num_images + } def _get_llm_num_vision_tokens( - mm_inputs: list, - data_type_key: str, - image_processor, - min_pixels: int, - max_pixels: int, + mm_inputs: list, + data_type_key: str, + image_processor, ): """Get number of vision tokens of multimodal inputs. @@ -765,13 +718,12 @@ def _get_llm_num_vision_tokens( image = to_numpy_array(mm_inputs[0]) input_data_format = infer_channel_dimension_format(image) height, width = get_image_size(image, channel_dim=input_data_format) - _, _, llm_num_vision_tokens = _get_vision_info( image_processor, height=height, width=width, - min_pixels=min_pixels, - max_pixels=max_pixels, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, do_resize=image_processor.do_resize, data_type_key=data_type_key, mm_count=len(mm_inputs), @@ -781,8 +733,7 @@ def _get_llm_num_vision_tokens( def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key: str, image_processor: Any, - prompt_token_ids: List[int], min_pixels: Optional[int], - max_pixels: Optional[int]) -> List[int]: + prompt_token_ids: List[int]) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -793,8 +744,6 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key (str): The type of the multi-modal input. image_processor (Any): The image processor used to process the inputs. prompt_token_ids (List[int]): The list of token IDs in the prompt. - min_pixels (int): min pixels to used for img processing - max_pixels (int): max pixels to be used for img processing Returns: List[int]: The list of token IDs for the multi-modal inputs. @@ -811,8 +760,6 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, [data] if data_type_key == "image" else data, data_type_key=data_type_key, image_processor=image_processor, - min_pixels=min_pixels, - max_pixels=max_pixels, ) if cnt == 0: end_idx = indices[cnt] @@ -827,13 +774,10 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, def input_processor_for_qwen2_vl( - ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, + ctx: InputContext, + inputs: DecoderOnlyInputs, ) -> DecoderOnlyInputs: - multi_modal_data = inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is None: return inputs @@ -842,11 +786,6 @@ def input_processor_for_qwen2_vl( processor = cached_get_processor(ctx.model_config.model) image_processor = processor.image_processor - # Apply processor kwarg overrides for image processor options - min_pixels = min_pixels if min_pixels else image_processor.min_pixels - max_pixels = max_pixels if max_pixels else image_processor.max_pixels - - model_config = ctx.model_config hf_config = ctx.get_hf_config(Qwen2VLConfig) # To avoid redundant processing of vision objects (resize, rescale, etc.), @@ -862,11 +801,14 @@ def input_processor_for_qwen2_vl( # return_tensors="pt") # prompt_token_ids = inputs["input_ids"][0].tolist() - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - - prompt_token_ids = inputs["prompt_token_ids"] + prompt_token_ids = inputs.get("prompt_token_ids", None) + if prompt_token_ids is None: + prompt = inputs["prompt"] + prompt_token_ids = processor.tokenizer( + prompt, + padding=True, + return_tensors=None, + )["input_ids"] # Expand image pad tokens. @@ -891,30 +833,20 @@ def input_processor_for_qwen2_vl( else: prompt_token_ids = _expand_pad_tokens(image_inputs, hf_config.image_token_id, - make_batched_images, - "image", + make_batched_images, "image", image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + prompt_token_ids) if video_inputs is not None: prompt_token_ids = _expand_pad_tokens(video_inputs, hf_config.video_token_id, - make_batched_videos, - "video", + make_batched_videos, "video", image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) - - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(prompt_token_ids) + prompt_token_ids) return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=prompt, + prompt=inputs["prompt"], multi_modal_data=multi_modal_data, ) @@ -928,8 +860,8 @@ def input_processor_for_qwen2_vl( "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, - SupportsPP): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -958,7 +890,6 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None) -> None: - super().__init__() assert not cache_config.enable_prefix_caching, \ @@ -966,18 +897,16 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix="visual", + + # NOTE: Qwen2-VL vision encoder does not support any + # quantization method now. + quant_config=None, ) - self.model = Qwen2Model(config, - cache_config, - quant_config, - prefix="model") + self.model = Qwen2Model(config, cache_config, quant_config) if get_pp_group().is_last_rank: if config.tie_word_embeddings: @@ -985,8 +914,7 @@ def __init__(self, else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config, - prefix="lm_head") + quant_config=quant_config) else: self.lm_head = PPMissingLayer() @@ -998,7 +926,7 @@ def __init__(self, def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, - List[torch.Tensor]], + List[torch.Tensor]], name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -1083,24 +1011,24 @@ def _process_video_input(self, return video_embeds def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: torch.Tensor, + placeholder_token_id: int, ) -> torch.Tensor: mask = (input_ids == placeholder_token_id) inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. @@ -1175,9 +1103,9 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -1211,7 +1139,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - if "visual" in name and name.endswith("qkv.weight"): + if "visual" in name and "qkv.weight" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads @@ -1220,7 +1148,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): visual_embed_dim) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif "visual" in name and name.endswith("qkv.bias"): + elif "visual" in name and "qkv.bias" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads From c1eeb5d863a62657d7c46f323b38105f18883f5d Mon Sep 17 00:00:00 2001 From: ericperfect Date: Tue, 5 Nov 2024 19:37:04 +0800 Subject: [PATCH 3/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 167 ++++++++++++------------- 1 file changed, 83 insertions(+), 84 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index bae9a8eb840d1..f4fd46e0e96c4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -72,7 +72,6 @@ logger = init_logger(__name__) - # === Vision Inputs === # @@ -98,7 +97,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, -Qwen2VLImageEmbeddingInputs] + Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoInputs(TypedDict): @@ -121,11 +120,11 @@ class Qwen2VLVideoInputs(TypedDict): class Qwen2VisionMLP(nn.Module): def __init__( - self, - in_features: int, - hidden_features: int = None, - act_layer: Type[nn.Module] = QuickGELU, - quant_config: Optional[QuantizationConfig] = None, + self, + in_features: int, + hidden_features: int = None, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, @@ -191,11 +190,11 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, class Qwen2VisionAttention(nn.Module): def __init__( - self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + self, + embed_dim: Optional[int] = None, + num_heads: Optional[int] = None, + projection_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -215,16 +214,16 @@ def __init__( # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -276,7 +275,7 @@ def forward( dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True + cu_seqlens[i - 1]:cu_seqlens[i]] = True output = F.scaled_dot_product_attention(q, k, v, @@ -303,13 +302,13 @@ def forward( class Qwen2VisionBlock(nn.Module): def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, - quant_config: Optional[QuantizationConfig] = None, + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() if norm_layer is None: @@ -339,11 +338,11 @@ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, class Qwen2VisionPatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_chans: int = 3, - embed_dim: int = 1152, + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size @@ -368,15 +367,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionPatchMerger(nn.Module): def __init__( - self, - d_model: int, - context_dim: int, - norm_layer: Type[nn.Module] = None, - spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, + self, + d_model: int, + context_dim: int, + norm_layer: Type[nn.Module] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) @@ -410,7 +409,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.dim = dim self.theta = theta inv_freq = 1.0 / (theta - ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -419,9 +418,9 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta ** (torch.arange( + self.inv_freq = 1.0 / (self.theta**(torch.arange( 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) + / self.dim)) seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) @@ -436,10 +435,10 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): def __init__( - self, - vision_config: Qwen2VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -516,9 +515,9 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: return rotary_pos_emb def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, + self, + x: torch.Tensor, + grid_thw: torch.Tensor, ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -530,7 +529,7 @@ def forward( # compute cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -549,9 +548,9 @@ def forward( def mm_input_mapper_for_qwen2_vl( - ctx: InputContext, - data: MultiModalData[object], - data_type_key: str, + ctx: InputContext, + data: MultiModalData[object], + data_type_key: str, ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): @@ -592,14 +591,14 @@ def mm_input_mapper_for_qwen2_vl( def _get_vision_info( - image_processor, - height: int, - width: int, - min_pixels: int, - max_pixels: int, - do_resize: bool = True, - data_type_key: str = "image", - mm_count: int = 1, + image_processor, + height: int, + width: int, + min_pixels: int, + max_pixels: int, + do_resize: bool = True, + data_type_key: str = "image", + mm_count: int = 1, ): """Get information (resized height / width and number of vision tokens) of input image / video frame.""" @@ -631,9 +630,9 @@ def _get_vision_info( def _get_max_image_info( - image_processor, - data_type_key: str = "image", - mm_count: int = 1, + image_processor, + data_type_key: str = "image", + mm_count: int = 1, ): return _get_vision_info( image_processor, @@ -663,7 +662,7 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def dummy_data_for_qwen2_vl( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) @@ -706,9 +705,9 @@ def dummy_data_for_qwen2_vl( def _get_llm_num_vision_tokens( - mm_inputs: list, - data_type_key: str, - image_processor, + mm_inputs: list, + data_type_key: str, + image_processor, ): """Get number of vision tokens of multimodal inputs. @@ -774,8 +773,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, def input_processor_for_qwen2_vl( - ctx: InputContext, - inputs: DecoderOnlyInputs, + ctx: InputContext, + inputs: DecoderOnlyInputs, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is None: @@ -926,7 +925,7 @@ def __init__(self, def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, - List[torch.Tensor]], + List[torch.Tensor]], name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -1011,24 +1010,24 @@ def _process_video_input(self, return video_embeds def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: torch.Tensor, + placeholder_token_id: int, ) -> torch.Tensor: mask = (input_ids == placeholder_token_id) inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. @@ -1103,9 +1102,9 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens From 394fa07708de939afd623656b5ccb4adc05e4868 Mon Sep 17 00:00:00 2001 From: ericperfect Date: Wed, 6 Nov 2024 09:31:01 +0800 Subject: [PATCH 4/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 211 +++++++++++++++++-------- 1 file changed, 142 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f4fd46e0e96c4..c7c7e83805a88 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import lru_cache, partial +from functools import partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) @@ -41,11 +41,11 @@ from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.config import CacheConfig, MultiModalConfig, LoRAConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -61,11 +61,12 @@ MultiModalInputs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope -from vllm.transformers_utils.processor import get_processor +from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -78,7 +79,7 @@ class Qwen2VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: + """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ @@ -102,8 +103,8 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): class Qwen2VLVideoInputs(TypedDict): pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, + """Shape: + `(num_patches, num_channels * temporal_patch_size * patch_size * patch_size)` """ @@ -125,15 +126,18 @@ def __init__( hidden_features: int = None, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.fc1") self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.fc2") def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -195,6 +199,7 @@ def __init__( num_heads: Optional[int] = None, projection_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() # Per attention head and per partition values. @@ -206,10 +211,12 @@ def __init__( self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() @@ -309,6 +316,7 @@ def __init__( act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() if norm_layer is None: @@ -320,11 +328,13 @@ def __init__( self.attn = Qwen2VisionAttention(embed_dim=dim, num_heads=num_heads, projection_size=dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mlp") def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor: @@ -373,6 +383,7 @@ def __init__( norm_layer: Type[nn.Module] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -383,12 +394,14 @@ def __init__( ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=True, - quant_config=quant_config), + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, - quant_config=quant_config), + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -439,6 +452,7 @@ def __init__( vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -466,28 +480,29 @@ def __init__( self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ - Qwen2VisionBlock( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - ) for _ in range(depth) + Qwen2VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) ]) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, + prefix=f"{prefix}.merger", ) @property def dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype + return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device + return self.patch_embed.proj.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] @@ -544,13 +559,14 @@ def forward( # === Vision input helpers === # -cached_get_processor = lru_cache(get_processor) - def mm_input_mapper_for_qwen2_vl( ctx: InputContext, data: MultiModalData[object], data_type_key: str, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): @@ -559,8 +575,19 @@ def mm_input_mapper_for_qwen2_vl( "image_grid_thw": data.get("image_grid_thw"), }) model_config = ctx.model_config + # Handle mm processor kwargs; we pass these at creation time + # because preprocess() in transformers doesn't expose them + mm_processor_kwargs = {} + if min_pixels: + mm_processor_kwargs["min_pixels"] = min_pixels + if max_pixels: + mm_processor_kwargs["max_pixels"] = max_pixels + image_processor = cached_get_image_processor( - model_config.model, trust_remote_code=model_config.trust_remote_code) + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") @@ -633,25 +660,36 @@ def _get_max_image_info( image_processor, data_type_key: str = "image", mm_count: int = 1, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ): + # Limit min / max pixels unless they're explicitly provided + if min_pixels is None: + min_pixels = max(image_processor.min_pixels, 28 * 28) + if max_pixels is None: + max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28) + return _get_vision_info( image_processor, height=9999999, width=9999999, - - # Limit min / max pixels. - min_pixels=max(image_processor.min_pixels, 28 * 28), - max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28), + min_pixels=min_pixels, + max_pixels=max_pixels, data_type_key=data_type_key, mm_count=mm_count, ) -def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, + data_type_key: str, + *, + min_pixels=None, + max_pixels=None) -> int: image_processor = cached_get_image_processor(ctx.model_config.model) max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key=data_type_key, - mm_count=1) + mm_count=1, min_pixels=min_pixels, + max_pixels=max_pixels) return max_llm_image_tokens @@ -662,14 +700,20 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def dummy_data_for_qwen2_vl( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key="image", - mm_count=num_images) + mm_count=num_images, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " @@ -680,10 +724,11 @@ def dummy_data_for_qwen2_vl( num_videos = mm_counts["video"] max_resized_height, max_resized_width, max_llm_video_tokens = \ _get_max_image_info(image_processor, data_type_key="video", - mm_count=num_videos) + mm_count=num_videos, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( - f"Qwen2-VL cannot process {num_images} videos in a prompt, " + f"Qwen2-VL cannot process {num_videos} videos in a prompt, " "please increase max_model_len or reduce video limit by " "--limit-mm-per-prompt.") @@ -699,15 +744,18 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - return dummy_seqdata, { - "image": dummy_image if num_images == 1 else [dummy_image] * num_images - } + return DummyData(dummy_seqdata, { + "image": + dummy_image if num_images == 1 else [dummy_image] * num_images + }) def _get_llm_num_vision_tokens( mm_inputs: list, data_type_key: str, image_processor, + min_pixels: int, + max_pixels: int, ): """Get number of vision tokens of multimodal inputs. @@ -717,12 +765,13 @@ def _get_llm_num_vision_tokens( image = to_numpy_array(mm_inputs[0]) input_data_format = infer_channel_dimension_format(image) height, width = get_image_size(image, channel_dim=input_data_format) + _, _, llm_num_vision_tokens = _get_vision_info( image_processor, height=height, width=width, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, + min_pixels=min_pixels, + max_pixels=max_pixels, do_resize=image_processor.do_resize, data_type_key=data_type_key, mm_count=len(mm_inputs), @@ -732,7 +781,8 @@ def _get_llm_num_vision_tokens( def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key: str, image_processor: Any, - prompt_token_ids: List[int]) -> List[int]: + prompt_token_ids: List[int], min_pixels: Optional[int], + max_pixels: Optional[int]) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -743,6 +793,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key (str): The type of the multi-modal input. image_processor (Any): The image processor used to process the inputs. prompt_token_ids (List[int]): The list of token IDs in the prompt. + min_pixels (int): min pixels to used for img processing + max_pixels (int): max pixels to be used for img processing Returns: List[int]: The list of token IDs for the multi-modal inputs. @@ -759,6 +811,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, [data] if data_type_key == "image" else data, data_type_key=data_type_key, image_processor=image_processor, + min_pixels=min_pixels, + max_pixels=max_pixels, ) if cnt == 0: end_idx = indices[cnt] @@ -775,8 +829,11 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, def input_processor_for_qwen2_vl( ctx: InputContext, inputs: DecoderOnlyInputs, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ) -> DecoderOnlyInputs: - multi_modal_data = inputs.get("multi_modal_data", None) + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None: return inputs @@ -785,6 +842,11 @@ def input_processor_for_qwen2_vl( processor = cached_get_processor(ctx.model_config.model) image_processor = processor.image_processor + # Apply processor kwarg overrides for image processor options + min_pixels = min_pixels if min_pixels else image_processor.min_pixels + max_pixels = max_pixels if max_pixels else image_processor.max_pixels + + model_config = ctx.model_config hf_config = ctx.get_hf_config(Qwen2VLConfig) # To avoid redundant processing of vision objects (resize, rescale, etc.), @@ -800,14 +862,11 @@ def input_processor_for_qwen2_vl( # return_tensors="pt") # prompt_token_ids = inputs["input_ids"][0].tolist() - prompt_token_ids = inputs.get("prompt_token_ids", None) - if prompt_token_ids is None: - prompt = inputs["prompt"] - prompt_token_ids = processor.tokenizer( - prompt, - padding=True, - return_tensors=None, - )["input_ids"] + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + + prompt_token_ids = inputs["prompt_token_ids"] # Expand image pad tokens. @@ -832,20 +891,30 @@ def input_processor_for_qwen2_vl( else: prompt_token_ids = _expand_pad_tokens(image_inputs, hf_config.image_token_id, - make_batched_images, "image", + make_batched_images, + "image", image_processor, - prompt_token_ids) + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) if video_inputs is not None: prompt_token_ids = _expand_pad_tokens(video_inputs, hf_config.video_token_id, - make_batched_videos, "video", + make_batched_videos, + "video", image_processor, - prompt_token_ids) + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) + + prompt = inputs.get("prompt") + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=inputs["prompt"], + prompt=prompt, multi_modal_data=multi_modal_data, ) @@ -859,8 +928,8 @@ def input_processor_for_qwen2_vl( "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, + SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -889,6 +958,7 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None) -> None: + super().__init__() assert not cache_config.enable_prefix_caching, \ @@ -896,16 +966,18 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config + self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - - # NOTE: Qwen2-VL vision encoder does not support any - # quantization method now. - quant_config=None, + quant_config=quant_config, + prefix="visual", ) - self.model = Qwen2Model(config, cache_config, quant_config) + self.model = Qwen2Model(config, + cache_config, + quant_config, + prefix="model") if get_pp_group().is_last_rank: if config.tie_word_embeddings: @@ -913,7 +985,8 @@ def __init__(self, else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix="lm_head") else: self.lm_head = PPMissingLayer() @@ -1138,7 +1211,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - if "visual" in name and "qkv.weight" in name: + if "visual" in name and name.endswith("qkv.weight"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads @@ -1147,7 +1220,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): visual_embed_dim) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif "visual" in name and "qkv.bias" in name: + elif "visual" in name and name.endswith("qkv.bias"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads From 3e15502992acfb6a67af743534f0b644965ddbbe Mon Sep 17 00:00:00 2001 From: ericperfect Date: Wed, 6 Nov 2024 09:31:26 +0800 Subject: [PATCH 5/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c7c7e83805a88..2e58f91e96a98 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -928,8 +928,8 @@ def input_processor_for_qwen2_vl( "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, - SupportsPP): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", From 2c6f7fc040b494fda60d84828d2179f4ed2111be Mon Sep 17 00:00:00 2001 From: ericperfect Date: Wed, 6 Nov 2024 09:41:56 +0800 Subject: [PATCH 6/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 767 +++++++++++++++---------- 1 file changed, 467 insertions(+), 300 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2e58f91e96a98..8597c970f46ba 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,43 +22,72 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" + from functools import partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Tuple, Type, TypedDict, Union) +from typing import ( + Any, + Callable, + Iterable, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + TypedDict, + Union, +) import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers.image_utils import (get_image_size, - infer_channel_dimension_format, - to_numpy_array) +from transformers.image_utils import ( + get_image_size, + infer_channel_dimension_format, + to_numpy_array, +) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLConfig, Qwen2VLVisionConfig) + Qwen2VLConfig, + Qwen2VLVisionConfig, +) from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - make_batched_images, make_batched_videos, smart_resize) + make_batched_images, + make_batched_videos, + smart_resize, +) from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend from vllm.config import CacheConfig, MultiModalConfig, LoRAConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import ( + INPUT_REGISTRY, + DecoderOnlyInputs, + DummyData, + InputContext, + token_inputs, +) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalInputs) +from vllm.multimodal import ( + MULTIMODAL_REGISTRY, + MultiModalDataDict, + MultiModalInputs, +) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer @@ -67,9 +96,12 @@ from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA -from .utils import (PPMissingLayer, get_vit_attn_backend, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory) +from .utils import ( + PPMissingLayer, + get_vit_attn_backend, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, +) logger = init_logger(__name__) @@ -97,8 +129,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): """ -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, - Qwen2VLImageEmbeddingInputs] +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoInputs(TypedDict): @@ -119,7 +150,6 @@ class Qwen2VLVideoInputs(TypedDict): class Qwen2VisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -129,15 +159,19 @@ def __init__( prefix: str = "", ): super().__init__() - self.fc1 = ColumnParallelLinear(in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) self.act = act_layer() - self.fc2 = RowParallelLinear(hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -152,15 +186,17 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) -def apply_rotary_emb_torch(x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False) -> torch.Tensor: +def apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, +) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -169,21 +205,25 @@ def apply_rotary_emb_torch(x: torch.Tensor, assert ro_dim <= x.shape[-1] cos = repeat( cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", + ) sin = repeat( sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) -def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision( + t: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() @@ -192,7 +232,6 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, class Qwen2VisionAttention(nn.Module): - def __init__( self, embed_dim: Optional[int] = None, @@ -205,26 +244,35 @@ def __init__( # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, world_size + ) - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.qkv = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, }: raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now.") + f"Qwen2-VL does not support {self.attn_backend} backend now." + ) def forward( self, @@ -261,53 +309,58 @@ def forward( q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0, - causal=False) - - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) ... -> b s ...", b=batch_size + ) elif self.attn_backend == _Backend.TORCH_SDPA: seq_length = q.size(1) q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool + ) for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=0.0 + ) context_layer = rearrange(output, "b h s d -> b s h d ") elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -325,28 +378,35 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - self.mlp = Qwen2VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor) -> torch.Tensor: - x = x + self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb) + self.attn = Qwen2VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.mlp = Qwen2VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) x = x + self.mlp(self.norm2(x)) return x class Qwen2VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -360,22 +420,24 @@ def __init__( self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_chans, - embed_dim, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view( + L, -1, self.temporal_patch_size, self.patch_size, self.patch_size + ) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -390,19 +452,25 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), - nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + ), + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) @@ -416,13 +484,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -431,12 +499,22 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, + self.dim, + 2, + dtype=torch.float, + device=self.inv_freq.device, + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -446,7 +524,6 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2VLVisionConfig, @@ -479,15 +556,19 @@ def __init__( head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen2VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, @@ -509,20 +590,29 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + ) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -542,9 +632,9 @@ def forward( rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -570,10 +660,12 @@ def mm_input_mapper_for_qwen2_vl( ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): - return MultiModalInputs({ - "image_embeds": data.get("image_embeds"), - "image_grid_thw": data.get("image_grid_thw"), - }) + return MultiModalInputs( + { + "image_embeds": data.get("image_embeds"), + "image_grid_thw": data.get("image_grid_thw"), + } + ) model_config = ctx.model_config # Handle mm processor kwargs; we pass these at creation time # because preprocess() in transformers doesn't expose them @@ -589,8 +681,10 @@ def mm_input_mapper_for_qwen2_vl( **mm_processor_kwargs, ) if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") + raise RuntimeError( + "No HuggingFace processor is available " + "to process the image object" + ) images = None videos = None @@ -601,9 +695,9 @@ def mm_input_mapper_for_qwen2_vl( videos = data try: - batch_data = image_processor \ - .preprocess(images=images, videos=videos, return_tensors="pt") \ - .data + batch_data = image_processor.preprocess( + images=images, videos=videos, return_tensors="pt" + ).data except Exception: logger.error("Failed to process image (%s)", data) raise @@ -611,10 +705,12 @@ def mm_input_mapper_for_qwen2_vl( return MultiModalInputs(batch_data) -image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, - data_type_key="image") -video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, - data_type_key="video") +image_input_mapper_for_qwen2_vl = partial( + mm_input_mapper_for_qwen2_vl, data_type_key="image" +) +video_input_mapper_for_qwen2_vl = partial( + mm_input_mapper_for_qwen2_vl, data_type_key="video" +) def _get_vision_info( @@ -650,8 +746,11 @@ def _get_vision_info( grid_h = resized_height // image_processor.patch_size grid_w = resized_width // image_processor.patch_size vision_tokens = grid_t * grid_h * grid_w - llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // - image_processor.merge_size) + llm_num_vision_tokens = ( + vision_tokens + // image_processor.merge_size + // image_processor.merge_size + ) return resized_height, resized_width, llm_num_vision_tokens @@ -680,23 +779,28 @@ def _get_max_image_info( ) -def get_max_qwen2_vl_mm_tokens(ctx: InputContext, - data_type_key: str, - *, - min_pixels=None, - max_pixels=None) -> int: +def get_max_qwen2_vl_mm_tokens( + ctx: InputContext, data_type_key: str, *, min_pixels=None, max_pixels=None +) -> int: image_processor = cached_get_image_processor(ctx.model_config.model) - max_resized_height, max_resized_width, max_llm_image_tokens = \ - _get_max_image_info(image_processor, data_type_key=data_type_key, - mm_count=1, min_pixels=min_pixels, - max_pixels=max_pixels) + max_resized_height, max_resized_width, max_llm_image_tokens = ( + _get_max_image_info( + image_processor, + data_type_key=data_type_key, + mm_count=1, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + ) return max_llm_image_tokens -get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, - data_type_key="image") -get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, - data_type_key="video") +get_max_qwen2_vl_image_tokens = partial( + get_max_qwen2_vl_mm_tokens, data_type_key="image" +) +get_max_qwen2_vl_video_tokens = partial( + get_max_qwen2_vl_mm_tokens, data_type_key="video" +) def dummy_data_for_qwen2_vl( @@ -705,32 +809,44 @@ def dummy_data_for_qwen2_vl( mm_counts: Mapping[str, int], *, min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None + max_pixels: Optional[int] = None, ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] - max_resized_height, max_resized_width, max_llm_image_tokens = \ - _get_max_image_info(image_processor, data_type_key="image", - mm_count=num_images, min_pixels=min_pixels, - max_pixels=max_pixels) + max_resized_height, max_resized_width, max_llm_image_tokens = ( + _get_max_image_info( + image_processor, + data_type_key="image", + mm_count=num_images, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + ) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " "please increase max_model_len or reduce image limit by " - "--limit-mm-per-prompt.") + "--limit-mm-per-prompt." + ) # Check video counts. num_videos = mm_counts["video"] - max_resized_height, max_resized_width, max_llm_video_tokens = \ - _get_max_image_info(image_processor, data_type_key="video", - mm_count=num_videos, min_pixels=min_pixels, - max_pixels=max_pixels) + max_resized_height, max_resized_width, max_llm_video_tokens = ( + _get_max_image_info( + image_processor, + data_type_key="video", + mm_count=num_videos, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + ) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_videos} videos in a prompt, " "please increase max_model_len or reduce video limit by " - "--limit-mm-per-prompt.") + "--limit-mm-per-prompt." + ) hf_config = ctx.get_hf_config(Qwen2VLConfig) @@ -741,13 +857,18 @@ def dummy_data_for_qwen2_vl( (0, seq_len - max_llm_image_tokens - 2), ) - dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), - color=0) + dummy_image = Image.new( + "RGB", (max_resized_width, max_resized_height), color=0 + ) - return DummyData(dummy_seqdata, { - "image": - dummy_image if num_images == 1 else [dummy_image] * num_images - }) + return DummyData( + dummy_seqdata, + { + "image": dummy_image + if num_images == 1 + else [dummy_image] * num_images + }, + ) def _get_llm_num_vision_tokens( @@ -779,10 +900,16 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens -def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, - data_type_key: str, image_processor: Any, - prompt_token_ids: List[int], min_pixels: Optional[int], - max_pixels: Optional[int]) -> List[int]: +def _expand_pad_tokens( + inputs: list, + token_id: int, + make_batched_fn: Callable, + data_type_key: str, + image_processor: Any, + prompt_token_ids: List[int], + min_pixels: Optional[int], + max_pixels: Optional[int], +) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -818,11 +945,12 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, end_idx = indices[cnt] non_data_tokens = prompt_token_ids[:end_idx] else: - non_data_tokens = prompt_token_ids[indices[cnt - 1] + - 1:indices[cnt]] + non_data_tokens = prompt_token_ids[ + indices[cnt - 1] + 1 : indices[cnt] + ] prompt_token_ids_with_data.extend(non_data_tokens) prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) - prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) + prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1 :]) return prompt_token_ids_with_data @@ -863,8 +991,8 @@ def input_processor_for_qwen2_vl( # prompt_token_ids = inputs["input_ids"][0].tolist() tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) + model_config.tokenizer, trust_remote_code=model_config.trust_remote_code + ) prompt_token_ids = inputs["prompt_token_ids"] @@ -874,39 +1002,43 @@ def input_processor_for_qwen2_vl( if isinstance(image_inputs, dict): prompt_token_ids_with_image = [] image_indices = [ - idx for idx, token in enumerate(prompt_token_ids) + idx + for idx, token in enumerate(prompt_token_ids) if token == hf_config.image_token_id ] image_cnt = len(image_indices) - embed_dim = image_inputs.get('image_embeds').size(0) + embed_dim = image_inputs.get("image_embeds").size(0) assert embed_dim % image_cnt == 0 num_pad_tokens = embed_dim // image_cnt for idx, token in enumerate(prompt_token_ids): if idx in image_indices: - prompt_token_ids_with_image.extend([token] * - num_pad_tokens) + prompt_token_ids_with_image.extend([token] * num_pad_tokens) else: prompt_token_ids_with_image.append(token) prompt_token_ids = prompt_token_ids_with_image else: - prompt_token_ids = _expand_pad_tokens(image_inputs, - hf_config.image_token_id, - make_batched_images, - "image", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + prompt_token_ids = _expand_pad_tokens( + image_inputs, + hf_config.image_token_id, + make_batched_images, + "image", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) if video_inputs is not None: - prompt_token_ids = _expand_pad_tokens(video_inputs, - hf_config.video_token_id, - make_batched_videos, - "video", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + prompt_token_ids = _expand_pad_tokens( + video_inputs, + hf_config.video_token_id, + make_batched_videos, + "video", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) prompt = inputs.get("prompt") if prompt is None: @@ -920,16 +1052,20 @@ def input_processor_for_qwen2_vl( @MULTIMODAL_REGISTRY.register_image_input_mapper( - image_input_mapper_for_qwen2_vl) -@MULTIMODAL_REGISTRY.register_input_mapper("video", - video_input_mapper_for_qwen2_vl) + image_input_mapper_for_qwen2_vl +) +@MULTIMODAL_REGISTRY.register_input_mapper( + "video", video_input_mapper_for_qwen2_vl +) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "video", get_max_qwen2_vl_video_tokens) + "video", get_max_qwen2_vl_video_tokens +) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): +class Qwen2VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -952,17 +1088,19 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] - def __init__(self, - config: Qwen2VLConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None) -> None: - + def __init__( + self, + config: Qwen2VLConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: super().__init__() - assert not cache_config.enable_prefix_caching, \ - "Qwen2-VL currently does not support prefix caching" + assert ( + not cache_config.enable_prefix_caching + ), "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config @@ -974,19 +1112,20 @@ def __init__(self, prefix="visual", ) - self.model = Qwen2Model(config, - cache_config, - quant_config, - prefix="model") + self.model = Qwen2Model( + config, cache_config, quant_config, prefix="model" + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix="lm_head") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) else: self.lm_head = PPMissingLayer() @@ -994,27 +1133,32 @@ def __init__(self, self.sampler = Sampler() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + ["hidden_states", "residual"], config.hidden_size + ) + ) - def _validate_and_reshape_mm_tensor(self, - mm_input: Union[torch.Tensor, - List[torch.Tensor]], - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError( + f"Incorrect type of {name}. " f"Got type: {type(mm_input)}" + ) if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}") + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim}" + ) return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1024,30 +1168,41 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}" + ) - return Qwen2VLImagePixelInputs(type="pixel_values", - data=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2VLImagePixelInputs( + type="pixel_values", + data=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Qwen2VLImageEmbeddingInputs(type="image_embeds", - data=image_embeds) + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) + return Qwen2VLImageEmbeddingInputs( + type="image_embeds", data=image_embeds + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1055,31 +1210,38 @@ def _parse_and_validate_video_input( return None pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2VLVideoInputs( pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) - def _process_image_input(self, - image_input: Qwen2VLImageInputs) -> torch.Tensor: + def _process_image_input( + self, image_input: Qwen2VLImageInputs + ) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"].type(self.visual.dtype) pixel_values = image_input["data"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, - grid_thw=image_input["image_grid_thw"]) + image_embeds = self.visual( + pixel_values, grid_thw=image_input["image_grid_thw"] + ) return image_embeds - def _process_video_input(self, - video_input: Qwen2VLVideoInputs) -> torch.Tensor: + def _process_video_input( + self, video_input: Qwen2VLVideoInputs + ) -> torch.Tensor: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=video_input["video_grid_thw"]) + self.visual.dtype + ) + video_embeds = self.visual( + pixel_values_videos, grid_thw=video_input["video_grid_thw"] + ) return video_embeds def _merge_multimodal_embeddings( @@ -1089,7 +1251,7 @@ def _merge_multimodal_embeddings( multimodal_embeddings: torch.Tensor, placeholder_token_id: int, ) -> torch.Tensor: - mask = (input_ids == placeholder_token_id) + mask = input_ids == placeholder_token_id inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds @@ -1134,7 +1296,8 @@ def forward( if uses_mrope(self.config): assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") + f"(3, seq_len) positions, but got {positions.size()}" + ) inputs_embeds = self.model.embed_tokens(input_ids) @@ -1168,10 +1331,12 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata + ) return logits def sample( @@ -1197,7 +1362,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1215,17 +1380,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) + loaded_weight = loaded_weight.view( + 3, visual_num_heads, head_size, visual_embed_dim + ) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) elif "visual" in name and name.endswith("qkv.bias"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) + loaded_weight = loaded_weight.view( + 3, visual_num_heads, head_size + ) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) try: @@ -1238,6 +1404,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): except KeyError: raise ValueError(f"Unexpected weight: {name}") from None - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) From 21c5aa47fa1294934af7fa0237b823aebf76ba46 Mon Sep 17 00:00:00 2001 From: ericperfect Date: Wed, 6 Nov 2024 10:06:09 +0800 Subject: [PATCH 7/8] format code Signed-off-by: ericperfect --- vllm/model_executor/models/qwen2_vl.py | 771 ++++++++++--------------- 1 file changed, 302 insertions(+), 469 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8597c970f46ba..cd9e53d889c18 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,72 +22,43 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" - from functools import partial -from typing import ( - Any, - Callable, - Iterable, - List, - Literal, - Mapping, - Optional, - Tuple, - Type, - TypedDict, - Union, -) +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Tuple, Type, TypedDict, Union) import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers.image_utils import ( - get_image_size, - infer_channel_dimension_format, - to_numpy_array, -) +from transformers.image_utils import (get_image_size, + infer_channel_dimension_format, + to_numpy_array) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLConfig, - Qwen2VLVisionConfig, -) + Qwen2VLConfig, Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - make_batched_images, - make_batched_videos, - smart_resize, -) + make_batched_images, make_batched_videos, smart_resize) from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, MultiModalConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import ( - INPUT_REGISTRY, - DecoderOnlyInputs, - DummyData, - InputContext, - token_inputs, -) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.multimodal import ( - MULTIMODAL_REGISTRY, - MultiModalDataDict, - MultiModalInputs, -) +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalInputs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer @@ -95,13 +66,10 @@ from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA -from .utils import ( - PPMissingLayer, - get_vit_attn_backend, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, -) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .utils import (PPMissingLayer, get_vit_attn_backend, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory) logger = init_logger(__name__) @@ -129,7 +97,8 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): """ -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, + Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoInputs(TypedDict): @@ -150,6 +119,7 @@ class Qwen2VLVideoInputs(TypedDict): class Qwen2VisionMLP(nn.Module): + def __init__( self, in_features: int, @@ -159,19 +129,15 @@ def __init__( prefix: str = "", ): super().__init__() - self.fc1 = ColumnParallelLinear( - in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - ) + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") self.act = act_layer() - self.fc2 = RowParallelLinear( - hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - ) + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -186,17 +152,15 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) -def apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, -) -> torch.Tensor: +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -205,25 +169,21 @@ def apply_rotary_emb_torch( assert ro_dim <= x.shape[-1] cos = repeat( cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", - ) + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat( sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", - ) + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [ - x[..., :ro_dim] * cos - + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] ], dim=-1, ) -def apply_rotary_pos_emb_vision( - t: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() @@ -232,6 +192,7 @@ def apply_rotary_pos_emb_vision( class Qwen2VisionAttention(nn.Module): + def __init__( self, embed_dim: Optional[int] = None, @@ -244,35 +205,26 @@ def __init__( # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads - ) + projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size - ) + num_heads, world_size) - self.qkv = ColumnParallelLinear( - input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) - self.proj = RowParallelLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now." - ) + f"Qwen2-VL does not support {self.attn_backend} backend now.") def forward( self, @@ -309,58 +261,53 @@ def forward( q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - output = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) ... -> b s ...", b=batch_size - ) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: seq_length = q.size(1) q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] - attention_mask = torch.zeros( - [1, seq_length, seq_length], device=q.device, dtype=torch.bool - ) + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = True - output = F.scaled_dot_product_attention( - q, k, v, attention_mask, dropout_p=0.0 - ) + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) context_layer = rearrange(output, "b h s d -> b s h d ") elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None - ) + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): + def __init__( self, dim: int, @@ -378,35 +325,28 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) - self.mlp = Qwen2VisionMLP( - dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - ) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb - ) + self.attn = Qwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) x = x + self.mlp(self.norm2(x)) return x class Qwen2VisionPatchEmbed(nn.Module): + def __init__( self, patch_size: int = 14, @@ -420,24 +360,22 @@ def __init__( self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d( - in_chans, - embed_dim, - kernel_size=kernel_size, - stride=kernel_size, - bias=False, - ) + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view( - L, -1, self.temporal_patch_size, self.patch_size, self.patch_size - ) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): + def __init__( self, d_model: int, @@ -452,25 +390,19 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - self.mlp = nn.ModuleList( - [ - ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0", - ), - nn.GELU(), - RowParallelLinear( - self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2", - ), - ] - ) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), + ]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) @@ -484,13 +416,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / ( - theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) - ) + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -499,22 +431,12 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, - self.dim, - 2, - dtype=torch.float, - device=self.inv_freq.device, - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -524,6 +446,7 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): + def __init__( self, vision_config: Qwen2VLVisionConfig, @@ -556,19 +479,15 @@ def __init__( head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( - [ - Qwen2VisionBlock( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - ) - for layer_idx in range(depth) - ] - ) + self.blocks = nn.ModuleList([ + Qwen2VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, @@ -590,29 +509,20 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = ( - hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - .permute(0, 2, 1, 3) - .flatten() - ) - wpos_ids = ( - wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - .permute(0, 2, 1, 3) - .flatten() - ) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) - ) + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -632,9 +542,9 @@ def forward( rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -660,12 +570,10 @@ def mm_input_mapper_for_qwen2_vl( ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): - return MultiModalInputs( - { - "image_embeds": data.get("image_embeds"), - "image_grid_thw": data.get("image_grid_thw"), - } - ) + return MultiModalInputs({ + "image_embeds": data.get("image_embeds"), + "image_grid_thw": data.get("image_grid_thw"), + }) model_config = ctx.model_config # Handle mm processor kwargs; we pass these at creation time # because preprocess() in transformers doesn't expose them @@ -681,10 +589,8 @@ def mm_input_mapper_for_qwen2_vl( **mm_processor_kwargs, ) if image_processor is None: - raise RuntimeError( - "No HuggingFace processor is available " - "to process the image object" - ) + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") images = None videos = None @@ -695,9 +601,9 @@ def mm_input_mapper_for_qwen2_vl( videos = data try: - batch_data = image_processor.preprocess( - images=images, videos=videos, return_tensors="pt" - ).data + batch_data = image_processor \ + .preprocess(images=images, videos=videos, return_tensors="pt") \ + .data except Exception: logger.error("Failed to process image (%s)", data) raise @@ -705,12 +611,10 @@ def mm_input_mapper_for_qwen2_vl( return MultiModalInputs(batch_data) -image_input_mapper_for_qwen2_vl = partial( - mm_input_mapper_for_qwen2_vl, data_type_key="image" -) -video_input_mapper_for_qwen2_vl = partial( - mm_input_mapper_for_qwen2_vl, data_type_key="video" -) +image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="image") +video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="video") def _get_vision_info( @@ -746,11 +650,8 @@ def _get_vision_info( grid_h = resized_height // image_processor.patch_size grid_w = resized_width // image_processor.patch_size vision_tokens = grid_t * grid_h * grid_w - llm_num_vision_tokens = ( - vision_tokens - // image_processor.merge_size - // image_processor.merge_size - ) + llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // + image_processor.merge_size) return resized_height, resized_width, llm_num_vision_tokens @@ -779,28 +680,23 @@ def _get_max_image_info( ) -def get_max_qwen2_vl_mm_tokens( - ctx: InputContext, data_type_key: str, *, min_pixels=None, max_pixels=None -) -> int: +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, + data_type_key: str, + *, + min_pixels=None, + max_pixels=None) -> int: image_processor = cached_get_image_processor(ctx.model_config.model) - max_resized_height, max_resized_width, max_llm_image_tokens = ( - _get_max_image_info( - image_processor, - data_type_key=data_type_key, - mm_count=1, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - ) + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key=data_type_key, + mm_count=1, min_pixels=min_pixels, + max_pixels=max_pixels) return max_llm_image_tokens -get_max_qwen2_vl_image_tokens = partial( - get_max_qwen2_vl_mm_tokens, data_type_key="image" -) -get_max_qwen2_vl_video_tokens = partial( - get_max_qwen2_vl_mm_tokens, data_type_key="video" -) +get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="image") +get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="video") def dummy_data_for_qwen2_vl( @@ -809,44 +705,32 @@ def dummy_data_for_qwen2_vl( mm_counts: Mapping[str, int], *, min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, + max_pixels: Optional[int] = None ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] - max_resized_height, max_resized_width, max_llm_image_tokens = ( - _get_max_image_info( - image_processor, - data_type_key="image", - mm_count=num_images, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - ) + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key="image", + mm_count=num_images, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " "please increase max_model_len or reduce image limit by " - "--limit-mm-per-prompt." - ) + "--limit-mm-per-prompt.") # Check video counts. num_videos = mm_counts["video"] - max_resized_height, max_resized_width, max_llm_video_tokens = ( - _get_max_image_info( - image_processor, - data_type_key="video", - mm_count=num_videos, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - ) + max_resized_height, max_resized_width, max_llm_video_tokens = \ + _get_max_image_info(image_processor, data_type_key="video", + mm_count=num_videos, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_videos} videos in a prompt, " "please increase max_model_len or reduce video limit by " - "--limit-mm-per-prompt." - ) + "--limit-mm-per-prompt.") hf_config = ctx.get_hf_config(Qwen2VLConfig) @@ -857,18 +741,13 @@ def dummy_data_for_qwen2_vl( (0, seq_len - max_llm_image_tokens - 2), ) - dummy_image = Image.new( - "RGB", (max_resized_width, max_resized_height), color=0 - ) + dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), + color=0) - return DummyData( - dummy_seqdata, - { - "image": dummy_image - if num_images == 1 - else [dummy_image] * num_images - }, - ) + return DummyData(dummy_seqdata, { + "image": + dummy_image if num_images == 1 else [dummy_image] * num_images + }) def _get_llm_num_vision_tokens( @@ -900,16 +779,10 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens -def _expand_pad_tokens( - inputs: list, - token_id: int, - make_batched_fn: Callable, - data_type_key: str, - image_processor: Any, - prompt_token_ids: List[int], - min_pixels: Optional[int], - max_pixels: Optional[int], -) -> List[int]: +def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, + data_type_key: str, image_processor: Any, + prompt_token_ids: List[int], min_pixels: Optional[int], + max_pixels: Optional[int]) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -945,12 +818,11 @@ def _expand_pad_tokens( end_idx = indices[cnt] non_data_tokens = prompt_token_ids[:end_idx] else: - non_data_tokens = prompt_token_ids[ - indices[cnt - 1] + 1 : indices[cnt] - ] + non_data_tokens = prompt_token_ids[indices[cnt - 1] + + 1:indices[cnt]] prompt_token_ids_with_data.extend(non_data_tokens) prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) - prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1 :]) + prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) return prompt_token_ids_with_data @@ -991,8 +863,8 @@ def input_processor_for_qwen2_vl( # prompt_token_ids = inputs["input_ids"][0].tolist() tokenizer = cached_get_tokenizer( - model_config.tokenizer, trust_remote_code=model_config.trust_remote_code - ) + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) prompt_token_ids = inputs["prompt_token_ids"] @@ -1002,43 +874,39 @@ def input_processor_for_qwen2_vl( if isinstance(image_inputs, dict): prompt_token_ids_with_image = [] image_indices = [ - idx - for idx, token in enumerate(prompt_token_ids) + idx for idx, token in enumerate(prompt_token_ids) if token == hf_config.image_token_id ] image_cnt = len(image_indices) - embed_dim = image_inputs.get("image_embeds").size(0) + embed_dim = image_inputs.get('image_embeds').size(0) assert embed_dim % image_cnt == 0 num_pad_tokens = embed_dim // image_cnt for idx, token in enumerate(prompt_token_ids): if idx in image_indices: - prompt_token_ids_with_image.extend([token] * num_pad_tokens) + prompt_token_ids_with_image.extend([token] * + num_pad_tokens) else: prompt_token_ids_with_image.append(token) prompt_token_ids = prompt_token_ids_with_image else: - prompt_token_ids = _expand_pad_tokens( - image_inputs, - hf_config.image_token_id, - make_batched_images, - "image", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) + prompt_token_ids = _expand_pad_tokens(image_inputs, + hf_config.image_token_id, + make_batched_images, + "image", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) if video_inputs is not None: - prompt_token_ids = _expand_pad_tokens( - video_inputs, - hf_config.video_token_id, - make_batched_videos, - "video", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) + prompt_token_ids = _expand_pad_tokens(video_inputs, + hf_config.video_token_id, + make_batched_videos, + "video", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) prompt = inputs.get("prompt") if prompt is None: @@ -1052,20 +920,16 @@ def input_processor_for_qwen2_vl( @MULTIMODAL_REGISTRY.register_image_input_mapper( - image_input_mapper_for_qwen2_vl -) -@MULTIMODAL_REGISTRY.register_input_mapper( - "video", video_input_mapper_for_qwen2_vl -) + image_input_mapper_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_input_mapper("video", + video_input_mapper_for_qwen2_vl) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "video", get_max_qwen2_vl_video_tokens -) + "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP -): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1088,19 +952,17 @@ class Qwen2VLForConditionalGeneration( embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - config: Qwen2VLConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: + def __init__(self, + config: Qwen2VLConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None) -> None: + super().__init__() - assert ( - not cache_config.enable_prefix_caching - ), "Qwen2-VL currently does not support prefix caching" + assert not cache_config.enable_prefix_caching, \ + "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config @@ -1112,20 +974,19 @@ def __init__( prefix="visual", ) - self.model = Qwen2Model( - config, cache_config, quant_config, prefix="model" - ) + self.model = Qwen2Model(config, + cache_config, + quant_config, + prefix="model") if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix="lm_head", - ) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head") else: self.lm_head = PPMissingLayer() @@ -1133,32 +994,27 @@ def __init__( self.sampler = Sampler() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - ) + ["hidden_states", "residual"], config.hidden_size)) - def _validate_and_reshape_mm_tensor( - self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str - ) -> torch.Tensor: + def _validate_and_reshape_mm_tensor(self, + mm_input: Union[torch.Tensor, + List[torch.Tensor]], + name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of {name}. " f"Got type: {type(mm_input)}" - ) + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError( - f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}" - ) + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim}") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Optional[Qwen2VLImageInputs]: + self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1168,41 +1024,30 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values" - ) + pixel_values, "image pixel values") image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) + image_grid_thw, "image grid_thw") if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}" - ) + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") - return Qwen2VLImagePixelInputs( - type="pixel_values", - data=pixel_values, - image_grid_thw=image_grid_thw, - ) + return Qwen2VLImagePixelInputs(type="pixel_values", + data=pixel_values, + image_grid_thw=image_grid_thw) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds" - ) + image_embeds, "image embeds") if not isinstance(image_embeds, torch.Tensor): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) - return Qwen2VLImageEmbeddingInputs( - type="image_embeds", data=image_embeds - ) + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2VLImageEmbeddingInputs(type="image_embeds", + data=image_embeds) def _parse_and_validate_video_input( - self, **kwargs: object - ) -> Optional[Qwen2VLVideoInputs]: + self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1210,38 +1055,31 @@ def _parse_and_validate_video_input( return None pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values" - ) + pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) + video_grid_thw, "video grid_thw") return Qwen2VLVideoInputs( pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) - def _process_image_input( - self, image_input: Qwen2VLImageInputs - ) -> torch.Tensor: + def _process_image_input(self, + image_input: Qwen2VLImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"].type(self.visual.dtype) pixel_values = image_input["data"].type(self.visual.dtype) - image_embeds = self.visual( - pixel_values, grid_thw=image_input["image_grid_thw"] - ) + image_embeds = self.visual(pixel_values, + grid_thw=image_input["image_grid_thw"]) return image_embeds - def _process_video_input( - self, video_input: Qwen2VLVideoInputs - ) -> torch.Tensor: + def _process_video_input(self, + video_input: Qwen2VLVideoInputs) -> torch.Tensor: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype - ) - video_embeds = self.visual( - pixel_values_videos, grid_thw=video_input["video_grid_thw"] - ) + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, + grid_thw=video_input["video_grid_thw"]) return video_embeds def _merge_multimodal_embeddings( @@ -1251,7 +1089,7 @@ def _merge_multimodal_embeddings( multimodal_embeddings: torch.Tensor, placeholder_token_id: int, ) -> torch.Tensor: - mask = input_ids == placeholder_token_id + mask = (input_ids == placeholder_token_id) inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds @@ -1296,8 +1134,7 @@ def forward( if uses_mrope(self.config): assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}" - ) + f"(3, seq_len) positions, but got {positions.size()}") inputs_embeds = self.model.embed_tokens(input_ids) @@ -1331,12 +1168,10 @@ def forward( ) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor( - self.lm_head, hidden_states, sampling_metadata - ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits def sample( @@ -1362,7 +1197,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1380,18 +1215,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view( - 3, visual_num_heads, head_size, visual_embed_dim - ) + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) elif "visual" in name and name.endswith("qkv.bias"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view( - 3, visual_num_heads, head_size - ) + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) try: @@ -1404,7 +1238,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): except KeyError: raise ValueError(f"Unexpected weight: {name}") from None - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) From bf0dc7af5fdc6e35db01adcab548808cccbcc9e8 Mon Sep 17 00:00:00 2001 From: ericperfect Date: Wed, 6 Nov 2024 18:16:08 +0800 Subject: [PATCH 8/8] add TODO(TODO Support LoRA for the visual encoder in the future.) message and update documentation. Signed-off-by: ericperfect --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/qwen2_vl.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 55835d945b00c..713838f802822 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -534,7 +534,7 @@ Text Generation - Qwen2-VL - T + I\ :sup:`E+` + V\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - - + - ✅︎ - ✅︎ * - :code:`UltravoxModel` - Ultravox diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cd9e53d889c18..1443876219eb0 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -943,6 +943,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, } # LoRA specific attributes + # TODO Support LoRA for the visual encoder in the future. supported_lora_modules = [ "qkv_proj", "o_proj",