Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCHEDULE] Fix schedule for big array #1340

Merged
merged 1 commit into from
Jun 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions topi/python/topi/cuda/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
"""Schedule for cudnn and miopen extern op"""
import tvm
from .. import generic

def _schedule_output(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
from .injective import _schedule_injective


@generic.schedule_extern.register(["cuda", "gpu"])
Expand All @@ -36,5 +28,5 @@ def schedule_extern(outs):
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_output(out.op, s)
_schedule_injective(out.op, s)
return s
25 changes: 21 additions & 4 deletions topi/python/topi/cuda/injective.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
from .. import generic, util

def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
max_block = 256

try:
const_size = util.get_const_int(util.prod(x.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False

if need_block_split:
xo, xi = sch[x].split(fused, factor=num_thread * max_block)
bx, tx = sch[x].split(xi, factor=num_thread)
sch[x].reorder(bx, tx, xo)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))

return sch


Expand Down
22 changes: 22 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
from __future__ import absolute_import as _abs
import tvm


def prod(x):
"""Get the product of every items in the tuple.

Parameters
----------
x: tuple
Input tuple

Returns
-------
value : Expr
The result value
"""
if not x:
return tvm.const(1, "int32")
res = x[0]
for i in range(1, len(x)):
res = res * x[i]
return res


def get_const_int(expr):
"""Verifies expr is integer and get the constant value.

Expand Down
5 changes: 5 additions & 0 deletions topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,18 @@ def _prelu_numpy(x, W):
def test_relu():
verify_relu(10, 128)

def test_schedule_big_array():
verify_relu(1024 * 100 , 512)


def test_leaky_relu():
verify_leaky_relu(100, 0.1)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))

if __name__ == "__main__":
test_schedule_big_array()
test_relu()
test_leaky_relu()
test_prelu()