Skip to content

Commit

Permalink
Integrate int4kv to benchmark_attn_decode.py (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
scxiao authored May 8, 2024
1 parent 8294eba commit 60d5f11
Showing 1 changed file with 127 additions and 1 deletion.
128 changes: 127 additions & 1 deletion xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,43 @@
]


def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
"""
Auxiliary int4 row quantization function used for benchmarking and tests.
Matches the behaviour of torch.ops.llama_cpp.dequantize_int4_cache -
quantization parameters (scale and offset) of each row along the last
dimension of the tensor are assumed to be packed into two float16 values
at the beginning of the row.
"""
# Scale and shift are such that quantization linearly maps int4 values range [0..15]
# to input values range min(k)..max(k) individually for every row
k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
# print(f"k_reshape = {k.shape}")
max_vals = torch.max(k, dim=-1, keepdim=True).values
min_vals = torch.min(k, dim=-1, keepdim=True).values
scale_k: torch.Tensor = (max_vals - min_vals) / 15
# print(f"scale_k_shape = {scale_k.shape}")

shift_k = torch.min(k, dim=-1, keepdim=True).values
scale_k = scale_k.to(torch.float16)
shift_k = shift_k.to(torch.float16)
in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
in_bytes = in_bytes.to(torch.uint8)
in_int4 = in_bytes & 0xF
in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
scale_shift = torch.concat(
[scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1
)
k_quant = torch.concat(
[
scale_shift.flatten(start_dim=-2),
in_int4_packed.flatten(start_dim=-2),
],
dim=-1,
).view(torch.int16)
return k_quant


class AttentionDecodingBase:
OP: Any = None

Expand Down Expand Up @@ -141,14 +178,102 @@ class AttentionDecodingCKSplitKV(AttentionDecodingBase):
OP = xops.fmha.ck_splitk.FwOp


class AttentionDecodingSplitInt4KV(AttentionDecodingBase):
OP = xops.fmha.triton_splitk.FwOp

def __init__(
self,
B: int,
Mq: int,
Mkv: int,
Hq: int,
Hkv: int,
K: int,
bw: bool,
attn_bias_type,
) -> None:
# super(AttentionDecodingSplitInt4KV, self).__init__(B, Mq, Mkv, Hq, Hkv, K, bw, attn_bias_type)
dtype = torch.float16
torch.manual_seed(10)
self.sub_label = (
f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes="
f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}"
)
self.label = "attn_decoding"
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)

assert Hkv <= Hq
assert Hq % Hkv == 0
self.q = torch.randn(
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
)
self.k = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
)
self.v = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
)

num_groups = 1
self.k = (
quantize_kv_int4(self.k, num_groups=num_groups)
.contiguous()
.view(torch.int32)
).expand(-1, -1, -1, Hq // Hkv, -1)
self.v = (
quantize_kv_int4(self.v, num_groups=num_groups)
.contiguous()
.view(torch.int32)
).expand(-1, -1, -1, Hq // Hkv, -1)

if Hq == Hkv:
self.q = self.q[:, :, :, 0]
self.k = self.k[:, :, :, 0]
self.v = self.v[:, :, :, 0]
if Hkv == 1:
self.q = self.q[:, :, 0]
self.k = self.k[:, :, 0]
self.v = self.v[:, :, 0]

self.attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=Hq,
num_heads_groups=Hq // Hkv,
q_len=Mq,
kv_len=Mkv,
dtype=dtype,
device=device,
requires_grad=False,
fmt="BMHK",
op=self.OP,
)

if isinstance(
self.attn_bias,
xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
):
self.q = self.q.view(1, -1, *self.q.shape[2:])
self.k = self.k.view(1, -1, *self.k.shape[2:])
self.v = self.v.view(1, -1, *self.v.shape[2:])

if hasattr(self.OP, "not_supported_reasons"):
inp = xops.fmha.Inputs(
query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias
)
not_supported_reasons = self.OP.not_supported_reasons(inp)
if not_supported_reasons:
raise NotSupportedInputError(not_supported_reasons)


class AttentionDecodingPyTorchRepeat(AttentionDecodingBase):
def fw(self) -> None:
B, Mq, Mkv, Hq, Hkv, K = self.shapes
scale = 1 / K**0.5
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
attn = (q @ k.transpose(-1, -2) * scale).softmax(-1)
return attn @ v


Expand All @@ -172,6 +297,7 @@ def fw(self) -> None:

if (sys.version_info.major, sys.version_info.minor) >= (3, 9):
BENCHMARKS["triton_splitK"] = AttentionDecodingSplitKV
BENCHMARKS["triton_int4KV"] = AttentionDecodingSplitInt4KV

try:
import flash_attn
Expand Down

0 comments on commit 60d5f11

Please sign in to comment.