Skip to content

Commit

Permalink
[SCHEDULE] Fix schedule for big array (apache#1340)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent 38ecc93 commit 72be668
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
12 changes: 2 additions & 10 deletions topi/python/topi/cuda/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
"""Schedule for cudnn and miopen extern op"""
import tvm
from .. import generic

def _schedule_output(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
from .injective import _schedule_injective


@generic.schedule_extern.register(["cuda", "gpu"])
Expand All @@ -36,5 +28,5 @@ def schedule_extern(outs):
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_output(out.op, s)
_schedule_injective(out.op, s)
return s
25 changes: 21 additions & 4 deletions topi/python/topi/cuda/injective.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
from .. import generic, util

def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
max_block = 256

try:
const_size = util.get_const_int(util.prod(x.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False

if need_block_split:
xo, xi = sch[x].split(fused, factor=num_thread * max_block)
bx, tx = sch[x].split(xi, factor=num_thread)
sch[x].reorder(bx, tx, xo)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))

return sch


Expand Down
22 changes: 22 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
from __future__ import absolute_import as _abs
import tvm


def prod(x):
"""Get the product of every items in the tuple.
Parameters
----------
x: tuple
Input tuple
Returns
-------
value : Expr
The result value
"""
if not x:
return tvm.const(1, "int32")
res = x[0]
for i in range(1, len(x)):
res = res * x[i]
return res


def get_const_int(expr):
"""Verifies expr is integer and get the constant value.
Expand Down
5 changes: 5 additions & 0 deletions topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,18 @@ def _prelu_numpy(x, W):
def test_relu():
verify_relu(10, 128)

def test_schedule_big_array():
verify_relu(1024 * 100 , 512)


def test_leaky_relu():
verify_leaky_relu(100, 0.1)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))

if __name__ == "__main__":
test_schedule_big_array()
test_relu()
test_leaky_relu()
test_prelu()

0 comments on commit 72be668

Please sign in to comment.