Skip to content

Commit

Permalink
test IC=3 convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent ffce47d commit 8d6a1bf
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,14 @@ def convert_conv2d_layout(mod, desired_layouts):


def verify_conv2d(
mod_nchw,
mod_ref,
mod_nchw, # can be dynamic batch
mod_ref, # always static batch
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
run_benchmark=False,
):
if not has_cutlass():
Expand All @@ -332,52 +333,66 @@ def verify_conv2d(
typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)

mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]})

if use_vm:
rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
)
rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm)
out = get_output_vm(rt_mod, ["data"], [np_data])
else:
rt_mod, dev, num_cutlass_partition = profile_and_build(
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
rt_mod, _, num_cutlass_partition = profile_and_build(
mod_weight_ohwi,
params,
sm,
)
out = get_output(rt_mod, ["data"], [np_data])

assert num_cutlass_partition > 0

rt_mod_ref, _ = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
)
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
if use_cudnn_ref:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
params,
target="cuda -libs=cudnn",
)
else:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
)

np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
ref_out = get_output(rt_mod_ref, ["data"], [np_data])

if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))

np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)


def test_conv2d():
for IC in [3, 16]:
d_shape = (16, IC, 32, 32)
w_shape = (32, IC, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)

verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=IC == 3,
run_benchmark=False,
)

d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)

verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
)

dyn_batch_shape = (relay.Any(),) + d_shape[1:]

mod_nchw = get_conv2d_nchw(d_shape, w_shape)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)

verify_conv2d(
Expand Down

0 comments on commit 8d6a1bf

Please sign in to comment.