diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index a258da3c5d78..5ffc66bb0acb 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -312,13 +312,14 @@ def convert_conv2d_layout(mod, desired_layouts): def verify_conv2d( - mod_nchw, - mod_ref, + mod_nchw, # can be dynamic batch + mod_ref, # always static batch d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, + use_cudnn_ref=False, run_benchmark=False, ): if not has_cutlass(): @@ -332,14 +333,14 @@ def verify_conv2d( typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}) + if use_vm: - rt_mod, dev, num_cutlass_partition = profile_and_build_vm( - convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm - ) + rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm) out = get_output_vm(rt_mod, ["data"], [np_data]) else: - rt_mod, dev, num_cutlass_partition = profile_and_build( - convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), + rt_mod, _, num_cutlass_partition = profile_and_build( + mod_weight_ohwi, params, sm, ) @@ -347,37 +348,51 @@ def verify_conv2d( assert num_cutlass_partition > 0 - rt_mod_ref, _ = get_ref_rt_mod( - convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), - params, - target="cuda", - ) - ref_out = get_output(rt_mod_ref, ["data"], [np_data]) + if use_cudnn_ref: + rt_mod_ref, dev = get_ref_rt_mod( + convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}), + params, + target="cuda -libs=cudnn", + ) + else: + rt_mod_ref, dev = get_ref_rt_mod( + convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), + params, + target="cuda", + ) - np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + ref_out = get_output(rt_mod_ref, ["data"], [np_data]) if run_benchmark: print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + def test_conv2d(): + for IC in [3, 16]: + d_shape = (16, IC, 32, 32) + w_shape = (32, IC, 3, 3) + mod_nchw = get_conv2d_nchw(d_shape, w_shape) + + verify_conv2d( + mod_nchw, + mod_nchw, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + use_cudnn_ref=IC == 3, + run_benchmark=False, + ) + d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) - mod_nchw = get_conv2d_nchw(d_shape, w_shape) - - verify_conv2d( - mod_nchw, - mod_nchw, - d_shape, - w_shape, - sm=80, - atol=1e-5, - rtol=1e-5, - run_benchmark=False, - ) - dyn_batch_shape = (relay.Any(),) + d_shape[1:] + + mod_nchw = get_conv2d_nchw(d_shape, w_shape) mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape) verify_conv2d(