diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bb828d6fc04fe..94f3f55636ed6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -231,8 +231,9 @@ def __init__( self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - # if not using triton, navi3x not use flash-attn either - if torch.cuda.get_device_capability()[0] == 11: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if torch.cuda.get_device_capability()[0] != 9: self.use_naive_attn = True else: try: