Skip to content

Commit

Permalink
do not offload depthwise conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 13, 2021
1 parent 5ae0fc8 commit 8e025c4
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,22 @@ def check_batch_matmul(call):
return check_dtype(lhs, rhs) and not transpose_a and transpose_b


def is_depthwise_conv2d(ic, oc, groups):
return ic == oc == groups


def check_conv2d(call):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
data_layout = conv2d.attrs.data_layout
kernel_layout = conv2d.attrs.kernel_layout
data = conv2d.args[0].checked_type
weight = conv2d.args[1].checked_type
return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight)
if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight):
return False
IC = data.shape[3]
OC = weight.shape[0]
return not is_depthwise_conv2d(IC, OC, call.attrs.groups)


def partition_for_cutlass(mod):
Expand Down

0 comments on commit 8e025c4

Please sign in to comment.