From 60d5f11154922e9780327bfb495d7773ecc93457 Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Wed, 8 May 2024 11:51:47 -0500 Subject: [PATCH] Integrate int4kv to benchmark_attn_decode.py (#1029) --- .../benchmarks/benchmark_attn_decoding.py | 128 +++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 7c2f2db8fb..da79b78801 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -34,6 +34,43 @@ ] +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + """ + Auxiliary int4 row quantization function used for benchmarking and tests. + Matches the behaviour of torch.ops.llama_cpp.dequantize_int4_cache - + quantization parameters (scale and offset) of each row along the last + dimension of the tensor are assumed to be packed into two float16 values + at the beginning of the row. + """ + # Scale and shift are such that quantization linearly maps int4 values range [0..15] + # to input values range min(k)..max(k) individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + # print(f"k_reshape = {k.shape}") + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + # print(f"scale_k_shape = {scale_k.shape}") + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + class AttentionDecodingBase: OP: Any = None @@ -141,6 +178,94 @@ class AttentionDecodingCKSplitKV(AttentionDecodingBase): OP = xops.fmha.ck_splitk.FwOp +class AttentionDecodingSplitInt4KV(AttentionDecodingBase): + OP = xops.fmha.triton_splitk.FwOp + + def __init__( + self, + B: int, + Mq: int, + Mkv: int, + Hq: int, + Hkv: int, + K: int, + bw: bool, + attn_bias_type, + ) -> None: + # super(AttentionDecodingSplitInt4KV, self).__init__(B, Mq, Mkv, Hq, Hkv, K, bw, attn_bias_type) + dtype = torch.float16 + torch.manual_seed(10) + self.sub_label = ( + f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes=" + f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}" + ) + self.label = "attn_decoding" + self.shapes = (B, Mq, Mkv, Hq, Hkv, K) + + assert Hkv <= Hq + assert Hq % Hkv == 0 + self.q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ) + + num_groups = 1 + self.k = ( + quantize_kv_int4(self.k, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ).expand(-1, -1, -1, Hq // Hkv, -1) + self.v = ( + quantize_kv_int4(self.v, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ).expand(-1, -1, -1, Hq // Hkv, -1) + + if Hq == Hkv: + self.q = self.q[:, :, :, 0] + self.k = self.k[:, :, :, 0] + self.v = self.v[:, :, :, 0] + if Hkv == 1: + self.q = self.q[:, :, 0] + self.k = self.k[:, :, 0] + self.v = self.v[:, :, 0] + + self.attn_bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=Hq // Hkv, + q_len=Mq, + kv_len=Mkv, + dtype=dtype, + device=device, + requires_grad=False, + fmt="BMHK", + op=self.OP, + ) + + if isinstance( + self.attn_bias, + xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + ): + self.q = self.q.view(1, -1, *self.q.shape[2:]) + self.k = self.k.view(1, -1, *self.k.shape[2:]) + self.v = self.v.view(1, -1, *self.v.shape[2:]) + + if hasattr(self.OP, "not_supported_reasons"): + inp = xops.fmha.Inputs( + query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias + ) + not_supported_reasons = self.OP.not_supported_reasons(inp) + if not_supported_reasons: + raise NotSupportedInputError(not_supported_reasons) + + class AttentionDecodingPyTorchRepeat(AttentionDecodingBase): def fw(self) -> None: B, Mq, Mkv, Hq, Hkv, K = self.shapes @@ -148,7 +273,7 @@ def fw(self) -> None: q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) return attn @ v @@ -172,6 +297,7 @@ def fw(self) -> None: if (sys.version_info.major, sys.version_info.minor) >= (3, 9): BENCHMARKS["triton_splitK"] = AttentionDecodingSplitKV + BENCHMARKS["triton_int4KV"] = AttentionDecodingSplitInt4KV try: import flash_attn