Skip to content

Commit

Permalink
[CUDA] Improve injective schedule to enable half2 (apache#8457)
Browse files Browse the repository at this point in the history
* [CUDA] Improve injective schedule to enable half2

* lint

* fix

* trigger ci
  • Loading branch information
comaniac authored Jul 14, 2021
1 parent 73b38e8 commit 5c1a1cf
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions python/tvm/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import numpy as np

import tvm
from tvm import te
from .. import utils
Expand All @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out):
sch: Schedule
The updated schedule.
"""

def find_nearest_small_factor(num, target):
"""Find the nearest factor of the given number that is smaller than the target."""
for i in range(target, 0, -1):
if num % i == 0:
return i
# Unreachable because i=1 must hold.
return -1

fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
max_block = 256

# vectorize on fp16 data type. This allows to better utilize the memory
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1
# Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization.
vector_width = 2 if out.dtype == "float16" else 1

is_dynamic_output = False
for dim in out.shape:
Expand All @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out):

try:
const_size = utils.get_const_int(out_len)

# Adjust block and thread to make sure they are dividable so that vectorize can be
# correctly applied.
if vector_width > 1 and const_size % vector_width == 0:
remain_total_size = const_size // vector_width
cand_sizes = []
for max_size in [num_thread, max_block]:
cand_sizes.append(
max_size
if remain_total_size % max_size == 0
else find_nearest_small_factor(remain_total_size, max_size)
)
remain_total_size //= cand_sizes[-1]

# If the product of candidate dividable (block * thread) is too small,
# then the performance may be worse even half2 is enabled. Note that 0.7
# is just a heuristic ratio and may not be optimal for all workloads.
if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7:
num_thread, max_block = cand_sizes

need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
Expand Down

0 comments on commit 5c1a1cf

Please sign in to comment.