From ae62c8233661dd2706b3eb6b6e7a0b3c1e4b173b Mon Sep 17 00:00:00 2001 From: jzhou Date: Fri, 15 Nov 2024 13:33:13 +0800 Subject: [PATCH] use CK FA --- .../models/glm4_vision_encoder.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 025615b0920fd..fd32e5a285ccf 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Adapted from # https://github.com/THUDM/GLM-4 """Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" @@ -79,6 +80,31 @@ def __init__( self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: + _ON_NAVI3 = "gfx11" in torch.cuda.get_device_properties("cuda").gcnArchName + if _ON_NAVI3: + try: + # git clone -b howiejay/navi_support https://github.com/ROCm/flash-attention.git + from flash_attn import flash_attn_func + B, L, _ = x.shape + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + q = q.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + k = k.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + v = v.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + + out = flash_attn_func(q, k, v) + + output, _ = self.dense(out.view(B, L, -1)) + output = self.output_dropout(output) + + return output + except ModuleNotFoundError: + pass + B, L, _ = x.shape qkv, _ = self.query_key_value(x) # B, L, 3 * H * D q, k, v = qkv.chunk(3, dim=-1)