Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhou committed Nov 19, 2024
1 parent ae62c82 commit 1b0e54b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import is_navi3


class PatchEmbedding(nn.Module):
Expand Down Expand Up @@ -80,8 +81,7 @@ def __init__(
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

def forward(self, x: torch.Tensor) -> torch.Tensor:
_ON_NAVI3 = "gfx11" in torch.cuda.get_device_properties("cuda").gcnArchName
if _ON_NAVI3:
if is_navi3():
try:
# git clone -b howiejay/navi_support https://github.com/ROCm/flash-attention.git
from flash_attn import flash_attn_func
Expand Down
8 changes: 8 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,14 @@ def is_navi() -> bool:
archName = torch.cuda.get_device_properties('cuda').gcnArchName
return archName is not None and "gfx1" in archName

@lru_cache(maxsize=None)
def is_navi3() -> bool:
if not current_platform.is_rocm() or not torch.cuda.is_available():
return False
# All (visible) GPUs must be of the same type,
# otherwise FP8 results can't be guaranteed.
archName = torch.cuda.get_device_properties('cuda').gcnArchName
return archName is not None and "gfx11" in archName

def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
Expand Down

0 comments on commit 1b0e54b

Please sign in to comment.