Skip to content

Commit

Permalink
[CUTLASS] More robust support for pattern matching and alignment (#9698)
Browse files Browse the repository at this point in the history
* bug fix in im2col encoding

* skip legalize when batch size is dynamic

* add sm75 kernels to sm80 profilings

* add dtype and layout check in parttern match

* use align1 kernel for unusual channel cases (IC = 3 etc)

* test IC=3 convolution

* fixed check functions for fused cases, run infer type before mergecomposite

* check align on N dim

* add comment on IC == 3 case

* lint fix

* do not offload depthwise conv2d

* lint

* trigger CI
  • Loading branch information
masahi authored Dec 14, 2021
1 parent 4e70931 commit d1dafbd
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 54 deletions.
12 changes: 6 additions & 6 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,16 @@ def profile(
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
B, H, W, C = d_shape
K, R, S, _ = w_shape
B, _, _, IC = d_shape
OC, R, S, _ = w_shape
_, P, Q, _ = out_shape

M = B * H * W
K = R * S * C
N = B * P * Q
M = B * P * Q
N = OC
K = R * S * IC

gemm_profile_result = self.gemm_profiler.profile(
M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
)

tile_description = gemm_profile_result["tile_description"]
Expand Down
21 changes: 12 additions & 9 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def create_gemm_operator(
# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
# align1 variants do not seem to be available for sm80
80: {
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
}

Expand All @@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
self.cache = {}

def check_align(self, op_name, M):
def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
# The same alignment is used for all axes
align = int(aligns[0][-1])
if M % align != 0:
return False
return True
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
Expand All @@ -194,7 +197,7 @@ def profile(
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

for op in ops:
op["runtime"] = -1
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ def get_tile_descriptions(math_inst):
TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
]

return generate_tensor_op_common(
sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, op_creator)
sm80_kernels = generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, op_creator
)
return sm75_kernels + sm80_kernels


class ProfilerEngine:
Expand Down
81 changes: 70 additions & 11 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from tvm.ir.transform import Sequential
from tvm.relay import transform
from ...dataflow_pattern import wildcard, is_op, is_constant

Expand Down Expand Up @@ -56,31 +58,88 @@ def make_batch_matmul_pattern():


def make_conv2d_pattern():
# TODO(masahi): Check layout and alignment
return is_op("nn.conv2d")(wildcard(), wildcard())


def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
# Only fp16 inputs are supported for now.
return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16"


def get_root_call(call, root_op_name):
if str(call.op) == root_op_name:
return call
return get_root_call(call.args[0], root_op_name)


def check_gemm(call):
"""Check if the given dense workload can be offloaded to CUTLASS."""
dense = get_root_call(call, "nn.dense")
lhs = dense.args[0].checked_type
rhs = dense.args[1].checked_type
return check_dtype(lhs, rhs)


def check_batch_matmul(call):
"""Check if the given batch_matmul workload can be offloaded to CUTLASS."""
batch_matmul = get_root_call(call, "nn.batch_matmul")
lhs = batch_matmul.args[0].checked_type
rhs = batch_matmul.args[1].checked_type
transpose_a = batch_matmul.attrs.transpose_a
transpose_b = batch_matmul.attrs.transpose_b
return check_dtype(lhs, rhs) and not transpose_a and transpose_b


def is_depthwise_conv2d(ic, oc, groups):
return ic == oc == groups


def check_conv2d(call):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
data_layout = conv2d.attrs.data_layout
kernel_layout = conv2d.attrs.kernel_layout
data = conv2d.args[0].checked_type
weight = conv2d.args[1].checked_type
if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight):
return False
IC = data.shape[3]
OC = weight.shape[0]
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None))
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"))
dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"))
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm)
dense_bias_gelu_fp16_pat = (
"cutlass.dense_bias_gelu_fp16",
make_gemm_pattern(True, "gelu"),
check_gemm,
)
dense_bias_gelu_fp32_pat = (
"cutlass.dense_bias_gelu_fp32",
make_gemm_pattern(True, "gelu", out_dtype="float32"),
check_gemm,
)
cutlass_patterns = [
dense_bias_gelu_fp16_pat,
dense_bias_gelu_fp32_pat,
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern()),
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
# TODO(masahi): Add more conv2d patterns
("cutlass.conv2d", make_conv2d_pattern()),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
mod = transform.PartitionGraph()(mod)
return mod
seq = Sequential(
[
transform.InferType(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"]),
transform.PartitionGraph(),
]
)
return seq(mod)
4 changes: 4 additions & 0 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def _conv2d_legalize(attrs, inputs, arg_types):

elif data_dtype in ["float16"]:
if data_layout == "NHWC" and kernel_layout == "HWIO":
if isinstance(data_tensor.shape[0], tvm.tir.expr.Any):
# Skip legalize when the batch size is dynamic
return None

batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[3].value
Expand Down
71 changes: 44 additions & 27 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def verify_batch_matmul(
def test_dense():
verify_dense(get_dense(M, N, K), M, N, K)
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
# Test align1 case
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)


def test_dense_bias():
Expand Down Expand Up @@ -312,13 +314,14 @@ def convert_conv2d_layout(mod, desired_layouts):


def verify_conv2d(
mod_nchw,
mod_ref,
mod_nchw, # can be dynamic batch
mod_ref, # always static batch
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
run_benchmark=False,
):
if not has_cutlass():
Expand All @@ -332,52 +335,66 @@ def verify_conv2d(
typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)

mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]})

if use_vm:
rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
)
rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm)
out = get_output_vm(rt_mod, ["data"], [np_data])
else:
rt_mod, dev, num_cutlass_partition = profile_and_build(
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
rt_mod, _, num_cutlass_partition = profile_and_build(
mod_weight_ohwi,
params,
sm,
)
out = get_output(rt_mod, ["data"], [np_data])

assert num_cutlass_partition > 0

rt_mod_ref, _ = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
)
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
if use_cudnn_ref:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
params,
target="cuda -libs=cudnn",
)
else:
rt_mod_ref, dev = get_ref_rt_mod(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target="cuda",
)

np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
ref_out = get_output(rt_mod_ref, ["data"], [np_data])

if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))

np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)


def test_conv2d():
for IC in [3, 16]:
d_shape = (16, IC, 32, 32)
w_shape = (32, IC, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)

verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=(IC == 3), # The autotvm kernel has an accuracy issue with IC == 3 case
run_benchmark=False,
)

d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)

verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
)

dyn_batch_shape = (relay.Any(),) + d_shape[1:]

mod_nchw = get_conv2d_nchw(d_shape, w_shape)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)

verify_conv2d(
Expand Down

0 comments on commit d1dafbd

Please sign in to comment.