Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
scxiao committed May 3, 2024
1 parent 2ea2b49 commit 50d747c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,17 @@ class AttentionDecodingCKSplitKV(AttentionDecodingBase):

class AttentionDecodingSplitInt4KV(AttentionDecodingBase):
OP = xops.fmha.triton_splitk.FwOp
def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool,
attn_bias_type

def __init__(
self,
B: int,
Mq: int,
Mkv: int,
Hq: int,
Hkv: int,
K: int,
bw: bool,
attn_bias_type,
) -> None:
# super(AttentionDecodingSplitInt4KV, self).__init__(B, Mq, Mkv, Hq, Hkv, K, bw, attn_bias_type)
dtype = torch.float16
Expand Down Expand Up @@ -255,7 +264,7 @@ def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: boo
not_supported_reasons = self.OP.not_supported_reasons(inp)
if not_supported_reasons:
raise NotSupportedInputError(not_supported_reasons)


class AttentionDecodingPyTorchRepeat(AttentionDecodingBase):
def fw(self) -> None:
Expand Down

0 comments on commit 50d747c

Please sign in to comment.