From 5d5edd2fd8b891bb74681f83095d606739cadfcb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Jul 2024 12:36:06 -0700 Subject: [PATCH] [Relax] Integrate cuDNN attention (#17157) * [Relax] Integrate cuDNN attention * update cmake * lint * lint * cudnn frontend * lint * lint * fix test * skip test --- cmake/config.cmake | 7 + cmake/modules/CUDA.cmake | 16 ++ python/tvm/contrib/cutlass/build.py | 32 +-- python/tvm/contrib/cutlass/gen_tensor_op.py | 4 +- python/tvm/relax/backend/contrib/cudnn.py | 99 ++++++- python/tvm/relax/backend/contrib/cutlass.py | 18 +- python/tvm/relax/backend/patterns.py | 32 ++- python/tvm/relax/frontend/nn/op.py | 9 +- python/tvm/relax/testing/__init__.py | 1 + python/tvm/relax/testing/attention.py | 148 ++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/attention_python.py | 122 ++++++++ src/relax/backend/contrib/cudnn/codegen.cc | 47 +++ src/relax/transform/allocate_workspace.cc | 9 +- src/relax/transform/fuse_ops.cc | 19 +- .../contrib/cudnn/cudnn_frontend/attention.cc | 124 ++++++++ .../contrib/cudnn/cudnn_frontend/attention.h | 83 ++++++ .../contrib/cudnn/cudnn_json_runtime.cc | 267 +++++++++++------- tests/python/relax/test_codegen_cudnn.py | 65 ++++- tests/python/relax/test_codegen_cutlass.py | 213 ++++---------- .../test_transform_allocate_workspace.py | 3 +- ...est_transform_merge_composite_functions.py | 5 +- 22 files changed, 1010 insertions(+), 314 deletions(-) create mode 100644 python/tvm/relax/testing/attention.py create mode 100644 python/tvm/topi/testing/attention_python.py create mode 100644 src/runtime/contrib/cudnn/cudnn_frontend/attention.cc create mode 100644 src/runtime/contrib/cudnn/cudnn_frontend/attention.h diff --git a/cmake/config.cmake b/cmake/config.cmake index 416eec0dcb81..26d50630f7d3 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -245,6 +245,13 @@ set(USE_EDGETPU OFF) # - /path/to/cudnn: use specific path to cuDNN path set(USE_CUDNN OFF) +# Whether use cuDNN frontend +# Possible values: +# - ON: enable cuDNN frontend +# - /path/to/cudnn_frontend: use specific path to cuDNN frontend +# - OFF: disable cuDNN frontend +set(USE_CUDNN_FRONTEND OFF) + # Whether use cuBLAS set(USE_CUBLAS OFF) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index b7b405f82286..ad83ebe26b8c 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -77,6 +77,22 @@ if(USE_CUDA) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY}) endif(USE_CUDNN) + if (USE_CUDNN_FRONTEND) + message(STATUS "Build with cuDNN Frontend support") + if (IS_DIRECTORY ${USE_CUDNN_FRONTEND}) + find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h HINTS ${USE_CUDNN_FRONTEND}/include) + include_directories(SYSTEM ${USE_CUDNN_FRONTEND}/include) + else() + find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h) + endif() + if (NOT CUDNN_FRONTEND_HEADER) + message(FATAL_ERROR "Cannot find cudnn_frontend.h, please set USE_CUDNN_FRONTEND to the path of the cuDNN frontend header") + endif() + tvm_file_glob(GLOB CONTRIB_CUDNN_FRONTEND_SRCS src/runtime/contrib/cudnn/cudnn_frontend/*.cc) + set_property(SOURCE ${CONTRIB_CUDNN_SRCS} APPEND PROPERTY COMPILE_DEFINITIONS TVM_USE_CUDNN_FRONTEND=1) + list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_FRONTEND_SRCS}) + endif(USE_CUDNN_FRONTEND) + if(USE_CUBLAS) message(STATUS "Build with cuBLAS support") tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 1c0a30c62d91..5c09c79bd906 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -868,34 +868,26 @@ def handle_attention(self, f, op_type): signature = _extract_relax_function_signature(f) if _get_call_node(f.body, "relax.nn.attention") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention") + op_attrs = attention_node.attrs elif _get_call_node(f.body, "relax.nn.attention_bias") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention_bias") + op_attrs = attention_node.attrs elif _get_call_node(f.body, "relax.nn.attention_var_len") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention_var_len").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention_var_len") + op_attrs = attention_node.attrs else: raise ValueError("Cannot find call node for attention") arg = {} if "stacked_attention" in op_type: - arg["arg0_shape"] = signature["arg0_shape"] arg["arg0_dtype"] = signature["arg0_dtype"] - arg["arg1_shape"] = q_shape = signature["arg1_shape"] - - if "arg3_shape" not in signature: - # arg0: qkv, arg1: shape, arg2: workspace - arg["arg2_shape"] = k_shape = signature["arg1_shape"] - arg["arg3_shape"] = v_shape = signature["arg1_shape"] - else: - # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: workspace - arg["arg2_shape"] = k_shape = signature["arg2_shape"] - arg["arg3_shape"] = v_shape = signature["arg3_shape"] - - if "arg5_dtype" in signature: - # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: bias, arg5: workspace - arg["bias_dtype"] = signature["arg4_dtype"] - if "arg5_shape" in signature: - arg["bias_shape"] = signature["arg4_shape"] + q_shape = get_const_tuple(attention_node.args[0].struct_info.shape) + k_shape = get_const_tuple(attention_node.args[1].struct_info.shape) + v_shape = get_const_tuple(attention_node.args[2].struct_info.shape) + if len(attention_node.args) == 4: + arg["bias_shape"] = get_const_tuple(attention_node.args[3].struct_info.shape) + arg["bias_dtype"] = attention_node.args[3].struct_info.dtype qkv_layout = "qkv_stacked" else: diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2f21a1d313e2..5d04cf13e693 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -745,8 +745,8 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["qkv"] = func_args[0] attrs["num_queries"] = s = annotations["num_queries"] attrs["num_keys"] = annotations["num_keys"] - if len(func_args) > 5 and not is_var_len: # +1 for workspace, the last arg - attrs["bias"] = func_args[4] + if len(func_args) > 2 and not is_var_len: # +1 for workspace, the last arg + attrs["bias"] = func_args[1] else: raise NotImplementedError() diff --git a/python/tvm/relax/backend/contrib/cudnn.py b/python/tvm/relax/backend/contrib/cudnn.py index f730d4e5be0a..2f15e3a4fd19 100644 --- a/python/tvm/relax/backend/contrib/cudnn.py +++ b/python/tvm/relax/backend/contrib/cudnn.py @@ -16,11 +16,16 @@ # under the License. """Pattern table for cuDNN backend""" -from tvm.relax import transform +import operator +from functools import partial, reduce + +import tvm +from tvm import relax +from tvm.relax import PyExprMutator, expr_functor, transform from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_conv2d_pattern +from ..patterns import make_conv2d_pattern, make_stacked_attention_pattern from ..utils import has_leaking_intermediate_variables @@ -60,6 +65,29 @@ def _check_conv2d(context: PatternCheckContext) -> bool: return True +def _check_stacked_attention(context: PatternCheckContext, layout: str) -> bool: + """Check if the given stacked attention workload can be offloaded to cuDNN.""" + if has_leaking_intermediate_variables(context): + return False + if layout == "BS3NH": + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 2: + return False + elif layout == "SBN3H": + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 3: + return False + else: + raise NotImplementedError(f"Unsupported layout: {layout}") + return True + + register_patterns( [ ( @@ -84,6 +112,16 @@ def _check_conv2d(context: PatternCheckContext) -> bool: ), _check_conv2d, ), + ( + "cudnn.attention.BS3NH", + *make_stacked_attention_pattern(start_op="split", layout="BS3NH"), + partial(_check_stacked_attention, layout="BS3NH"), + ), + ( + "cudnn.attention.SBN3H", + *make_stacked_attention_pattern(start_op="split", layout="SBN3H"), + partial(_check_stacked_attention, layout="SBN3H"), + ), ] ) @@ -105,4 +143,59 @@ def partition_for_cudnn(mod): """ patterns = get_patterns_with_prefix("cudnn") - return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) + return tvm.transform.Sequential( + [ + transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True), + annotate_workspace, + transform.AllocateWorkspace(), + ] + )(mod) + + +def _shape_1d(shape): + return reduce(operator.mul, shape, 1) + + +@expr_functor.mutator +class WorkspaceAnnotator(PyExprMutator): + """Annotate a workspace requirement for each cuDNN-offloaded function.""" + + def __init__(self, mod): + super().__init__(mod) + + def visit_function_(self, f): + if "Composite" not in f.attrs: + body = super().visit_expr(f.body) + new_f = relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) + + if "global_symbol" in f.attrs and "cudnn" in f.attrs["global_symbol"]: + composite_func = body.blocks[0].bindings[0].value + if "WorkspaceSize" in composite_func.attrs: + return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"]) + + return new_f + + if "attention" in f.attrs["Composite"] and "cudnn" in f.attrs["Composite"]: + # Workspace is needed only for larger head sizes, but for simplicity we always allocate. + out_dtype = f.ret_struct_info.dtype + out_size_1d = _shape_1d(f.ret_struct_info.shape) + # This needs to be in sync with the actual value that the kernel expects. + workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] + if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)): + # Tempororay workaround for dynamic shape workload. Will be removed when + # workspace for dynamic shape workload is implemented. + workspace_size_bytes = 8 + return f.with_attr("WorkspaceSize", workspace_size_bytes) + + return f + + +@tvm.transform.module_pass(opt_level=0) +def annotate_workspace(mod, _): + """Pass to annotate a workspace requirement for each cuDNN-offloaded function.""" + annotator = WorkspaceAnnotator(mod) + for name, f in mod.functions_items(): + if isinstance(f, relax.Function): + new_f = annotator.visit_expr(f) + mod.update_func(name, new_f) + return mod diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 0d9f4ff8e923..80979bbe7e25 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -383,19 +383,25 @@ def _check_stacked_attention(context: PatternCheckContext) -> bool: if not split_op.attrs.axis == 2: return False else: + get_const_int_list = lambda tup: [int(e.value) for e in tup] last_end = 0 for name in ["query", "key", "value"]: assert f"strided_slice_{name}" in context.annotated_expr strided_slice_op = context.annotated_expr[f"strided_slice_{name}"] - if list(strided_slice_op.attrs.axes) != [2]: + axes = get_const_int_list(strided_slice_op.args[1]) + begins = get_const_int_list(strided_slice_op.args[2]) + ends = get_const_int_list(strided_slice_op.args[3]) + strides = get_const_int_list(strided_slice_op.args[4]) + + if axes != [2]: return False - if list(strided_slice_op.attrs.begin) != [last_end]: + if begins != [last_end]: return False - if not len(strided_slice_op.attrs.end) == 1: + if not len(ends) == 1: return False - last_end = strided_slice_op.attrs.end[0] - if list(strided_slice_op.attrs.strides) != [1]: + if strides != [1]: return False + last_end = ends[0] return True @@ -537,7 +543,7 @@ def visit_function_(self, f): return new_f - if "attention" in f.attrs["Composite"]: + if "attention" in f.attrs["Composite"] and "cutlass" in f.attrs["Composite"]: # Workspace is needed only for larger head sizes, but for simplicity we always allocate. out_dtype = f.ret_struct_info.dtype out_size_1d = _shape_1d(f.ret_struct_info.shape) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 8ec43f1f27f6..1faef9cceb05 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -260,7 +260,7 @@ def make_attention_pattern(with_bias: bool = False, var_len: bool = False): return out, annotations -def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): +def make_stacked_attention_pattern(start_op: str, with_bias: bool = False, layout="BS3NH"): """ Create pattern for fused multi head attention with stacked input. @@ -272,6 +272,9 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): with_bias: bool Whether or not to include bias addition + layout: str + The layout of the stacked input tensor. + Returns ------- pattern: DFPattern @@ -290,17 +293,28 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): key_raw = is_tuple_get_item(qkv_tuple, 1) value_raw = is_tuple_get_item(qkv_tuple, 2) elif start_op == "strided_slice": - ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(stacked_qkv) - ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(stacked_qkv) - ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) + ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) + ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) else: raise NotImplementedError() query_reshape_list = wildcard() key_reshape_list = wildcard() value_reshape_list = wildcard() - query = is_op("relax.reshape")(query_raw, query_reshape_list) - key = is_op("relax.reshape")(key_raw, key_reshape_list) - value = is_op("relax.reshape")(value_raw, value_reshape_list) + if layout == "BS3NH": + query = is_op("relax.reshape")(query_raw, query_reshape_list) + key = is_op("relax.reshape")(key_raw, key_reshape_list) + value = is_op("relax.reshape")(value_raw, value_reshape_list) + elif layout == "SBN3H": + ops["q_transpose"] = query = is_op("relax.permute_dims")(query_raw) + ops["k_transpose"] = key = is_op("relax.permute_dims")(key_raw) + ops["v_transpose"] = value = is_op("relax.permute_dims")(value_raw) annotations = { "stacked_qkv": stacked_qkv, "query_reshape_list": query_reshape_list, @@ -314,6 +328,10 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): out = is_op("relax.nn.attention_bias")(query, key, value, bias) else: out = is_op("relax.nn.attention")(query, key, value) + + if layout == "SBN3H": + out = is_op("relax.permute_dims")(out) + return out, annotations diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 725a930fd680..ec072f663cd5 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1568,11 +1568,14 @@ def scaled_dot_product_attention( Parameters ---------- query : Tensor - Tensor representing current attention lookup. + Tensor representing current attention lookup of shape + [batch, seq_len, num_heads, head_size]. key : Tensor - Tensor representing cross attention mapping. + Tensor representing cross attention mapping of shape + [batch, seq_len_kv, num_heads_kv, head_size]. value : Tensor - Tensor representing embedded attention values. + Tensor representing embedded attention values of shape + [batch, seq_len_kv, num_heads_kv, head_size_value]. attn_mask : Optional[Tensor] Optional mask for attention, not yet supported. is_causal : Optional[bool] diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index 4256ebc3be89..dc43d6c1f8ee 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -21,3 +21,4 @@ from .relay_translator import * from .ast_printer import dump_ast from .matmul import * +from .attention import * diff --git a/python/tvm/relax/testing/attention.py b/python/tvm/relax/testing/attention.py new file mode 100644 index 000000000000..a00674394ba2 --- /dev/null +++ b/python/tvm/relax/testing/attention.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relax script for attention module.""" +import tvm +from tvm.script import relax as R, tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def get_relax_attention_module( + q_shape, + k_shape, + v_shape, + *, + dtype, + bias_shape=None, + qk_scale=None, + causal_mask=None, + window_size=None, +): # pylint: disable=too-many-arguments, too-many-locals, invalid-name + """Get a relax module for attention.""" + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + if window_size is not None: + window_size = T.IntImm("int32", window_size) + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + q = R.arg("q", R.Tensor(q_shape, dtype)) + k = R.arg("k", R.Tensor(k_shape, dtype)) + v = R.arg("v", R.Tensor(v_shape, dtype)) + bias = None + if bias_shape is not None and bias_shape != "none": + bias = R.arg("bias", R.Tensor(bias_shape, dtype)) + + with R.dataflow() as frame: + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask, window_size)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_relax_stacked_attention_module( + qkv, + b, + s, + n, + h, + h_v, + op, + bias=None, + qk_scale=None, + single_shape=False, + layout="BS3NH", +): # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, invalid-name + # pylint: disable=too-many-statements + """Get a relax module for stacked attention.""" + dtype = str(qkv.dtype) + assert layout in ["BS3NH", "SBN3H"] + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + if single_shape: + if layout == "BS3NH": + qk_shape = R.shape([b, s, n, h]) + elif layout == "SBN3H": + qk_shape = R.shape([b, s, n, h]) + v_shape = qk_shape + else: + if layout == "BS3NH": + qk_shape = [b, s, n, h] + v_shape = [b, s, n, h_v] + elif layout == "SBN3H": + qk_shape = [s, b, n, h] + v_shape = [s, b, n, h_v] + + if layout == "BS3NH": + split_axis = 2 + split_sections = [n * h, n * h * 2] + elif layout == "SBN3H": + split_axis = 3 + split_sections = [h, h * 2] + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + if op == "split": + qkv_tuple = R.split(qkv, split_sections, axis=split_axis) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + elif op == "strided_slice": + q = R.strided_slice(qkv, [split_axis], [0], [split_sections[0]], [1]) + k = R.strided_slice( + qkv, [split_axis], [split_sections[0]], [split_sections[1]], [1] + ) + v = R.strided_slice( + qkv, + [split_axis], + [split_sections[1]], + [int(qkv.struct_info.shape[split_axis])], + [1], + ) + else: + raise NotImplementedError() + if layout == "BS3NH": + q = R.reshape(q, qk_shape) + k = R.reshape(k, qk_shape) + v = R.reshape(v, v_shape) + elif layout == "SBN3H": + q = R.permute_dims(q, [1, 0, 2, 3]) + k = R.permute_dims(k, [1, 0, 2, 3]) + v = R.permute_dims(v, [1, 0, 2, 3]) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) + if layout == "SBN3H": + result = R.emit(R.permute_dims(result, [1, 0, 2, 3])) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 72a7cedc491c..1486e9986e0e 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -84,3 +84,4 @@ from .searchsorted import searchsorted_ref from .conv2d_backcward_weight_python import conv2d_backward_weight_python from .lstm_python import lstm_python +from .attention_python import attention_python diff --git a/python/tvm/topi/testing/attention_python.py b/python/tvm/topi/testing/attention_python.py new file mode 100644 index 000000000000..856667aeddd1 --- /dev/null +++ b/python/tvm/topi/testing/attention_python.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Attention operator in python""" +from typing import Optional +import numpy as np +from .softmax_python import softmax_python + + +def attention_python( + q: np.ndarray, + k: np.ndarray, + v: np.ndarray, + bias: Optional[np.ndarray], + qk_scale: float, + causal: str, + window_size: Optional[int] = None, + layout: str = "BSNH", +): # pylint: disable=too-many-arguments, too-many-locals, invalid-name + """Attention operator in python + + Parameters + ---------- + q : np.ndarray + Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by + `layout`. + k : np.ndarray + Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the layout specified + by `layout`. + v : np.ndarray + Value tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim_v] in the layout + specified by `layout`. + bias : np.ndarray + Bias tensor with shape [batch, num_heads, seq_length, seq_length] + qk_scale : float + Scale factor for the query-key product. + causal : str + The type of causal mask to apply. Can be "none", "TopLeft", or "BottomRight". + window_size : Optional[int] + The window size for the causal mask. + layout : str + The layout of the input tensors, e.g. "BSNH" or "BNSH". + + Returns + ------- + np.ndarray + The output tensor with shape [batch, seq_length, num_heads, head_dim_v] in the layout + specified by `layout`. + """ + assert layout in ["BSNH", "BNSH", "SBNH"] + + dim_b = layout.find("B") + dim_s = layout.find("S") + dim_n = layout.find("N") + dim_h = layout.find("H") + + q = q.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s, h + k = k.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s_kv, h + kt = k.transpose(0, 1, 3, 2) # b, n, h, s_kv + v = v.transpose(dim_b, dim_n, dim_s, dim_h) + + num_heads = q.shape[1] + num_kv_heads = k.shape[1] + s = q.shape[2] + s_kv = k.shape[2] + + if num_heads != num_kv_heads: + assert num_heads % num_kv_heads == 0 + factor = num_heads // num_kv_heads + kt = np.repeat(kt, factor, axis=1) + v = np.repeat(v, factor, axis=1) + + if not qk_scale == "none": + score = q @ kt * qk_scale # b, n, s, s_kv + else: + score = q @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + if bias is not None: + score = score + bias # b, n, s, s_kv + if causal == "none": + attn = softmax_python(score, -1) + else: + if causal == "TopLeft": + offset = 0 + elif causal == "BottomRight": + offset = abs(s - s_kv) + else: + raise ValueError(f"Unsupported causal type: {causal}") + score_masked = np.tril(score, k=offset) + + if window_size: + score_masked = np.triu( + score_masked, -window_size + 1 # pylint: disable=invalid-unary-operand-type + ) + + score_masked_exp = np.tril( + np.exp(score_masked - np.max(score_masked, axis=-1, keepdims=True)), k=offset + ) + + if window_size: + score_masked_exp = np.triu( + score_masked_exp, -window_size + 1 # pylint: disable=invalid-unary-operand-type + ) + + score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True) + attn = np.divide(score_masked_exp, score_masked_sum) + + out = attn @ v # b, n, s, h_v + return out.transpose(*np.argsort([dim_b, dim_n, dim_s, dim_h]).tolist()) diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index 812016b8eafa..d8ca5f4e97f4 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -55,6 +55,17 @@ class cuDNNJSONSerializer : public JSONSerializer { std::string composite_name = composite_opt.value(); + if (composite_name.find("cudnn.conv2d") != std::string::npos) { + return HandleConv2D(call_node, fn, composite_name); + } else if (composite_name.find("cudnn.attention") != std::string::npos) { + return HandleAttention(call_node, fn, composite_name); + } else { + LOG(FATAL) << "Unsupported composite function: " << composite_name; + } + } + + NodeEntries HandleConv2D(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { NodeEntries inputs_tmp; for (const auto& arg : call_node->args) { auto res = VisitExpr(arg); @@ -80,6 +91,42 @@ class cuDNNJSONSerializer : public JSONSerializer { return AddNode(node, GetRef(call_node)); } + NodeEntries HandleAttention(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + std::string layout = composite_name.substr(composite_name.find_last_of(".") + 1); + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + ICHECK_EQ(inputs.size(), 2); + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.nn.attention"); + auto q_shape = Downcast( + Downcast(root_call->args[0]->struct_info_.value())->shape.value()); + auto k_shape = Downcast( + Downcast(root_call->args[1]->struct_info_.value())->shape.value()); + auto v_shape = Downcast( + Downcast(root_call->args[2]->struct_info_.value())->shape.value()); + int num_heads = q_shape->values[2].as()->value; + int num_kv_heads = k_shape->values[2].as()->value; + int head_size = q_shape->values[3].as()->value; + int head_size_v = v_shape->values[3].as()->value; + SetCallNodeAttribute(node, root_call); + + auto to_str_array = [](int val) { + return std::vector{std::vector{std::to_string(val)}}; + }; + node->SetAttr("num_heads", to_str_array(num_heads)); + node->SetAttr("num_kv_heads", to_str_array(num_kv_heads)); + node->SetAttr("head_size", to_str_array(head_size)); + node->SetAttr("head_size_v", to_str_array(head_size_v)); + node->SetAttr("layout", std::vector{std::vector{layout}}); + return AddNode(node, GetRef(call_node)); + } + private: /*! \brief The bindings to look up composite functions. */ Map bindings_; diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 1d4a0177126a..05aa8ce5528d 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -66,8 +66,10 @@ class ExternFunctionRewriter : ExprMutator { } new_params.push_back(workspace_param); + auto new_attrs = func_node->attrs; + new_attrs.CopyOnWrite()->dict.erase(attr::kWorkspaceSize); return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info, - func_node->is_pure, func_node->attrs); + func_node->is_pure, new_attrs); } return ExprMutator::VisitExpr_(func_node); } @@ -122,6 +124,7 @@ class WorkspaceProvider : ExprMutator { builder_->UpdateFunction(new_gvar, WithAttr(f, tvm::attr::kGlobalSymbol, new_gvar->name_hint)); gvar_map_[gvar] = new_gvar; + new_gvars_.insert(new_gvar); builder_->GetContextIRModule()->Remove(GetRef(gvar)); } @@ -164,8 +167,7 @@ class WorkspaceProvider : ExprMutator { auto new_op = VisitExpr(call_node->op); if (auto gv = new_op.as()) { - auto callee = builder_->GetContextIRModule()->Lookup(gv.value()); - if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) { + if (new_gvars_.count(gv.value())) { auto new_args = call_node->args; ICHECK(workspace_var_main_.defined()); new_args.push_back(workspace_var_main_); @@ -185,6 +187,7 @@ class WorkspaceProvider : ExprMutator { * the new ones that are transformed to take an additional workspace parameter. This is only * needed since the struct info of the global variables changes between transformation. */ std::unordered_map gvar_map_; + std::unordered_set new_gvars_; }; } // namespace relax diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 2be7ad41f3e1..6030a28d93b6 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -595,8 +596,7 @@ class FunctionCreator : public ExprMutator { } StructInfo param_sinfo = GetStructInfo(expr); - // Exclude PrimValues from arg/params to make composite functions contain PrimValues. - if (!expr->IsInstance()) { + if (!IsInlinableConstants(expr)) { Var param(std::move(name), GetStructInfo(expr)); arguments_.push_back(expr); params_.push_back(param); @@ -621,6 +621,21 @@ class FunctionCreator : public ExprMutator { return ExprMutator::VisitExpr(expr); } + // Check if the expression is constant PrimValue or ShapeExpr or tuple of them that can be + // inlined in the composite functions and excluded from args/params. + bool IsInlinableConstants(const Expr& expr) { + if (const auto* tuple = expr.as()) { + return std::all_of(tuple->fields.begin(), tuple->fields.end(), + [this](const Expr& e) { return IsInlinableConstants(e); }); + } else if (const auto* prim_value = expr.as()) { + return tvm::tir::UndefinedVars(prim_value->value).empty(); + } else if (const auto* shape_expr = expr.as()) { + return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), + [this](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); + } + return false; + } + private: /*! \brief The variables defined in this function */ std::unordered_set defined_vars_; diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc new file mode 100644 index 000000000000..f8b170fe2052 --- /dev/null +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.cc + * \brief cuDNN scale dot product attention implementation + */ + +#include "./attention.h" + +#include +#include + +#include "../../../cuda/cuda_common.h" +#include "../cudnn_utils.h" + +namespace tvm { +namespace contrib { + +void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, int64_t head_size_v, + double scale, const DLDataType& data_type, + const std::string& layout) { + graph_ = std::make_unique(); + + CHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) + << "Only float16 is supported"; + + graph_->set_io_data_type(cudnn_frontend::DataType_t::HALF) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto q_desc = cudnn_frontend::graph::Tensor_attributes().set_name("Q").set_uid(kTensorIDQ); + auto k_desc = cudnn_frontend::graph::Tensor_attributes().set_name("K").set_uid(kTensorIDK); + auto v_desc = cudnn_frontend::graph::Tensor_attributes().set_name("V").set_uid(kTensorIDV); + auto o_desc = cudnn_frontend::graph::Tensor_attributes().set_name("Out").set_uid(kTensorIDOut); + + std::vector q_stride, k_stride, v_stride, + o_stride; // stride in the order of (batch, num_heads, seq_len, head_size) + + if (layout == "BS3NH") { + int64_t stride_H = 1; + int64_t q_stride_N = head_size; + int64_t k_stride_N = head_size; + int64_t v_stride_N = head_size_v; + int64_t stride_S = + num_heads * q_stride_N + num_kv_heads * k_stride_N + num_kv_heads * v_stride_N; + int64_t stride_B = stride_S * seq_len; + q_stride = {stride_B, q_stride_N, stride_S, stride_H}; + k_stride = {stride_B, k_stride_N, stride_S, stride_H}; + v_stride = {stride_B, v_stride_N, stride_S, stride_H}; + o_stride = {seq_len * num_heads * head_size_v, head_size_v, num_heads * head_size_v, 1}; + offset_k_ = num_heads * head_size; + offset_v_ = offset_k_ + num_kv_heads * head_size; + } else if (layout == "SBN3H") { + CHECK_EQ(num_kv_heads, num_heads); + int64_t stride_H = 1; + int64_t stride_N = head_size + head_size + head_size_v; + int64_t stride_B = num_heads * stride_N; + int64_t stride_S = stride_B * batch; + q_stride = k_stride = v_stride = {stride_B, stride_N, stride_S, stride_H}; + o_stride = {num_heads * head_size_v, head_size_v, num_heads * head_size_v * batch, 1}; + offset_k_ = head_size; + offset_v_ = offset_k_ * 2; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + + q_desc = q_desc.set_dim({batch, num_heads, seq_len, head_size}).set_stride(q_stride); + k_desc = k_desc.set_dim({batch, num_kv_heads, seq_len, head_size}).set_stride(k_stride); + v_desc = v_desc.set_dim({batch, num_kv_heads, seq_len, head_size_v}).set_stride(v_stride); + auto sdpa_options = cudnn_frontend::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(true) + .set_alibi_mask(false) + .set_causal_mask(false) + .set_attn_scale(scale); + + auto q = graph_->tensor(q_desc); + auto k = graph_->tensor(k_desc); + auto v = graph_->tensor(v_desc); + auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options); + CHECK(stats == nullptr); + o->set_output(true).set_dim({batch, num_heads, seq_len, head_size_v}).set_stride(o_stride); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, {cudnn_frontend::HeurMode_t::A})); +} + +void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out) { + CUDNN_CALL( + cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle, tvm::runtime::GetCUDAStream())); + auto* qkv_base = reinterpret_cast(qkv->data) + qkv->byte_offset; + auto* q_ptr = reinterpret_cast(qkv_base) + offset_q_; + auto* k_ptr = reinterpret_cast(qkv_base) + offset_k_; + auto* v_ptr = reinterpret_cast(qkv_base) + offset_v_; + auto* out_ptr = reinterpret_cast(out->data) + out->byte_offset; + + size_t workspace_size = graph_->get_workspace_size(); + CHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small"; + std::unordered_map inputs = { + {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, {kTensorIDOut, out_ptr}}; + + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs, workspace->data)); +} + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h new file mode 100644 index 000000000000..4d0309fb3ba6 --- /dev/null +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.h + * \brief cuDNN scale dot product attention implementation + */ + +#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ +#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ + +#include +#include + +#include +#include + +#define CUDNN_FRONTEND_CALL(func) \ + do { \ + auto status = (func); \ + CHECK(status.is_good()) << status.get_message(); \ + } while (0) + +namespace tvm { +namespace contrib { + +class CuDNNSDPARunnerNode : public tvm::runtime::Object { + public: + CuDNNSDPARunnerNode() {} + + ~CuDNNSDPARunnerNode() {} + + static constexpr const char* _type_key = "contrib.cudnn.SDPARunner"; + + void Init(int64_t batch, int64_t seq_len, int64_t num_heads, int64_t num_kv_heads, + int64_t head_size, int64_t head_size_v, double scale, const DLDataType& data_type, + const std::string& layout); + + void Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out); + + static constexpr int kTensorIDQ = 0; + static constexpr int kTensorIDK = 1; + static constexpr int kTensorIDV = 2; + static constexpr int kTensorIDOut = 4; + + private: + std::unique_ptr graph_{nullptr}; + int64_t offset_q_{0}; + int64_t offset_k_{0}; + int64_t offset_v_{0}; +}; + +class CuDNNSDPARunner : public tvm::runtime::ObjectRef { + public: + static CuDNNSDPARunner Create() { + auto n = make_object(); + return CuDNNSDPARunner(n); + } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CuDNNSDPARunner, tvm::runtime::ObjectRef, + CuDNNSDPARunnerNode); +}; + +} // namespace contrib +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 7d701396d0ca..3f4b659275d4 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -31,6 +31,10 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" + +#ifdef TVM_USE_CUDNN_FRONTEND +#include "./cudnn_frontend/attention.h" +#endif #include "cudnn_utils.h" namespace tvm { @@ -47,78 +51,19 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { : JSONRuntimeBase(symbol_name, graph_json, const_names) {} void Init(const Array& consts) override { - auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(); - auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); - ICHECK(func != nullptr); - stream = static_cast((*func)().operator void*()); - - auto attr_in_name = [](const std::string& op_name, const std::string& attr_name) { - return op_name.find(attr_name) != std::string::npos; - }; - - auto vstr2vint = [](const JSONGraphNode& node, const std::string& attrStr) { - auto string_to_int = [](const std::string& str) { return std::stoi(str); }; - auto string_vec = node.GetAttr>(attrStr); - std::vector int_vec(string_vec.size()); - std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), string_to_int); - return int_vec; - }; + op_execs_.resize(nodes_.size()); // get some config from the graph for (size_t i = 0; i < nodes_.size(); ++i) { const auto& node = nodes_[i]; if (node.GetOpType() == "kernel") { - op_name = node.GetOpName(); - std::vector input_dims, kernel_dims, output_dims; - auto input_node = nodes_[0]; - auto input_shapes = input_node.GetOpShape()[0]; - auto kernel_node = nodes_[1]; - auto kernel_shapes = kernel_node.GetOpShape()[0]; - auto output_shapes = node.GetOpShape()[0]; - for (const auto& _i : input_shapes) { - input_dims.emplace_back(static_cast(_i)); - } - for (const auto& _i : kernel_shapes) { - kernel_dims.emplace_back(static_cast(_i)); + std::string op_name = node.GetOpName(); + if (op_name.find("conv2d") != std::string::npos) { + op_execs_[i] = GetConv2DExec(node); + } else if (op_name.find("attention") != std::string::npos) { + op_execs_[i] = GetAttentionExec(node); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; } - for (const auto& _i : output_shapes) { - output_dims.emplace_back(static_cast(_i)); - } - has_bias = attr_in_name(op_name, "bias"); - groups = std::stoi(node.GetAttr>("groups")[0]); - padding = vstr2vint(node, "padding"); - strides = vstr2vint(node, "strides"); - dilation = vstr2vint(node, "dilation"); - conv_dtype = node.GetAttr>("out_dtype")[0]; - std::string layout = node.GetAttr>("out_layout")[0]; - dims = layout.size() - 2; // remove O and I dims - - if (layout == "NCHW") - format = CUDNN_TENSOR_NCHW; - else if (layout == "NHWC") - format = CUDNN_TENSOR_NHWC; - else - LOG(FATAL) << "Unsupported layout: " << layout; - - if (attr_in_name(op_name, "relu")) { - act = CUDNN_ACTIVATION_RELU; - } else if (attr_in_name(op_name, "relu6")) { - act = CUDNN_ACTIVATION_CLIPPED_RELU; - coef = 6.0; - } else if (attr_in_name(op_name, "leaky_relu")) { - act = CUDNN_ACTIVATION_RELU; - coef = 0.1; - } - this->handle = entry_ptr->handle; - this->kernel_node = node; - - // find best algo - TVMRetValue best_algo; - - tvm::contrib::FindAlgo(format, dims, groups, padding.data(), strides.data(), - dilation.data(), input_dims.data(), kernel_dims.data(), - output_dims.data(), conv_dtype, conv_dtype, false, &best_algo); - - this->algo = best_algo.operator int(); } } } @@ -126,27 +71,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { const char* type_key() const override { return "cudnn_json"; } // May be overridden void Run() override { - auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { - const DLTensor* bias = nullptr; - if (has_bias) { - bias = GetInput(node, 2); + for (const auto& f : op_execs_) { + if (f != nullptr) { + f(); } - return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); - }; - - auto [a_ptr, b_ptr, bias_ptr] = get_inputs(kernel_node, has_bias); - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - - if (this->has_bias) { - tvm::contrib::ConvolutionBiasActivationForward( - this->mode, this->format, this->algo, this->dims, this->groups, this->act, this->coef, - this->padding.data(), this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr, - bias_ptr, this->conv_dtype); - } else { - tvm::contrib::ConvolutionForward( - this->mode, this->format, this->algo, this->dims, this->groups, this->padding.data(), - this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr, this->conv_dtype); } } @@ -157,27 +85,150 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { ICHECK(eid < data_entry_.size()); return data_entry_[eid]; } - /*conv op name*/ - std::string op_name; - /*conv mode: CUDNN_CROSS_CORRELATION by default*/ - int mode = CUDNN_CROSS_CORRELATION; - /*algo: by default we select the implicit gemm algo, will be tuned in the initial pass.*/ - int algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - /*if has bias*/ - bool has_bias = false; - /*args for function call*/ - int act = CUDNN_ACTIVATION_IDENTITY; - double coef = 1.0; - int format = CUDNN_TENSOR_NHWC; - int dims = 2; - int groups = 1; - std::vector padding; - std::vector strides; - std::vector dilation; - std::string conv_dtype; - cudaStream_t stream; - cudnnHandle_t handle; - tvm::runtime::json::JSONGraphNode kernel_node; + + bool attr_in_name(const std::string& op_name, const std::string& attr_name) { + return op_name.find(attr_name) != std::string::npos; + } + + std::vector vstr2vint(const JSONGraphNode& node, const std::string& attrStr) { + auto string_to_int = [](const std::string& str) { return std::stoi(str); }; + auto string_vec = node.GetAttr>(attrStr); + std::vector int_vec(string_vec.size()); + std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), string_to_int); + return int_vec; + } + + std::function GetConv2DExec(const JSONGraphNode& node) { + auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(); + auto op_name = node.GetOpName(); + + std::vector input_dims, kernel_dims, output_dims; + auto input_node = nodes_[0]; + auto input_shapes = input_node.GetOpShape()[0]; + auto kernel_shapes = nodes_[1].GetOpShape()[0]; + auto output_shapes = node.GetOpShape()[0]; + for (const auto& _i : input_shapes) { + input_dims.emplace_back(static_cast(_i)); + } + for (const auto& _i : kernel_shapes) { + kernel_dims.emplace_back(static_cast(_i)); + } + for (const auto& _i : output_shapes) { + output_dims.emplace_back(static_cast(_i)); + } + bool has_bias = attr_in_name(op_name, "bias"); + int groups = std::stoi(node.GetAttr>("groups")[0]); + std::vector padding = vstr2vint(node, "padding"); + std::vector strides = vstr2vint(node, "strides"); + std::vector dilation = vstr2vint(node, "dilation"); + auto conv_dtype = node.GetAttr>("out_dtype")[0]; + std::string layout = node.GetAttr>("out_layout")[0]; + int dims = layout.size() - 2; // remove O and I dims + + int format = CUDNN_TENSOR_NHWC; + if (layout == "NCHW") { + format = CUDNN_TENSOR_NCHW; + } else if (layout == "NHWC") { + format = CUDNN_TENSOR_NHWC; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + + int act = CUDNN_ACTIVATION_IDENTITY; + double coef = 1.0; + if (attr_in_name(op_name, "relu")) { + act = CUDNN_ACTIVATION_RELU; + } else if (attr_in_name(op_name, "relu6")) { + act = CUDNN_ACTIVATION_CLIPPED_RELU; + coef = 6.0; + } else if (attr_in_name(op_name, "leaky_relu")) { + act = CUDNN_ACTIVATION_RELU; + coef = 0.1; + } + + /*conv mode: CUDNN_CROSS_CORRELATION by default*/ + int mode = CUDNN_CROSS_CORRELATION; + + // find best algo + TVMRetValue best_algo; + + tvm::contrib::FindAlgo(format, dims, groups, padding.data(), strides.data(), dilation.data(), + input_dims.data(), kernel_dims.data(), output_dims.data(), conv_dtype, + conv_dtype, false, &best_algo); + + int algo = best_algo.operator int(); + std::function op_exec = [=]() { + auto stream = static_cast(GetCUDAStream()); + CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); + + auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = GetInput(node, 2); + } + return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); + }; + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, has_bias); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + if (has_bias) { + tvm::contrib::ConvolutionBiasActivationForward( + mode, format, algo, dims, groups, act, coef, padding.data(), strides.data(), + dilation.data(), a_ptr, b_ptr, out_ptr, bias_ptr, conv_dtype); + } else { + tvm::contrib::ConvolutionForward(mode, format, algo, dims, groups, padding.data(), + strides.data(), dilation.data(), a_ptr, b_ptr, out_ptr, + conv_dtype); + } + }; + return op_exec; + } + + std::function GetAttentionExec(const JSONGraphNode& node) { +#ifdef TVM_USE_CUDNN_FRONTEND + auto dtype = node.GetOpDataType()[0]; + int num_heads = vstr2vint(node, "num_heads")[0]; + int num_kv_heads = vstr2vint(node, "num_kv_heads")[0]; + int head_size = vstr2vint(node, "head_size")[0]; + int head_size_v = vstr2vint(node, "head_size_v")[0]; + std::string layout = node.GetAttr>("layout")[0]; + const auto& input_qkv_node = nodes_[EntryID(node.GetInputs()[0])]; + auto qkv_shapes = input_qkv_node.GetOpShape()[0]; + + int64_t batch, seq_len; + if (layout == "BS3NH") { + ICHECK_EQ(qkv_shapes.size(), 3); + batch = qkv_shapes[0]; + seq_len = qkv_shapes[1]; + } else if (layout == "SBN3H") { + ICHECK_EQ(qkv_shapes.size(), 4); + batch = qkv_shapes[1]; + seq_len = qkv_shapes[0]; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + double scale = 1 / std::sqrt(head_size); + std::string scale_attr = node.GetAttr>("scale")[0]; + if (scale_attr.size()) { + scale = std::stod(scale_attr); + } + + auto runner = tvm::contrib::CuDNNSDPARunner::Create(); + runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size, head_size_v, scale, dtype, + layout); + return [=]() { + auto qkv = GetInput(node, 0); + auto workspace = const_cast(GetInput(node, 1)); + auto out = const_cast(data_entry_[EntryID(outputs_[0])]); + runner->Run(qkv, workspace, out); + }; +#else + LOG(FATAL) << "Please build with CUDNN frontend to use attention op"; +#endif + } + + std::vector> op_execs_; }; runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 0f911905f820..59f49bfde889 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -22,7 +22,8 @@ import tvm.topi.testing from tvm import relax from tvm.relax.backend.contrib.cudnn import partition_for_cudnn -from tvm.relax.testing import get_relax_matmul_module +from tvm.relax.testing import get_relax_matmul_module, get_relax_stacked_attention_module +from tvm.contrib.pickle_memoize import memoize from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder @@ -99,7 +100,7 @@ def get_relax_conv2d_module( def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False): mod = partition_for_cudnn(mod) mod = relax.transform.RunCodegen()(mod) - return build_and_run(mod, np_inputs, "cuda", cuda_graph) + return build_and_run(mod, np_inputs, "cuda", cuda_graph=cuda_graph) def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): @@ -244,5 +245,65 @@ def test_conv2d_nchw_oihw_offload(data_shape, weight_shape, dtype, with_bias, ac tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@memoize("topi.tests.test_codegen_cudnn.test_stacked_attention_offload") +def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, dtype, layout): + if layout == "BS3NH": + qkv = np.random.randn(b, s, n * h * 2 + n * h_v).astype(dtype) + split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2) + q = split_qkv[0].reshape(b, s, n, h) + k = split_qkv[1].reshape(b, s, n, h) + v = split_qkv[2].reshape(b, s, n, h_v) + layout = "BSNH" + elif layout == "SBN3H": + qkv = np.random.randn(s, b, n, h * 2 + h_v).astype(dtype) + q, k, v = np.split(qkv, [h, h * 2], axis=3) + layout = "SBNH" + else: + raise ValueError("Unsupported layout: {}".format(layout)) + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias # b, n, s, s + else: + bias = None + ref = tvm.topi.testing.attention_python(q, k, v, bias, qk_scale, "none", None, layout) + return qkv, bias, ref + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape scale, single_shape, layout + (4, 8, 32, (64, 32), "none", 1.0, False, "BS3NH"), + (4, 8, 32, (64, 64), "none", "none", True, "BS3NH"), + (4, 8, 32, (64, 32), "none", 1.0, False, "SBN3H"), + (4, 8, 32, (64, 64), "none", "none", True, "SBN3H"), + ] +) +def stacked_attention_size(request): + return request.param + + +@pytest.mark.skip(reason="require cudnn frontend") +def test_stacked_attention_split_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, scale, single_shape, layout = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, scale, "float16", layout + ) + if scale == "none": + mod = get_relax_stacked_attention_module( + qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape, layout=layout + ) + scale = 1.0 / np.sqrt(h) + else: + mod = get_relax_stacked_attention_module( + qkv, b, s, n, h, h_v, "split", bias, scale, single_shape=single_shape, layout=layout + ) + + if bias is None: + out = get_result_with_relax_cudnn_offload(mod, [qkv]) + else: + out = get_result_with_relax_cudnn_offload(mod, [qkv, bias]) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=2e-2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 969651f72fd4..3fa3f2d914d7 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -24,7 +24,11 @@ from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend.contrib.cutlass import partition_for_cutlass -from tvm.relax.testing import get_relax_matmul_module +from tvm.relax.testing import ( + get_relax_matmul_module, + get_relax_attention_module, + get_relax_stacked_attention_module, +) from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T @@ -594,47 +598,6 @@ def attention_size(request): return request.param -def get_relax_attention_module( - q_shape, - k_shape, - v_shape, - *, - dtype, - bias_shape=None, - qk_scale=None, - causal_mask=None, - window_size=None, -): - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder - from tvm.script.ir_builder import tir as T - - if qk_scale is not None: - qk_scale = T.FloatImm("float32", qk_scale) - - if window_size is not None: - window_size = T.IntImm("int32", window_size) - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - q = R.arg("q", R.Tensor(q_shape, dtype)) - k = R.arg("k", R.Tensor(k_shape, dtype)) - v = R.arg("v", R.Tensor(v_shape, dtype)) - bias = None - if bias_shape is not None and bias_shape != "none": - bias = R.arg("bias", R.Tensor(bias_shape, dtype)) - - with R.dataflow() as frame: - result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask, window_size)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) - - def get_numpy_attention_ref( b, s, @@ -649,59 +612,20 @@ def get_numpy_attention_ref( window_size=None, num_kv_head=None, ): - if num_kv_head is None: - num_kv_head = n - + num_kv_head = num_kv_head or n q = np.random.randn(b, s, n, h).astype(dtype) - k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) - v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) - - if num_kv_head is None: - k = k_orig - v = v_orig - else: - factor = n // num_kv_head - k = np.repeat(k_orig, factor, axis=2) - v = np.repeat(v_orig, factor, axis=2) - - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - if not qk_scale == "none": - score = qt @ kt * qk_scale # b, n, s, s_kv - else: - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - if not bias_shape == "none": - bias = np.random.randn(*bias_shape).astype(dtype) - score = score + bias # b, n, s, s_kv - else: + k = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) + v = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) + if bias_shape == "none": bias = None - if causal == "none": - attn = tvm.topi.testing.softmax_python(score, -1) else: - if causal == "TopLeft": - offset = 0 - elif causal == "BottomRight": - offset = abs(s - s_kv) - else: - raise NotImplementedError() - score_masked = np.tril(score, k=offset) - - if window_size: - score_masked = np.triu(score_masked, -window_size + 1) - - score_masked_exp = np.tril( - np.exp(score_masked - np.max(score_masked, axis=-1, keepdims=True)), k=offset - ) - - if window_size: - score_masked_exp = np.triu(score_masked_exp, -window_size + 1) + bias = np.random.randn(*bias_shape).astype(dtype) - score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True) - attn = np.divide(score_masked_exp, score_masked_sum) + ref = tvm.topi.testing.attention_python( + q, k, v, bias, qk_scale, causal=causal, window_size=window_size, layout="BSNH" + ) - vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v - ref = attn @ vt # b, n, s, h_v - return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + return q, k, v, bias, ref def test_attention_offload(attention_size, attention_dtype): @@ -844,69 +768,14 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, dtype q = np.reshape(split_qkv[0], (b, s, n, h)) k = np.reshape(split_qkv[1], (b, s, n, h)) v = np.reshape(split_qkv[2], (b, s, n, h_v)) - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s - if not qk_scale == "none": - score = qt @ kt * qk_scale # b, n, s, s - else: - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s if not bias_shape == "none": bias = np.random.randn(*bias_shape).astype(dtype) - score = score + bias # b, n, s, s else: bias = None - attn = tvm.topi.testing.softmax_python(score, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v - ref = attn @ vt # b, n, s, h_v - return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v - - -def get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False -): - dtype = str(qkv.dtype) - - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder - from tvm.script.ir_builder import tir as T - - if qk_scale is not None: - qk_scale = T.FloatImm("float32", qk_scale) - - if single_shape: - qk_shape = R.shape([b, s, n, h]) - v_shape = qk_shape - else: - qk_shape = [b, s, n, h] - v_shape = [b, s, n, h_v] - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) - if bias is not None: - bias = R.arg("bias", R.Tensor(bias.shape, dtype)) - with R.dataflow() as frame: - if op == "split": - qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) - q = R.reshape(qkv_tuple[0], qk_shape) - k = R.reshape(qkv_tuple[1], qk_shape) - v = R.reshape(qkv_tuple[2], v_shape) - elif op == "strided_slice": - q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), qk_shape) - k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), qk_shape) - v = R.reshape( - R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]), v_shape - ) - else: - raise NotImplementedError() - result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) + ref = tvm.topi.testing.attention_python( + q, k, v, bias, qk_scale, causal="none", window_size=None, layout="BSNH" + ) + return qkv, bias, ref @pytest.fixture( @@ -926,11 +795,30 @@ def test_stacked_attention_split_offload(stacked_attention_size): qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16") if scale == "none": mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "split", + bias, + single_shape=single_shape, + layout="BS3NH", ) else: mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "split", bias, scale, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "split", + bias, + scale, + single_shape=single_shape, + layout="BS3NH", ) if bias is None: @@ -945,11 +833,30 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size): qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32") if scale == "none": mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "strided_slice", bias, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "strided_slice", + bias, + single_shape=single_shape, + layout="BS3NH", ) else: mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "strided_slice", + bias, + scale, + single_shape=single_shape, + layout="BS3NH", ) if bias is None: out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2) diff --git a/tests/python/relax/test_transform_allocate_workspace.py b/tests/python/relax/test_transform_allocate_workspace.py index 1198642d3f35..248d195d654b 100644 --- a/tests/python/relax/test_transform_allocate_workspace.py +++ b/tests/python/relax/test_transform_allocate_workspace.py @@ -95,7 +95,6 @@ def fused_relax_nn_attention_cutlass1( R.func_attr( { "Codegen": "cutlass", - "WorkspaceSize": 65536, "global_symbol": "fused_relax_nn_attention_cutlass1", } ) @@ -107,7 +106,7 @@ def gv( v_1: R.Tensor((32, 8, 16, 8), dtype="float16"), workspace_1: R.Tensor((65536,), dtype="uint8"), ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): - R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536}) + R.func_attr({"Composite": "cutlass.attention", "Primitive": 1}) with R.dataflow(): gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = R.nn.attention( q_1, k_1, v_1, scale=None diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 6a36314a7444..cff832a21ff9 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -1053,7 +1053,6 @@ class Expected: @R.function def fused_relax_reshape_relax_matmul_tensorrt( inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), - param_0: R.Shape([1, 784]), lv1: R.Tensor((784, 512), dtype="float32"), ) -> R.Tensor((1, 512), dtype="float32"): R.func_attr({"Codegen": "tensorrt"}) @@ -1069,7 +1068,7 @@ def lv_1( R.output(gv) return gv - lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0) + lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, R.shape([1, 784])) @R.function def lv1_1_1( @@ -1100,7 +1099,7 @@ def main( ) gv: R.Tensor( (1, 512), dtype="float32" - ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, R.shape([1, 784]), lv1) + ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1) R.output(gv) return gv