diff --git a/topi/python/topi/cuda/extern.py b/topi/python/topi/cuda/extern.py index 34dae092cf69..94046196c074 100644 --- a/topi/python/topi/cuda/extern.py +++ b/topi/python/topi/cuda/extern.py @@ -2,15 +2,7 @@ """Schedule for cudnn and miopen extern op""" import tvm from .. import generic - -def _schedule_output(op, sch): - x = op.output(0) - fused = sch[x].fuse(*sch[x].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads - bx, tx = sch[x].split(fused, factor=num_thread) - sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) - sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) - return sch +from .injective import _schedule_injective @generic.schedule_extern.register(["cuda", "gpu"]) @@ -36,5 +28,5 @@ def schedule_extern(outs): for out in outs: if isinstance(out.op, tvm.tensor.ExternOp): continue - _schedule_output(out.op, s) + _schedule_injective(out.op, s) return s diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index 0143aec36a7b..4ca89fb3ecd3 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -1,15 +1,32 @@ # pylint: disable=invalid-name, unused-variable, """Schedule for composition of injective operator""" import tvm -from .. import generic +from .. import generic, util def _schedule_injective(op, sch): x = op.output(0) fused = sch[x].fuse(*sch[x].op.axis) num_thread = tvm.target.current_target(allow_none=False).max_num_threads - bx, tx = sch[x].split(fused, factor=num_thread) - sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) - sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) + max_block = 256 + + try: + const_size = util.get_const_int(util.prod(x.shape)) + max_block = 256 + need_block_split = const_size > max_block * num_thread + except ValueError: + need_block_split = False + + if need_block_split: + xo, xi = sch[x].split(fused, factor=num_thread * max_block) + bx, tx = sch[x].split(xi, factor=num_thread) + sch[x].reorder(bx, tx, xo) + sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) + sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) + else: + bx, tx = sch[x].split(fused, factor=num_thread) + sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) + sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) + return sch diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index a093bfd8a35c..3625f6aaefaa 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -2,6 +2,28 @@ from __future__ import absolute_import as _abs import tvm + +def prod(x): + """Get the product of every items in the tuple. + + Parameters + ---------- + x: tuple + Input tuple + + Returns + ------- + value : Expr + The result value + """ + if not x: + return tvm.const(1, "int32") + res = x[0] + for i in range(1, len(x)): + res = res * x[i] + return res + + def get_const_int(expr): """Verifies expr is integer and get the constant value. diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index 6360155c5012..2f7898ff242a 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -71,6 +71,10 @@ def _prelu_numpy(x, W): def test_relu(): verify_relu(10, 128) +def test_schedule_big_array(): + verify_relu(1024 * 100 , 512) + + def test_leaky_relu(): verify_leaky_relu(100, 0.1) @@ -78,6 +82,7 @@ def test_prelu(): verify_prelu((1, 3, 2, 2), (3,)) if __name__ == "__main__": + test_schedule_big_array() test_relu() test_leaky_relu() test_prelu()