Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thrust support for nms #5116

Merged
merged 11 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ set(USE_NNPACK OFF)
# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)

# /path/to/tensorflow: tensorflow root path when use tflite library
Expand Down
99 changes: 99 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,105 @@ void thrust_sort(DLTensor* input,
}
}

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort_nms(DLTensor* input,
DLTensor* valid_count,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
thrust::device_ptr<IndicesType> valid_count_ptr(static_cast<IndicesType *>(valid_count->data));
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
}

thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
int current_values = valid_count_ptr[i];
masahi marked this conversation as resolved.
Show resolved Hide resolved
thrust::sequence(indices_ptr, indices_ptr + current_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + current_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + current_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += current_values;
indices_ptr += current_values;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort_nms<float, int32_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort_nms<float, int64_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort_nms<float, float>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort_nms<float, double>(input, valid_count, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort_nms<double, int32_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort_nms<double, int64_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort_nms<double, float>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort_nms<double, double>(input, valid_count, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort_nms<int32_t, int32_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort_nms<int32_t, int64_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort_nms<int32_t, float>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort_nms<int32_t, double>(input, valid_count, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort_nms<int64_t, int32_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort_nms<int64_t, int64_t>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort_nms<int64_t, float>(input, valid_count, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort_nms<int64_t, double>(input, valid_count, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});

masahi marked this conversation as resolved.
Show resolved Hide resolved
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
Expand Down
10 changes: 7 additions & 3 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import te

from tvm.tir import if_then_else
from .sort import argsort
from .sort import argsort, argsort_thrust
from .. import tag


Expand Down Expand Up @@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
sort_tensor = argsort_thrust(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
else:
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)

sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
Expand Down
50 changes: 48 additions & 2 deletions topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,53 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):

return ib.get()

def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.

Parameters
----------
data: tvm.te.Tensor
The input array.

valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.

axis : int, optional
Axis long which to sort the input tensor.

is_ascend : boolean, optional
Whether to sort in ascending or descending order.

dtype : string, optional
DType of the output indices.

Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
if axis < 0:
axis = len(data.shape) + axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check for the axis. For example,

assert axis==len(data.shape)-1, "Supports sorting along the last axis only"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In original sort_nms, we do support axis other than -1, I'll add your swap operations in this operator.

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
]
out = te.extern([data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
Laurawly marked this conversation as resolved.
Show resolved Hide resolved
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu")
return out[1]

def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand Down Expand Up @@ -318,8 +365,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
Expand Down