Skip to content

Commit

Permalink
export: splitk fix and fa3 bwd use (fairinternal/xformers#1214)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@81491ed
  • Loading branch information
bottler authored and xFormers Bot committed Sep 3, 2024
1 parent 2dcb363 commit 67c5055
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
1 change: 1 addition & 0 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ def backward(
ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
cutlass.BwOp if torch.version.cuda else ck.BwOp,
flash.BwOp,
flash3.BwOp,
]

__all__ = [
Expand Down
49 changes: 41 additions & 8 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,11 +1336,11 @@ class FwOp(AttentionFwOpBase):
NUM_GROUPS = 1 # Default quantization is row-wise
NUM_GROUPS_VALUES = [1, 2, 4, 8]

# values used when autotune=False
BLOCK_M: int = 16
# Values below are used when autotune=False.
# Note that under certain conditions different values might be used, see the code just before the kernel launch.
BLOCK_M: int = 16 # When M > 1, different BLOCK_M can be used.
BLOCK_N: int = 64
# On AMD these two values are overwritten depending on input shapes, see the code just before the kernel launch
# This might change once we get autotuning working on AMD
# On AMD or for M > 1 different NUM_STAGES and NUM_WARPS can be used.
NUM_STAGES: int = 1
NUM_WARPS: int = 2

Expand Down Expand Up @@ -1429,7 +1429,7 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
return reasons

@classmethod
def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int:
def get_split_k(cls, B: int, G: int, H: int, Mk: int, Mq: int) -> int:
"""Heuristic for the number of splits"""
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
if torch.version.hip:
Expand All @@ -1450,6 +1450,8 @@ def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int:

split_k_upper_bound = 512
else:
if Mq > 1 and B * G * H > 64:
return 1
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
split_k_stop_val = Mk / max_chunk_size
Expand Down Expand Up @@ -1637,7 +1639,9 @@ def apply(
split_k = cls.SPLIT_K
else:
# Use heuristics
split_k = cls.get_split_k(B, G, H, Mk) if attn_bias_tensor is None else 1
split_k = (
cls.get_split_k(B, G, H, Mk, Mq) if attn_bias_tensor is None else 1
)

# M_ceil = Mqq rounded up to a multiple of MAX_BLOCK_M
M_ceil = (Mqq + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M
Expand Down Expand Up @@ -1694,6 +1698,8 @@ def grid(META):
use_seq_len = seq_len is not None

kernel = cls.get_kernel()
BLOCK_M = cls.BLOCK_M
BLOCK_N = cls.BLOCK_N
if cls.AUTOTUNE:
extra_args = {}
else:
Expand All @@ -1704,15 +1710,42 @@ def grid(META):
if B == 1:
num_warps = 4
num_stages = 1 # TODO num_stages = 0 gives better perf on AMD, but sometimes produces NaNs
BLOCK_N = 32
elif B <= 4 and split_k <= 128:
num_warps = 2
num_stages = 1
BLOCK_N = 32
elif B <= 16:
if M < 16:
num_warps = 2
num_stages = 1
else:
num_warps = 1
num_stages = 1
BLOCK_N = 32
else:
num_warps = 1
num_stages = 1
BLOCK_N = 64
else:
should_modify_warp_and_block = (
Kkv == 128
and Kq == 128
and torch.cuda.get_device_capability() >= (8, 9)
)
if should_modify_warp_and_block:
if Mq > 1:
num_warps = 4
# Choose minimal round block size which covers M.
if M > 16:
BLOCK_M = 32
if M > 32:
BLOCK_M = 64
if M > 64:
BLOCK_M = 128
extra_args = {
"BLOCK_M": cls.BLOCK_M,
"BLOCK_N": cls.BLOCK_N,
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"num_warps": num_warps,
"num_stages": num_stages,
}
Expand Down

0 comments on commit 67c5055

Please sign in to comment.