Skip to content

Commit

Permalink
[CUTLASS] Support more kernels: int8, tf32, and 3xtf32 (#9899)
Browse files Browse the repository at this point in the history
* add int8 type in library

* wip

* adding test and plumbing data and weight dtype

* adding 3xtf32 support and refactor tile description enum

* add 3xtf32 test

* update gemm generator too

* int8 test worked

* 3xtf32 also works

* int8 and 3xtf32 gemm works

* clean up test

* support int8 in sm75

* refined int8 alignment constraints

* black

* support 3xtf32 in default kernel

* remove log

* refine dtype check

* support tf32

* leave TODO for alignment modification on int8 kernels

* tf32 test working

* fix default kernel for tf32

* workaround for compilation failure

* lint
  • Loading branch information
masahi authored Jan 13, 2022
1 parent 920c380 commit ff2c434
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 116 deletions.
91 changes: 83 additions & 8 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,25 @@ def visit_call(self, call):


def select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched,
profile_all,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
out = cutlass_profiler.get_default(
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -109,6 +122,9 @@ def select_gemm_kernel(
NN,
KK,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched=batched,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
Expand All @@ -122,15 +138,35 @@ def select_gemm_kernel(


def handle_batch_matmul(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
MM = arg0_shape[1]
KK = arg0_shape[2]
NN = arg1_shape[1]

name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
True,
profile_all,
use_multiprocessing,
)

return {
Expand All @@ -147,15 +183,35 @@ def handle_batch_matmul(


def handle_dense(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]

name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
False,
profile_all,
use_multiprocessing,
)

assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now."
Expand All @@ -178,12 +234,15 @@ def handle_conv2d(
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(op_type, out_dtype)
out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -195,6 +254,9 @@ def handle_conv2d(
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
Expand All @@ -209,7 +271,9 @@ def handle_conv2d(
}


def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
def tune_cutlass_kernels(
mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"
):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand Down Expand Up @@ -258,6 +322,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
arg0_dtype = new_attrs["arg0_dtype"]
arg1_dtype = new_attrs["arg1_dtype"]

if "conv2d" in op_type:
new_attrs["padding"] = annotator.op_attrs.padding
Expand All @@ -273,6 +339,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand All @@ -285,6 +354,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand All @@ -297,6 +369,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand Down
29 changes: 22 additions & 7 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ def __init__(self, sm, cutlass_path, binary_path):
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, op_type, out_dtype):
gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype)
def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
gemm_profile_result = self.gemm_profiler.get_default(
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32
)
tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
Expand All @@ -165,9 +170,10 @@ def get_default(self, op_type, out_dtype):

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
align = int(aligns[0][-1])
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
return all([dim % align == 0 for dim in [C, K]])

def select_op(
Expand All @@ -178,6 +184,9 @@ def select_op(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=True,
use_multiprocessing=False,
):
Expand Down Expand Up @@ -207,9 +216,9 @@ def select_op(
return self.cache[workload]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
op_creator=enumerate_conv2d_operators,
out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32
)

ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

if profile_all:
Expand Down Expand Up @@ -240,6 +249,9 @@ def profile(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32=True,
profile_all=True,
use_multiprocessing=False,
):
Expand All @@ -254,6 +266,9 @@ def profile(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
Expand Down
66 changes: 53 additions & 13 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,18 @@ def enumerate_gemm_operators(
# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
# align1 variants do not seem to be available for sm80
80: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
# two kernels for tf32 and 3xtf32
("float32", "float32"): (
"cutlass_tensorop_s1688gemm_128x64_32x3_tn_align1",
"cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1",
),
},
}

Expand All @@ -147,21 +152,31 @@ def __init__(self, sm, cutlass_path, binary_path):

def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(aligns[0][-1])
align = int(match.groups()[0])
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(self, op_type, out_dtype, batched=False):
def get_default(
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32
)
default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]

if arg0_dtype == "float32":
default_kernel_name = (
default_kernel_name[0] if not use_3xtf32 else default_kernel_name[1]
)

filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
op = filtered[0]
Expand All @@ -176,7 +191,18 @@ def get_default(self, op_type, out_dtype, batched=False):
op.update({"name": name, "opdef": opdef})
return op

def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
def select_op(
self,
M,
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=True,
use_multiprocessing=False,
):
"""
Profile and select the best kernel from candidate kernels.
See the documentation for the profile method below.
Expand All @@ -187,7 +213,10 @@ def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fa

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
op_creator=enumerate_gemm_operators,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
use_3xtf32=use_3xtf32,
)
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

Expand All @@ -212,6 +241,9 @@ def profile(
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32=True,
profile_all=True,
use_multiprocessing=False,
batched=False,
Expand All @@ -221,7 +253,15 @@ def profile(
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
M,
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)

name, opdef = create_gemm_operator_with_epilogue(
Expand Down
Loading

0 comments on commit ff2c434

Please sign in to comment.