Skip to content

Commit

Permalink
first_seqpos and seqpos inputs for rope_padded (fairinternal/xformers…
Browse files Browse the repository at this point in the history
  • Loading branch information
bottler authored and xFormers Bot committed Apr 4, 2024
1 parent 3fea61a commit b3fe601
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
46 changes: 46 additions & 0 deletions tests/test_rope_padded.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,49 @@ def test_rope_prefill(seqlen) -> None:
expected_out = _slow_rope2(xq, seqpos=seqpos, adjacents=adjacents)
atol, rtol = ROPE_ATOL_RTOL["bf16"]
assert_allclose(out, expected_out, atol=atol, rtol=rtol)


@cuda_sm80_only
def test_rope_seqpos() -> None:
heads, kvheads = 2, 1
dim = 32
device = "cuda"
adjacents = True
dtype = torch.bfloat16
seqlen = 723

attn_bias = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[seqlen], kv_padding=seqlen + 1, kv_seqlen=[seqlen]
)
cache_k = torch.rand(1, seqlen + 1, kvheads, dim, device=device, dtype=dtype)
cache_v = torch.randn_like(cache_k)
xq = torch.rand(1, seqlen, heads, dim, device=device, dtype=dtype)
xk = torch.rand(1, seqlen, kvheads, dim, device=device, dtype=dtype)
xv = torch.rand(1, seqlen, kvheads, dim, device=device, dtype=dtype)

def inner(seqpos, *, first_seqpos_input=None, seqpos_input=None):
out = rope_padded(
xq,
xk,
xv,
cache_k,
cache_v,
attn_bias,
adjacents=adjacents,
first_seqpos=first_seqpos_input,
seqpos=seqpos_input,
)

expected_out = _slow_rope2(xq, seqpos=seqpos, adjacents=adjacents)
atol, rtol = ROPE_ATOL_RTOL["bf16"]
assert_allclose(out, expected_out, atol=atol, rtol=rtol)

inner(torch.arange(start=0, end=seqlen, device=device))
inner(
torch.arange(start=4, end=seqlen + 4, device=device),
first_seqpos_input=torch.tensor([4], device=device),
)
custom_seqpos = torch.arange(start=0, end=seqlen, device=device)
custom_seqpos[231] = 934
custom_seqpos[423] = 134
inner(custom_seqpos, seqpos_input=custom_seqpos)
10 changes: 9 additions & 1 deletion xformers/ops/_triton/rope_padded_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def _rope_padded_kernel(
seqstartk,
seqlenk,
theta,
first_seqpos,
seqpos,
k_start: tl.constexpr,
v_start: tl.constexpr,
n_groups,
Expand All @@ -51,6 +53,7 @@ def _rope_padded_kernel(
stride_outqM,
stride_outqG,
stride_outqH,
stride_seqpos,
internal_dtype: tl.constexpr,
# If True, seqstartq and seqstartk are not used but rather we
# assume that every batch element has the same number of
Expand Down Expand Up @@ -130,7 +133,12 @@ def _rope_padded_kernel(
)

cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
seq_pos = cache_pos - cache_start
if seqpos is not None:
seq_pos = tl.load(seqpos + query_pos * stride_seqpos)
else:
seq_pos = cache_pos - cache_start
if first_seqpos is not None:
seq_pos += tl.load(first_seqpos + batch_elt * stride_seqpos)
cache_k += (
(head_idx - k_start) * stride_cachekH
+ cache_pos * stride_cachekM
Expand Down
30 changes: 30 additions & 0 deletions xformers/ops/rope_padded.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def rope_padded(
*,
theta: float = 10000.0,
out_q: Optional[torch.Tensor] = None,
first_seqpos: Optional[torch.Tensor] = None,
seqpos: Optional[torch.Tensor] = None,
adjacents: bool = True,
internal_dtype: str = "",
):
Expand Down Expand Up @@ -57,6 +59,15 @@ def rope_padded(
Used to determine frequencies for the
RoPE calculation as well as the locations in cache_k and cache_v
to write to. Must be on the device.
first_seqpos: Optionally a tensor containing the sequence position of the
beginning of the cache for each batch element.
Providing a tensor of zeros is the same as providing None.
This affects the numerical calculation but not which memory
locations are read or written.
seqpos: Optionally a 1D tensor containing the sequence position of each
query. This should have length equal to xq.shape[1] .
This affects the numerical calculation but not which memory
locations are read or written.
adjacents: If True, the inputs are in adjacent pairs along the final dim axis.
This is like the released LLaMA model.
If False, the dim axis is split in two equal pieces.
Expand Down Expand Up @@ -180,6 +191,22 @@ def rope_padded(

logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1

if first_seqpos is not None and seqpos is not None:
raise ValueError("seqpos and first_seqpos may not both be provided")
stride_seqpos = 0
if first_seqpos is not None:
if first_seqpos.shape != (logical_bsz,):
shape = tuple(first_seqpos.shape)
raise ValueError(
f"first_seqpos.shape {shape} but ({logical_bsz},) expected."
)
stride_seqpos = first_seqpos.stride(0)
elif seqpos is not None:
if seqpos.shape != (n_total_queries,):
shape = tuple(seqpos.shape)
raise ValueError(f"seqpos.shape {shape} but ({n_total_queries},) expected.")
stride_seqpos = seqpos.stride(0)

# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // xq.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim))
Expand Down Expand Up @@ -210,6 +237,8 @@ def rope_padded(
seqstartk,
seqlenk,
theta,
first_seqpos,
seqpos,
k_start,
v_start,
n_groups,
Expand All @@ -235,6 +264,7 @@ def rope_padded(
out_q_stride[1],
out_q_stride[2] if ndim == 5 else 0,
out_q_stride[-2],
stride_seqpos,
internal_dtype,
const_batch_strides=False,
cache_padding_length=0,
Expand Down

0 comments on commit b3fe601

Please sign in to comment.