diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 098619cc7a821..57d67b24badf4 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -90,6 +90,10 @@ def check_batch_matmul(call): return check_dtype(lhs, rhs) and not transpose_a and transpose_b +def is_depthwise_conv2d(ic, oc, groups): + return ic == oc == groups + + def check_conv2d(call): """Check if the given conv2d workload can be offloaded to CUTLASS.""" conv2d = get_root_call(call, "nn.conv2d") @@ -97,7 +101,11 @@ def check_conv2d(call): kernel_layout = conv2d.attrs.kernel_layout data = conv2d.args[0].checked_type weight = conv2d.args[1].checked_type - return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight) + if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight): + return False + IC = data.shape[3] + OC = weight.shape[0] + return not is_depthwise_conv2d(IC, OC, call.attrs.groups) def partition_for_cutlass(mod):