diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 80c8ebffeb0f9..bf17572f25009 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Patterns supported CUTLASS.""" +from tvm.ir.transform import Sequential from tvm.relay import transform from ...dataflow_pattern import wildcard, is_op, is_constant @@ -61,29 +62,41 @@ def make_conv2d_pattern(): def check_dtype(lhs, rhs): """Check if dtypes in the given workload are supported by CUTLASS.""" + # Only fp16 inputs are supported for now. return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16" +def get_root_call(call, root_op_name): + if str(call.op) == root_op_name: + return call + return get_root_call(call.args[0], root_op_name) + + def check_gemm(call): """Check if the given dense workload can be offloaded to CUTLASS.""" - lhs = call.args[0].checked_type - rhs = call.args[1].checked_type + dense = get_root_call(call, "nn.dense") + lhs = dense.args[0].checked_type + rhs = dense.args[1].checked_type return check_dtype(lhs, rhs) def check_batch_matmul(call): """Check if the given batch_matmul workload can be offloaded to CUTLASS.""" - transpose_a = call.attrs.transpose_a - transpose_b = call.attrs.transpose_b - return check_gemm(call) and transpose_a == False and transpose_b == True + batch_matmul = get_root_call(call, "nn.batch_matmul") + lhs = batch_matmul.args[0].checked_type + rhs = batch_matmul.args[1].checked_type + transpose_a = batch_matmul.attrs.transpose_a + transpose_b = batch_matmul.attrs.transpose_b + return check_dtype(lhs, rhs) and transpose_a == False and transpose_b == True def check_conv2d(call): """Check if the given conv2d workload can be offloaded to CUTLASS.""" - data_layout = call.attrs.data_layout - kernel_layout = call.attrs.kernel_layout - data = call.args[0].checked_type - weight = call.args[1].checked_type + conv2d = get_root_call(call, "nn.conv2d") + data_layout = conv2d.attrs.data_layout + kernel_layout = conv2d.attrs.kernel_layout + data = conv2d.args[0].checked_type + weight = conv2d.args[1].checked_type return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight) @@ -112,7 +125,12 @@ def partition_for_cutlass(mod): # TODO(masahi): Add more conv2d patterns ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] - mod = transform.MergeComposite(cutlass_patterns)(mod) - mod = transform.AnnotateTarget(["cutlass"])(mod) - mod = transform.PartitionGraph()(mod) - return mod + seq = Sequential( + [ + transform.InferType(), + transform.MergeComposite(cutlass_patterns), + transform.AnnotateTarget(["cutlass"]), + transform.PartitionGraph(), + ] + ) + return seq(mod)