Skip to content

Commit

Permalink
flash_attn3 not directly import (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Nov 19, 2024
1 parent 856fdd0 commit 2bdbcc3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions yunchang/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
flash_attn_backward,
flash_attn3_func_forward,
flash_attn3_func_backward,
flash3_attn_func,
torch_attn
torch_attn,
HAS_FLASH_ATTN_HOPPER
)
from enum import Enum, auto
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -47,8 +47,10 @@ def fn(q,
causal=False,
*args, **kwargs
):
assert HAS_FLASH_ATTN_HOPPER, "FlashAttention3 is not available! install it from https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release"
# (q, k, v, softmax_scale=None, causal=False, window_size=(-1, -1),
# deterministic=False, descale_q=None, descale_k=None, descale_v=None, gqa_parallel=False)
from .attention import flash3_attn_func
assert softmax_scale is not None, f"softmax_scale is required for FA3"
assert dropout_p == 0.0, f"dropout_p: {dropout_p} is not supported for FA3"
return flash3_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal)
Expand Down

0 comments on commit 2bdbcc3

Please sign in to comment.