Skip to content

Commit

Permalink
profile_all -> find_first_valid
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 26, 2022
1 parent ba1bbb9 commit d15a995
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
30 changes: 15 additions & 15 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
Expand All @@ -126,10 +126,10 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched=batched,
profile_all=profile_all,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -146,7 +146,7 @@ def handle_batch_matmul(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
Expand All @@ -165,7 +165,7 @@ def handle_batch_matmul(
arg1_dtype,
use_3xtf32,
True,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand All @@ -191,7 +191,7 @@ def handle_dense(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
Expand All @@ -210,7 +210,7 @@ def handle_dense(
arg1_dtype,
use_3xtf32,
False,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand Down Expand Up @@ -238,7 +238,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
profile_all_alignments,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
Expand All @@ -259,10 +259,10 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
profile_all_alignments,
profile_all=profile_all,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -278,7 +278,7 @@ def tune_cutlass_kernels(
sm,
use_3xtf32=True,
profile_all_alignments=False,
profile_all=True,
find_first_valid=False,
use_multiprocessing=False,
tmp_dir="./tmp",
):
Expand All @@ -301,7 +301,7 @@ def tune_cutlass_kernels(
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
profile_all : bool
find_first_valid : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.
Expand Down Expand Up @@ -358,7 +358,7 @@ def tune_cutlass_kernels(
arg1_dtype,
use_3xtf32,
profile_all_alignments,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -373,7 +373,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -388,7 +388,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def select_op(
weight_dtype,
use_3xtf32,
profile_all_alignments=False,
profile_all=True,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand Down Expand Up @@ -217,7 +217,7 @@ def select_op(
profile_all_alignments,
)

if profile_all:
if find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

args = (
Expand All @@ -228,7 +228,7 @@ def select_op(
for op in ops:
out = self.engine.evaluate(op, args.split(" "))
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[workload] = op
return op

Expand All @@ -249,11 +249,11 @@ def profile(
weight_dtype,
use_3xtf32=True,
profile_all_alignments=False,
profile_all=True,
find_first_valid=False,
use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -267,7 +267,7 @@ def profile(
weight_dtype,
use_3xtf32,
profile_all_alignments,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand Down
12 changes: 6 additions & 6 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def select_op(
arg1_dtype,
use_3xtf32,
profile_all_alignments=False,
profile_all=True,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand All @@ -222,13 +222,13 @@ def select_op(
profile_all_alignments=profile_all_alignments,
)

if profile_all:
if find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

for op in ops:
out = self.engine.evaluate(op, [M, N, K])
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[(M, N, K)] = op
return op

Expand All @@ -247,12 +247,12 @@ def profile(
arg1_dtype,
use_3xtf32=True,
profile_all_alignments=False,
profile_all=True,
find_first_valid=False,
use_multiprocessing=False,
batched=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -264,7 +264,7 @@ def profile(
arg1_dtype,
use_3xtf32,
profile_all_alignments=profile_all_alignments,
profile_all=profile_all,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def profile_and_build(
sm,
use_3xtf32=use_3xtf32,
profile_all_alignments=False,
profile_all=False,
find_first_valid=True,
use_multiprocessing=False,
tmp_dir=tmp_dir,
)
Expand Down

0 comments on commit d15a995

Please sign in to comment.