Skip to content

Commit

Permalink
add fused conv2d pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent 1c0bbb2 commit 81bf9e6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def handle_conv2d(

if op_type == "cutlass.conv2d":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.conv2d_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.conv2d_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down
25 changes: 22 additions & 3 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,22 @@ def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def make_conv2d_pattern():
return is_op("nn.conv2d")(wildcard(), wildcard())
def make_conv2d_pattern(with_bias=True, with_act=None, out_dtype="float16"):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
bias = wildcard()
conv2d = is_op("nn.conv2d")(data, weight)
if with_bias:
add_or_bias_add = is_op("add") | is_op("nn.bias_add")
conv2d_out = add_or_bias_add(conv2d, bias)
else:
conv2d_out = conv2d

if with_act is not None and with_act == "relu":
return is_op("nn.relu")(conv2d_out)

return conv2d_out


def check_dtype(lhs, rhs):
Expand Down Expand Up @@ -130,7 +144,12 @@ def partition_for_cutlass(mod):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
# TODO(masahi): Add more conv2d patterns
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
check_conv2d,
),
("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
seq = Sequential(
Expand Down

0 comments on commit 81bf9e6

Please sign in to comment.