diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 8f8a7a6783..21a9e82d5e 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -31,16 +31,6 @@ ) for i in range(8, 18) for hkv in (1, 2) -# ] + [ -# dict( -# B=2, -# Mq=1, -# Mkv=8448, -# Hq=8, -# Hkv=1, -# K=128, -# attn_bias_type=xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, -# ) ] @@ -215,7 +205,7 @@ def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: boo [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw ) - num_groups = 4 + num_groups = 1 self.k = ( quantize_kv_int4(self.k, num_groups=num_groups) .contiguous() diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index 609e48a1ba..9c544070f4 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -637,7 +637,7 @@ def dequantize( x_[:, :, None, :] >> offsets ) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL) - quant_offset = tl.reshape( + quant_offset = tl.view( quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) ) # Trick - instead of converting int4 to float16 we view it as float16