diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 90b7f9320c7f9..558056852d5ec 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -215,6 +215,10 @@ def handle_conv2d( if op_type == "cutlass.conv2d": cutlass_op_def = out["opdef"] + elif op_type == "cutlass.conv2d_bias": + cutlass_op_def = out["opdef_bias"] + elif op_type == "cutlass.conv2d_bias_relu": + cutlass_op_def = out["opdef_bias_relu"] else: raise ValueError("%s pattern is not implemented." % op_type) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index a4371442b5571..b05c52cadc023 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -56,8 +56,22 @@ def make_batch_matmul_pattern(): return is_op("nn.batch_matmul")(wildcard(), wildcard()) -def make_conv2d_pattern(): - return is_op("nn.conv2d")(wildcard(), wildcard()) +def make_conv2d_pattern(with_bias=True, with_act=None, out_dtype="float16"): + """Create a pattern for dense op followed by activations.""" + data = wildcard() + weight = wildcard() + bias = wildcard() + conv2d = is_op("nn.conv2d")(data, weight) + if with_bias: + add_or_bias_add = is_op("add") | is_op("nn.bias_add") + conv2d_out = add_or_bias_add(conv2d, bias) + else: + conv2d_out = conv2d + + if with_act is not None and with_act == "relu": + return is_op("nn.relu")(conv2d_out) + + return conv2d_out def check_dtype(lhs, rhs): @@ -130,7 +144,12 @@ def partition_for_cutlass(mod): dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), - # TODO(masahi): Add more conv2d patterns + ( + "cutlass.conv2d_bias_relu", + make_conv2d_pattern(with_bias=True, with_act="relu"), + check_conv2d, + ), + ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] seq = Sequential(