Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loop unroll for hstu attn bwd #143

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,7 @@ def _ragged_hstu_attn_bwd_one_col_block( # noqa C901
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
UNROLL: tl.constexpr,
ATOMIC_ADD: tl.constexpr,
):
# Work on the subsequence dv[start_n, start_n + BLOCK_N, :]
Expand Down Expand Up @@ -1752,7 +1753,7 @@ def _ragged_hstu_attn_bwd_one_col_block( # noqa C901
ATOMIC_ADD=ATOMIC_ADD,
)
# pyre-ignore[61]
for start_m in range(low, high, BLOCK_M):
for start_m in tl.range(low, high, BLOCK_M, loop_unroll_factor=UNROLL):
start_m = tl.multiple_of(start_m, BLOCK_M)
dk, dv = _ragged_hstu_attn_bwd_one_block(
start_m=start_m,
Expand Down Expand Up @@ -1841,6 +1842,7 @@ def _get_bw_configs() -> List[triton.Config]:
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"waves_per_eu": waves_per_eu,
"SEQUENCE_PARALLEL": sp,
"UNROLL": 1,
},
num_stages=num_stages,
num_warps=num_warps,
Expand All @@ -1851,157 +1853,169 @@ def _get_bw_configs() -> List[triton.Config]:

configs = [
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=3,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 2},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 4},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=3,
num_warps=8,
pre_hook=_bwd_pre_hook,
Expand Down Expand Up @@ -2088,6 +2102,7 @@ def _ragged_hstu_attn_bwd( # noqa C901
SEQUENCE_PARALLEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
UNROLL: tl.constexpr,
HAS_SORT_BY_LENGTH_INDICES: tl.constexpr,
):
off_hz = tl.program_id(0)
Expand Down Expand Up @@ -2181,6 +2196,7 @@ def _ragged_hstu_attn_bwd( # noqa C901
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
UNROLL=UNROLL,
ATOMIC_ADD=True,
)
else:
Expand Down Expand Up @@ -2234,6 +2250,7 @@ def _ragged_hstu_attn_bwd( # noqa C901
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
UNROLL=UNROLL,
ATOMIC_ADD=False,
)

Expand Down