Skip to content

Commit

Permalink
partition working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent c08bb38 commit 0de5ebd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def visit_call(self, call):
if str(op) == "nn.conv2d":
self.op_attrs = call.attrs

for arg in call.args:
self.visit(arg)


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def make_conv2d_pattern(with_bias=True, with_act=None, out_dtype="float16"):
def make_conv2d_pattern(with_bias=False, with_act=None, out_dtype="float16"):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
Expand Down
16 changes: 12 additions & 4 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,16 +385,24 @@ def test_conv2d():
)


d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)


def test_conv2d_bias():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)


def test_conv2d_bias_relu():
mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)


if __name__ == "__main__":
test_conv2d_bias()
test_conv2d_bias_relu()

0 comments on commit 0de5ebd

Please sign in to comment.