Skip to content

Commit

Permalink
use CK FA
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhou committed Nov 15, 2024
1 parent 5362727 commit ae62c82
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8

Check failure on line 1 in vllm/model_executor/models/glm4_vision_encoder.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (UP009)

vllm/model_executor/models/glm4_vision_encoder.py:1:1: UP009 UTF-8 encoding declaration is unnecessary
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
Expand Down Expand Up @@ -79,6 +80,31 @@ def __init__(
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

def forward(self, x: torch.Tensor) -> torch.Tensor:
_ON_NAVI3 = "gfx11" in torch.cuda.get_device_properties("cuda").gcnArchName

Check failure on line 83 in vllm/model_executor/models/glm4_vision_encoder.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/glm4_vision_encoder.py:83:81: E501 Line too long (83 > 80)
if _ON_NAVI3:
try:
# git clone -b howiejay/navi_support https://github.com/ROCm/flash-attention.git
from flash_attn import flash_attn_func
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)

q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim) # B, L, H, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim) # B, L, H, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim) # B, L, H, D

out = flash_attn_func(q, k, v)

output, _ = self.dense(out.view(B, L, -1))
output = self.output_dropout(output)

return output
except ModuleNotFoundError:
pass

B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
Expand Down

0 comments on commit ae62c82

Please sign in to comment.