From 0de5ebdb2e318129a1a4d3a2e568849a993525c4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 12 Dec 2021 18:39:08 +0900 Subject: [PATCH] partition working --- python/tvm/contrib/cutlass/build.py | 3 +++ python/tvm/relay/op/contrib/cutlass.py | 2 +- tests/python/contrib/test_cutlass.py | 16 ++++++++++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 558056852d5ec..4923bdf8d168f 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -87,6 +87,9 @@ def visit_call(self, call): if str(op) == "nn.conv2d": self.op_attrs = call.attrs + for arg in call.args: + self.visit(arg) + def select_gemm_kernel( cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index b05c52cadc023..2d65c3ecb7524 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -56,7 +56,7 @@ def make_batch_matmul_pattern(): return is_op("nn.batch_matmul")(wildcard(), wildcard()) -def make_conv2d_pattern(with_bias=True, with_act=None, out_dtype="float16"): +def make_conv2d_pattern(with_bias=False, with_act=None, out_dtype="float16"): """Create a pattern for dense op followed by activations.""" data = wildcard() weight = wildcard() diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 672eadc96944d..d0f9d78a48be7 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -385,16 +385,24 @@ def test_conv2d(): ) +d_shape = (16, 16, 32, 32) +w_shape = (32, 16, 3, 3) +padding = (1, 1) + + def test_conv2d_bias(): - d_shape = (16, 16, 32, 32) - w_shape = (32, 16, 3, 3) - padding = (1, 1) mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding) + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + +def test_conv2d_bias_relu(): + mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding) verify_conv2d( mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ) if __name__ == "__main__": - test_conv2d_bias() + test_conv2d_bias_relu()