From 1b0e54be247a00ebbec9d84e9a194c7dd6935921 Mon Sep 17 00:00:00 2001 From: jzhou Date: Tue, 19 Nov 2024 10:41:31 +0800 Subject: [PATCH] refine --- vllm/model_executor/models/glm4_vision_encoder.py | 4 ++-- vllm/utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index fd32e5a285ccf..0d67366c150cc 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -18,6 +18,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.utils import is_navi3 class PatchEmbedding(nn.Module): @@ -80,8 +81,7 @@ def __init__( self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: - _ON_NAVI3 = "gfx11" in torch.cuda.get_device_properties("cuda").gcnArchName - if _ON_NAVI3: + if is_navi3(): try: # git clone -b howiejay/navi_support https://github.com/ROCm/flash-attention.git from flash_attn import flash_attn_func diff --git a/vllm/utils.py b/vllm/utils.py index 211d3e86c8b05..5c2a4b96775e3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1641,6 +1641,14 @@ def is_navi() -> bool: archName = torch.cuda.get_device_properties('cuda').gcnArchName return archName is not None and "gfx1" in archName +@lru_cache(maxsize=None) +def is_navi3() -> bool: + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx11" in archName def weak_ref_tensors( tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]