Skip to content

Commit

Permalink
[fused seqpar] Support multi-query settings in Triton
Browse files Browse the repository at this point in the history
In allgather+matmul we required the strides of the weight and of the output matrices to be equal. When using multi-query, wq is larger than wk/wv, thus its strides are different. (This is actually not true, because wq/wk/wv are column-major, hence their stride is in_feats, not out_feats, but the outputs are row-major thus their strides pose a problem).

The solution is just to support arbitrary strides. It's very ugly and bulky since we need to pass tons of parameters around, but it works.

ghstack-source-id: d40a452c7121fdd9992773bfc2c14aeba503804e
Pull Request resolved: fairinternal/xformers#1007

__original_commit__ = fairinternal/xformers@cb71011
  • Loading branch information
lw authored and xFormers Bot committed Jan 26, 2024
1 parent e5e812a commit 8894c69
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 23 deletions.
8 changes: 4 additions & 4 deletions tests/test_seqpar.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,23 @@ def my_chunk(t, *, dim):
).round()
weight1, weight2 = [
torch.testing.make_tensor(
(inner_dim, outer_dim),
(inner_dim * (idx + 1), outer_dim),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
for _ in range(2)
for idx in range(2)
]
gradient1, gradient2 = [
torch.testing.make_tensor(
batch_dims + (inner_dim,),
batch_dims + (inner_dim * (idx + 1),),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
for _ in range(2)
for idx in range(2)
]

# Non-fused reference code
Expand Down
136 changes: 117 additions & 19 deletions xformers/ops/_triton/sequence_parallel_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,18 @@ def determine_tile(
world_size,
direction,
stride_am,
stride_cm,
stride_bk1,
stride_bk2,
stride_bk3,
stride_bn1,
stride_bn2,
stride_bn3,
stride_cm1,
stride_cm2,
stride_cm3,
stride_cn1,
stride_cn2,
stride_cn3,
BLOCK_M,
BLOCK_N,
GROUP_M,
Expand Down Expand Up @@ -173,6 +184,26 @@ def determine_tile(
C1_my_shard,
tl.where(pid_n < grid_n1 + grid_n2, C2_my_shard, C3_my_shard),
)
stride_bk = tl.where(
pid_n < grid_n1,
stride_bk1,
tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3),
)
stride_bn = tl.where(
pid_n < grid_n1,
stride_bn1,
tl.where(pid_n < grid_n1 + grid_n2, stride_bn2, stride_bn3),
)
stride_cm = tl.where(
pid_n < grid_n1,
stride_cm1,
tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3),
)
stride_cn = tl.where(
pid_n < grid_n1,
stride_cn1,
tl.where(pid_n < grid_n1 + grid_n2, stride_cn2, stride_cn3),
)
N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3))
pid_n = tl.where(
pid_n < grid_n1,
Expand All @@ -187,7 +218,21 @@ def determine_tile(
other_rank == my_rank, C_my_shard, C + other_rank * M_per_rank * stride_cm
)

return A, B, C, M_per_rank, N, pid_m, pid_n, other_rank, blocks_per_rank
return (
A,
B,
C,
M_per_rank,
N,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
pid_m,
pid_n,
other_rank,
blocks_per_rank,
)


@triton.jit
Expand Down Expand Up @@ -338,7 +383,21 @@ def our_estimate_matmul_time(B1, C1, N1, N2, N3, **kwargs):
}
)
@triton.jit(
do_not_specialize=[11, 12, 13, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
do_not_specialize=[
"wait_counters",
"blocks_done_counters",
"write_counters",
"do_wait",
"do_write",
"direction",
"stripe",
"seq_num",
"num_stripes",
"_wait",
"my_rank",
"world_size",
"timeout_ns",
],
debug=True, # To avoid stripping device asserts
)
def _xformers_seqpar_matmul_kernel(
Expand All @@ -363,10 +422,18 @@ def _xformers_seqpar_matmul_kernel(
K, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bk1,
stride_bk2,
stride_bk3,
stride_bn1,
stride_bn2,
stride_bn3,
stride_cm1,
stride_cm2,
stride_cm3,
stride_cn1,
stride_cn2,
stride_cn3,
do_wait,
do_write,
direction,
Expand All @@ -385,7 +452,21 @@ def _xformers_seqpar_matmul_kernel(
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
A, B, C, M, N, pid_m, pid_n, other_rank, num_blocks_2d = determine_tile(
(
A,
B,
C,
M,
N,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
pid_m,
pid_n,
other_rank,
num_blocks_2d,
) = determine_tile(
A,
B1,
B2,
Expand All @@ -405,7 +486,18 @@ def _xformers_seqpar_matmul_kernel(
world_size,
direction,
stride_am,
stride_cm,
stride_bk1,
stride_bk2,
stride_bk3,
stride_bn1,
stride_bn2,
stride_bn3,
stride_cm1,
stride_cm2,
stride_cm3,
stride_cn1,
stride_cn2,
stride_cn3,
BLOCK_M,
BLOCK_N,
GROUP_M,
Expand Down Expand Up @@ -504,13 +596,11 @@ def _launch_triton_matmul(
assert all(c.shape[0] == M for c in cs)
assert all(c.shape[1] == N for c, N in zip(cs, Ns))
stride_am, stride_ak = cast(Tuple[int, int], a.stride())
stride_bk, stride_bn = cast(Tuple[int, int], bs[0].stride())
stride_cm, stride_cn = cast(Tuple[int, int], cs[0].stride())
assert all(b.stride() == (stride_bk, stride_bn) for b in bs)
assert all(c.stride() == (stride_cm, stride_cn) for c in cs)
strides_bk, strides_bn = zip(*(cast(Tuple[int, int], b.stride()) for b in bs))
strides_cm, strides_cn = zip(*(cast(Tuple[int, int], c.stride()) for c in cs))
assert stride_am == 1 or stride_ak == 1
assert stride_bk == 1 or stride_bn == 1
assert stride_cm == 1 or stride_cn == 1
assert all(s == 1 for s in strides_bk) or all(s == 1 for s in strides_bn)
assert all(s == 1 for s in strides_cm) or all(s == 1 for s in strides_cn)

if a_my_shard is not None:
assert a_my_shard.ndim == 2
Expand Down Expand Up @@ -615,10 +705,18 @@ def grid(META):
kwargs = dict(
stride_am=stride_am,
stride_ak=stride_ak,
stride_bk=stride_bk,
stride_bn=stride_bn,
stride_cm=stride_cm,
stride_cn=stride_cn,
stride_bk1=strides_bk[0],
stride_bk2=strides_bk[min(1, len(strides_bk) - 1)],
stride_bk3=strides_bk[min(2, len(strides_bk) - 1)],
stride_bn1=strides_bn[0],
stride_bn2=strides_bn[min(1, len(strides_bn) - 1)],
stride_bn3=strides_bn[min(2, len(strides_bn) - 1)],
stride_cm1=strides_cm[0],
stride_cm2=strides_cm[min(1, len(strides_cm) - 1)],
stride_cm3=strides_cm[min(2, len(strides_cm) - 1)],
stride_cn1=strides_cn[0],
stride_cn2=strides_cn[min(1, len(strides_cn) - 1)],
stride_cn3=strides_cn[min(2, len(strides_cn) - 1)],
do_wait=do_wait,
do_write=do_write,
direction=direction,
Expand Down

0 comments on commit 8894c69

Please sign in to comment.