diff --git a/src/sfast/libs/xformers/xformers_attention.py b/src/sfast/libs/xformers/xformers_attention.py index f4942d2..deb4ba9 100644 --- a/src/sfast/libs/xformers/xformers_attention.py +++ b/src/sfast/libs/xformers/xformers_attention.py @@ -4,17 +4,21 @@ from xformers import ops from sfast.utils.custom_python_operator import register_custom_python_operator -OP_STR_MAP = { - ops.MemoryEfficientAttentionCutlassFwdFlashBwOp: +OP_STR_MAP = {} + +for attr_name in [ 'MemoryEfficientAttentionCutlassFwdFlashBwOp', - ops.MemoryEfficientAttentionCutlassOp: 'MemoryEfficientAttentionCutlassOp', - ops.MemoryEfficientAttentionFlashAttentionOp: + 'MemoryEfficientAttentionCutlassOp', 'MemoryEfficientAttentionFlashAttentionOp', - ops.MemoryEfficientAttentionOp: 'MemoryEfficientAttentionOp', - ops.MemoryEfficientAttentionTritonFwdFlashBwOp: + 'MemoryEfficientAttentionOp', 'MemoryEfficientAttentionTritonFwdFlashBwOp', - ops.TritonFlashAttentionOp: 'TritonFlashAttentionOp', -} + 'TritonFlashAttentionOp', + 'MemoryEfficientAttentionCkOp', + 'MemoryEfficientAttentionSplitKCkOp' +]: + op_attr = getattr(ops, attr_name, None) + if op_attr is not None: + OP_STR_MAP[op_attr] = attr_name STR_OP_MAP = {v: k for k, v in OP_STR_MAP.items()}