Skip to content

Commit

Permalink
int4kv test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
scxiao committed Apr 19, 2024
1 parent 39a3f0d commit 7fe8aa0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 12 deletions.
12 changes: 1 addition & 11 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,6 @@
)
for i in range(8, 18)
for hkv in (1, 2)
# ] + [
# dict(
# B=2,
# Mq=1,
# Mkv=8448,
# Hq=8,
# Hkv=1,
# K=128,
# attn_bias_type=xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
# )
]


Expand Down Expand Up @@ -215,7 +205,7 @@ def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: boo
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
)

num_groups = 4
num_groups = 1
self.k = (
quantize_kv_int4(self.k, num_groups=num_groups)
.contiguous()
Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def dequantize(
x_[:, :, None, :] >> offsets
) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)

quant_offset = tl.reshape(
quant_offset = tl.view(
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
)
# Trick - instead of converting int4 to float16 we view it as float16
Expand Down

0 comments on commit 7fe8aa0

Please sign in to comment.