Skip to content

Commit

Permalink
[Relax] Integrate cuDNN attention
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 15, 2024
1 parent 820f1b6 commit abbcce9
Show file tree
Hide file tree
Showing 18 changed files with 763 additions and 315 deletions.
16 changes: 16 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ if(USE_CUDA)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY})
endif(USE_CUDNN)

if (USE_CUDNN_FRONTEND)
message(STATUS "Build with cuDNN Frontend support")
if (IS_DIRECTORY ${USE_CUDNN_FRONTEND})
find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h HINTS ${USE_CUDNN_FRONTEND}/include)
include_directories(SYSTEM ${USE_CUDNN_FRONTEND}/include)
else()
find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h)
endif()
if (NOT CUDNN_FRONTEND_HEADER)
message(FATAL_ERROR "Cannot find cudnn_frontend.h, please set USE_CUDNN_FRONTEND to the path of the cuDNN frontend header")
endif()
tvm_file_glob(GLOB CONTRIB_CUDNN_FRONTEND_SRCS src/runtime/contrib/cudnn/cudnn_frontend/*.cc)
set_property(SOURCE ${CONTRIB_CUDNN_SRCS} APPEND PROPERTY COMPILE_DEFINITIONS TVM_USE_CUDNN_FRONTEND=1)
list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_FRONTEND_SRCS})
endif(USE_CUDNN_FRONTEND)

if(USE_CUBLAS)
message(STATUS "Build with cuBLAS support")
tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
Expand Down
32 changes: 12 additions & 20 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,34 +868,26 @@ def handle_attention(self, f, op_type):
signature = _extract_relax_function_signature(f)

if _get_call_node(f.body, "relax.nn.attention") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
attention_node = _get_call_node(f.body, "relax.nn.attention")
op_attrs = attention_node.attrs
elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
attention_node = _get_call_node(f.body, "relax.nn.attention_bias")
op_attrs = attention_node.attrs
elif _get_call_node(f.body, "relax.nn.attention_var_len") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention_var_len").attrs
attention_node = _get_call_node(f.body, "relax.nn.attention_var_len")
op_attrs = attention_node.attrs
else:
raise ValueError("Cannot find call node for attention")
arg = {}

if "stacked_attention" in op_type:
arg["arg0_shape"] = signature["arg0_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_shape"] = q_shape = signature["arg1_shape"]

if "arg3_shape" not in signature:
# arg0: qkv, arg1: shape, arg2: workspace
arg["arg2_shape"] = k_shape = signature["arg1_shape"]
arg["arg3_shape"] = v_shape = signature["arg1_shape"]
else:
# arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: workspace
arg["arg2_shape"] = k_shape = signature["arg2_shape"]
arg["arg3_shape"] = v_shape = signature["arg3_shape"]

if "arg5_dtype" in signature:
# arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: bias, arg5: workspace
arg["bias_dtype"] = signature["arg4_dtype"]
if "arg5_shape" in signature:
arg["bias_shape"] = signature["arg4_shape"]
q_shape = get_const_tuple(attention_node.args[0].struct_info.shape)
k_shape = get_const_tuple(attention_node.args[1].struct_info.shape)
v_shape = get_const_tuple(attention_node.args[2].struct_info.shape)
if len(attention_node.args) == 4:
arg["bias_shape"] = get_const_tuple(attention_node.args[3].struct_info.shape)
arg["bias_dtype"] = attention_node.args[3].struct_info.dtype

qkv_layout = "qkv_stacked"
else:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,8 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["qkv"] = func_args[0]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
if len(func_args) > 5 and not is_var_len: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
if len(func_args) > 2 and not is_var_len: # +1 for workspace, the last arg
attrs["bias"] = func_args[1]
else:
raise NotImplementedError()

Expand Down
99 changes: 96 additions & 3 deletions python/tvm/relax/backend/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
# under the License.

"""Pattern table for cuDNN backend"""
from tvm.relax import transform
import operator
from functools import partial, reduce

import tvm
from tvm import relax
from tvm.relax import PyExprMutator, expr_functor, transform
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_conv2d_pattern
from ..patterns import make_conv2d_pattern, make_stacked_attention_pattern
from ..utils import has_leaking_intermediate_variables


Expand Down Expand Up @@ -60,6 +65,29 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
return True


def _check_stacked_attention(context: PatternCheckContext, layout: str) -> bool:
"""Check if the given stacked attention workload can be offloaded to cuDNN."""
if has_leaking_intermediate_variables(context):
return False
if layout == "BS3NH":
if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
return False
if "split" in context.annotated_expr:
split_op = context.annotated_expr["split"]
if not split_op.attrs.axis == 2:
return False
elif layout == "SBN3H":
if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4:
return False
if "split" in context.annotated_expr:
split_op = context.annotated_expr["split"]
if not split_op.attrs.axis == 3:
return False
else:
raise NotImplementedError(f"Unsupported layout: {layout}")
return True


register_patterns(
[
(
Expand All @@ -84,6 +112,16 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
),
_check_conv2d,
),
(
"cudnn.attention.BS3NH",
*make_stacked_attention_pattern(start_op="split", layout="BS3NH"),
partial(_check_stacked_attention, layout="BS3NH"),
),
(
"cudnn.attention.SBN3H",
*make_stacked_attention_pattern(start_op="split", layout="SBN3H"),
partial(_check_stacked_attention, layout="SBN3H"),
),
]
)

Expand All @@ -105,4 +143,59 @@ def partition_for_cudnn(mod):
"""

patterns = get_patterns_with_prefix("cudnn")
return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
return tvm.transform.Sequential(
[
transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True),
annotate_workspace,
transform.AllocateWorkspace(),
]
)(mod)


def _shape_1d(shape):
return reduce(operator.mul, shape, 1)


@expr_functor.mutator
class WorkspaceAnnotator(PyExprMutator):
"""Annotate a workspace requirement for each cuDNN-offloaded function."""

def __init__(self, mod):
super().__init__(mod)

def visit_function_(self, f):
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
new_f = relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

if "global_symbol" in f.attrs and "cudnn" in f.attrs["global_symbol"]:
composite_func = body.blocks[0].bindings[0].value
if "WorkspaceSize" in composite_func.attrs:
return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"])

return new_f

if "attention" in f.attrs["Composite"] and "cudnn" in f.attrs["Composite"]:
# Workspace is needed only for larger head sizes, but for simplicity we always allocate.
out_dtype = f.ret_struct_info.dtype
out_size_1d = _shape_1d(f.ret_struct_info.shape)
# This needs to be in sync with the actual value that the kernel expects.
workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype]
if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)):
# Tempororay workaround for dynamic shape workload. Will be removed when
# workspace for dynamic shape workload is implemented.
workspace_size_bytes = 8
return f.with_attr("WorkspaceSize", workspace_size_bytes)

return f


@tvm.transform.module_pass(opt_level=0)
def annotate_workspace(mod, _):
"""Pass to annotate a workspace requirement for each cuDNN-offloaded function."""
annotator = WorkspaceAnnotator(mod)
for name, f in mod.functions_items():
if isinstance(f, relax.Function):
new_f = annotator.visit_expr(f)
mod.update_func(name, new_f)
return mod
18 changes: 12 additions & 6 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,25 @@ def _check_stacked_attention(context: PatternCheckContext) -> bool:
if not split_op.attrs.axis == 2:
return False
else:
get_const_int_list = lambda tup: [int(e.value) for e in tup]
last_end = 0
for name in ["query", "key", "value"]:
assert f"strided_slice_{name}" in context.annotated_expr
strided_slice_op = context.annotated_expr[f"strided_slice_{name}"]
if list(strided_slice_op.attrs.axes) != [2]:
axes = get_const_int_list(strided_slice_op.args[1])
begins = get_const_int_list(strided_slice_op.args[2])
ends = get_const_int_list(strided_slice_op.args[3])
strides = get_const_int_list(strided_slice_op.args[4])

if axes != [2]:
return False
if list(strided_slice_op.attrs.begin) != [last_end]:
if begins != [last_end]:
return False
if not len(strided_slice_op.attrs.end) == 1:
if not len(ends) == 1:
return False
last_end = strided_slice_op.attrs.end[0]
if list(strided_slice_op.attrs.strides) != [1]:
if strides != [1]:
return False
last_end = ends[0]
return True


Expand Down Expand Up @@ -537,7 +543,7 @@ def visit_function_(self, f):

return new_f

if "attention" in f.attrs["Composite"]:
if "attention" in f.attrs["Composite"] and "cutlass" in f.attrs["Composite"]:
# Workspace is needed only for larger head sizes, but for simplicity we always allocate.
out_dtype = f.ret_struct_info.dtype
out_size_1d = _shape_1d(f.ret_struct_info.shape)
Expand Down
32 changes: 25 additions & 7 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def make_attention_pattern(with_bias: bool = False, var_len: bool = False):
return out, annotations


def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
def make_stacked_attention_pattern(start_op: str, with_bias: bool = False, layout="BS3NH"):
"""
Create pattern for fused multi head attention with stacked input.
Expand All @@ -272,6 +272,9 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
with_bias: bool
Whether or not to include bias addition
layout: str
The layout of the stacked input tensor.
Returns
-------
pattern: DFPattern
Expand All @@ -290,17 +293,28 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
key_raw = is_tuple_get_item(qkv_tuple, 1)
value_raw = is_tuple_get_item(qkv_tuple, 2)
elif start_op == "strided_slice":
ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(stacked_qkv)
ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(stacked_qkv)
ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(stacked_qkv)
ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(
stacked_qkv, varg_default_wildcard=True
)
ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(
stacked_qkv, varg_default_wildcard=True
)
ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(
stacked_qkv, varg_default_wildcard=True
)
else:
raise NotImplementedError()
query_reshape_list = wildcard()
key_reshape_list = wildcard()
value_reshape_list = wildcard()
query = is_op("relax.reshape")(query_raw, query_reshape_list)
key = is_op("relax.reshape")(key_raw, key_reshape_list)
value = is_op("relax.reshape")(value_raw, value_reshape_list)
if layout == "BS3NH":
query = is_op("relax.reshape")(query_raw, query_reshape_list)
key = is_op("relax.reshape")(key_raw, key_reshape_list)
value = is_op("relax.reshape")(value_raw, value_reshape_list)
elif layout == "SBN3H":
ops["q_transpose"] = query = is_op("relax.permute_dims")(query_raw)
ops["k_transpose"] = key = is_op("relax.permute_dims")(key_raw)
ops["v_transpose"] = value = is_op("relax.permute_dims")(value_raw)
annotations = {
"stacked_qkv": stacked_qkv,
"query_reshape_list": query_reshape_list,
Expand All @@ -314,6 +328,10 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
out = is_op("relax.nn.attention_bias")(query, key, value, bias)
else:
out = is_op("relax.nn.attention")(query, key, value)

if layout == "SBN3H":
out = is_op("relax.permute_dims")(out)

return out, annotations


Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,11 +1568,14 @@ def scaled_dot_product_attention(
Parameters
----------
query : Tensor
Tensor representing current attention lookup.
Tensor representing current attention lookup of shape
[batch, seq_len, num_heads, head_size].
key : Tensor
Tensor representing cross attention mapping.
Tensor representing cross attention mapping of shape
[batch, seq_len_kv, num_heads_kv, head_size].
value : Tensor
Tensor representing embedded attention values.
Tensor representing embedded attention values of shape
[batch, seq_len_kv, num_heads_kv, head_size_value].
attn_mask : Optional[Tensor]
Optional mask for attention, not yet supported.
is_causal : Optional[bool]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .relay_translator import *
from .ast_printer import dump_ast
from .matmul import *
from .attention import *
Loading

0 comments on commit abbcce9

Please sign in to comment.