diff --git a/third_party/flash-attention b/third_party/flash-attention index 6c9e60de56..85881f547f 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 6c9e60de566800538fedad2ad5e6b7b55ca7f0c5 +Subproject commit 85881f547fd1053a7b4a2c3faad6690cca969279 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 1d5e8b20fa..20bddb8642 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -49,8 +49,8 @@ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention FLASH_VERSION = flash_attn.__version__ - FLASH_VER_MIN = (2, 5, 2) - FLASH_VER_LAST = (2, 5, 6) # last supported, inclusive + FLASH_VER_MIN = (2, 5, 7) + FLASH_VER_LAST = (2, 5, 7) # last supported, inclusive flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) if ( flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST @@ -145,6 +145,7 @@ def _flash_fwd( cu_seq_lens_q, cu_seq_lens_k, seqused_k, + None, # block_table None, # alibi_slopes max_seq_len_q, max_seq_len_k,