Skip to content

Commit

Permalink
[CUTLASS] Profile only the largest-possible alignment by default (#10036
Browse files Browse the repository at this point in the history
)

* introduce profile_all_alignments option

* add profile_all_alignment option to API

* wip

* fixed dynamic case

* black

* update gen_gemm too

* minor improvement

* fix

* all tests work

* add doc

* fixed for sm = 75 case

* fix typo

* remove unused import

* profile_all -> find_first_valid

* fix
  • Loading branch information
masahi authored Jan 26, 2022
1 parent 2830c96 commit 1b9b05e
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 66 deletions.
46 changes: 31 additions & 15 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
Expand All @@ -126,10 +126,10 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched=batched,
profile_all=profile_all,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -146,7 +146,7 @@ def handle_batch_matmul(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
Expand All @@ -165,7 +165,7 @@ def handle_batch_matmul(
arg1_dtype,
use_3xtf32,
True,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand All @@ -191,7 +191,7 @@ def handle_dense(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
Expand All @@ -210,7 +210,7 @@ def handle_dense(
arg1_dtype,
use_3xtf32,
False,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand All @@ -237,7 +237,8 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
Expand All @@ -257,10 +258,11 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -272,7 +274,13 @@ def handle_conv2d(


def tune_cutlass_kernels(
mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"
mod,
sm,
use_3xtf32=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
tmp_dir="./tmp",
):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand All @@ -286,7 +294,14 @@ def tune_cutlass_kernels(
An integer specifying the compute capability. For example, 75 for Turing and
80 or 86 for Ampere.
profile_all : bool
use_3xtf32 : bool
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
find_first_valid : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.
Expand Down Expand Up @@ -342,7 +357,8 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -357,7 +373,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -372,7 +388,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand Down
36 changes: 17 additions & 19 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
import re
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
Expand Down Expand Up @@ -168,14 +167,6 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
)
return {"name": name, "opdef": opdef}

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
return all([dim % align == 0 for dim in [C, K]])

def select_op(
self,
d_shape,
Expand All @@ -187,7 +178,8 @@ def select_op(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand Down Expand Up @@ -216,12 +208,16 @@ def select_op(
return self.cache[workload]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32
out_dtype,
data_dtype,
weight_dtype,
enumerate_conv2d_operators,
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
)

ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

if profile_all:
if not find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

args = (
Expand All @@ -232,7 +228,7 @@ def select_op(
for op in ops:
out = self.engine.evaluate(op, args.split(" "))
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[workload] = op
return op

Expand All @@ -252,11 +248,12 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32=True,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -269,8 +266,9 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)

name, opdef = create_conv2d_operator_with_epilogue(
Expand Down
50 changes: 27 additions & 23 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
import re
from .gemm_operation import GemmOperation, EmitGemmInstance
from .gemm_profiler import GemmProfilerEmitter
from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
Expand Down Expand Up @@ -63,8 +62,9 @@ def create_gemm_operator_with_epilogue(
swizzling_functor,
)

return op.procedural_name(), EmitGemmInstance().emit(
op, no_beta_scaling=no_beta_scaling, batched=batched
return (
op.procedural_name(),
EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched),
)


Expand Down Expand Up @@ -150,26 +150,22 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
self.cache = {}

def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32
out_dtype,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
lambda align: align == 1, # Only request align1 kernels
use_3xtf32,
profile_all_alignments=True, # To include all align1 kernels
)

default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]

if arg0_dtype == "float32":
Expand Down Expand Up @@ -200,7 +196,8 @@ def select_op(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand All @@ -211,22 +208,27 @@ def select_op(
op = self.cache[(M, N, K)]
return op

# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
use_3xtf32=use_3xtf32,
lambda align: all([dim % align == 0 for dim in [M, N, K]]),
use_3xtf32,
profile_all_alignments=profile_all_alignments,
)
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

if profile_all:
if not find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

for op in ops:
out = self.engine.evaluate(op, [M, N, K])
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[(M, N, K)] = op
return op

Expand All @@ -244,12 +246,13 @@ def profile(
arg0_dtype,
arg1_dtype,
use_3xtf32=True,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
batched=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -260,7 +263,8 @@ def profile(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=profile_all,
profile_all_alignments=profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)

Expand Down
Loading

0 comments on commit 1b9b05e

Please sign in to comment.