Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support quantization of Qwen2VisionTransformer for Qwen2-VL #9817

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,18 @@ def __init__(
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(in_features,
hidden_features,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features,
in_features,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
Expand Down Expand Up @@ -196,6 +199,7 @@ def __init__(
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
Expand All @@ -207,10 +211,12 @@ def __init__(

self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.proj")

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend()
Expand Down Expand Up @@ -310,6 +316,7 @@ def __init__(
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -321,11 +328,13 @@ def __init__(
self.attn = Qwen2VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mlp")

def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -374,6 +383,7 @@ def __init__(
norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
Expand All @@ -384,12 +394,14 @@ def __init__(
ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config),
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config),
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
])

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -440,6 +452,7 @@ def __init__(
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

Expand Down Expand Up @@ -467,28 +480,29 @@ def __init__(
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

self.blocks = nn.ModuleList([
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
) for _ in range(depth)
Qwen2VisionBlock(dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
)

@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
return self.patch_embed.proj.weight.dtype

@property
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
return self.patch_embed.proj.weight.device

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
Expand Down Expand Up @@ -932,10 +946,8 @@ def __init__(self,
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),

# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
quant_config=quant_config,
prefix="visual",
)

self.model = Qwen2Model(config,
Expand Down Expand Up @@ -1175,7 +1187,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
if "visual" in name and name.endswith("qkv.weight"):
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
Expand All @@ -1184,7 +1196,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
elif "visual" in name and name.endswith("qkv.bias"):
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
Expand Down