From 5c1a1cf7289b439b0042a85b63b0007dc1d9b98a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 13 Jul 2021 17:57:19 -0700 Subject: [PATCH] [CUDA] Improve injective schedule to enable half2 (#8457) * [CUDA] Improve injective schedule to enable half2 * lint * fix * trigger ci --- python/tvm/topi/cuda/injective.py | 36 ++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index cce56b796cea..0faddc31c25a 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable, """Schedule for composition of injective operator""" +import numpy as np + import tvm from tvm import te from .. import utils @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out): sch: Schedule The updated schedule. """ + + def find_nearest_small_factor(num, target): + """Find the nearest factor of the given number that is smaller than the target.""" + for i in range(target, 0, -1): + if num % i == 0: + return i + # Unreachable because i=1 must hold. + return -1 + fused = sch[out].fuse(*sch[out].op.axis) num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 - # vectorize on fp16 data type. This allows to better utilize the memory - # bandwidth. - vector_width = 4 if out.dtype == "float16" else 1 + # Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization. + vector_width = 2 if out.dtype == "float16" else 1 is_dynamic_output = False for dim in out.shape: @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out): try: const_size = utils.get_const_int(out_len) + + # Adjust block and thread to make sure they are dividable so that vectorize can be + # correctly applied. + if vector_width > 1 and const_size % vector_width == 0: + remain_total_size = const_size // vector_width + cand_sizes = [] + for max_size in [num_thread, max_block]: + cand_sizes.append( + max_size + if remain_total_size % max_size == 0 + else find_nearest_small_factor(remain_total_size, max_size) + ) + remain_total_size //= cand_sizes[-1] + + # If the product of candidate dividable (block * thread) is too small, + # then the performance may be worse even half2 is enabled. Note that 0.7 + # is just a heuristic ratio and may not be optimal for all workloads. + if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: + num_thread, max_block = cand_sizes + need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False