Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix sort changing original input data issue #3212

Merged
merged 6 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/op/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name,
CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
}
CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
}
n->inputs = std::move(inputs);
Expand Down
3 changes: 2 additions & 1 deletion topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.intrin import if_then_else, log, power
from topi.vision import non_max_suppression, get_valid_counts
from .sort import argsort
from .. import tag


def get_valid_counts_pre(data, flag, idx, score_threshold):
Expand Down Expand Up @@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
"valid_count_buf", data_alignment=4)
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)

sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
Expand Down
47 changes: 40 additions & 7 deletions topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

from tvm import api
from topi.sort import argsort
from topi.math import identity
from .. import generic
from .. import tag


def sort_ir(data, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Expand Down Expand Up @@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend):

return ib.get()



def sort_nms_ir(data, valid_count, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.

Expand Down Expand Up @@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
out : tvm.Tensor
The output of this function.
"""
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
sorted_data = identity(data)
if flag:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
out = tvm.extern([data.shape],
[data, valid_count],
[sorted_data, valid_count],
lambda ins, outs: sort_nms_ir(
ins[0], ins[1], outs[0], axis, is_ascend),
dtype="int32",
in_buffers=[data_buf, valid_count_buf],
in_buffers=[sorted_data_buf, valid_count_buf],
out_buffers=[out_buf],
name="argsort_nms_gpu",
tag="argsort_nms_gpu")
else:
out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = tvm.extern([data.shape],
[data],
[sorted_data],
lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend),
dtype=dtype,
in_buffers=[data_buf],
in_buffers=[sorted_data_buf],
out_buffers=[out_buf],
name="argsort_gpu",
tag="argsort_gpu")
return out

@generic.schedule_argsort.register(["cuda", "gpu"])
def schedule_argsort(outs):
"""Schedule for argsort operator.

Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_broadcast(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)

return s
77 changes: 7 additions & 70 deletions topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,17 @@

def _default_schedule(outs):
"""Default schedule for gpu."""
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

from .injective import _schedule_injective
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if op.tag in ["nms", "invalid_to_bottom"]:
if op.tag == "nms":
sort = op.input_tensors[1]
else:
out = op.input_tensors[0]
sort = s[out].op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else:
x = op.output(0)
fused = s[x].fuse(*s[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)

if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)

traverse(outs[0].op)
return s

Expand Down Expand Up @@ -173,19 +149,7 @@ def schedule_proposal(outs):
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
return _default_schedule(outs)

@generic.schedule_get_valid_counts.register(["cuda", "gpu"])
def schedule_get_valid_counts(outs):
Expand All @@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs):
The computation schedule for the op.
"""
return _default_schedule(outs)

@generic.schedule_argsort.register(["cuda", "gpu"])
def schedule_argsort(outs):
"""Schedule for argsort operator.

Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s