Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thrust support for nms #5116

Merged
merged 11 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ set(USE_NNPACK OFF)
# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)

# /path/to/tensorflow: tensorflow root path when use tflite library
Expand Down
90 changes: 62 additions & 28 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
#include <functional>

namespace tvm {
namespace contrib {
Expand All @@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
bool is_ascend,
const std::function<int(int)> &get_sort_len) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
Expand All @@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input,
thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
n_values = get_sort_len(i);
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
Expand All @@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
const std::function<int(int)> &get_sort_len,
std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});

masahi marked this conversation as resolved.
Show resolved Hide resolved

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
auto get_sort_len = [=](int i) { return n_values; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});
} // namespace contrib
} // namespace tvm
10 changes: 7 additions & 3 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import te

from tvm.tir import if_then_else
from .sort import argsort
from .sort import argsort, argsort_thrust
from .. import tag


Expand Down Expand Up @@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
sort_tensor = argsort_thrust(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
else:
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)

sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
Expand Down
73 changes: 65 additions & 8 deletions topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from ..transform import strided_slice, transpose
from .. import tag

def swap(arr, axis):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

def _schedule_sort(outs):
"""Schedule for argsort operator.

Expand Down Expand Up @@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):

return ib.get()

def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.

Parameters
----------
data: tvm.te.Tensor
The input array.

valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.

axis : int, optional
Axis long which to sort the input tensor.

is_ascend : boolean, optional
Whether to sort in ascending or descending order.

dtype : string, optional
DType of the output indices.

Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
if axis < 0:
axis = ndim + axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
]
out = te.extern([data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
Laurawly marked this conversation as resolved.
Show resolved Hide resolved
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu")

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

return out[1]

def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand Down Expand Up @@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
Expand Down Expand Up @@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis

def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)))
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
Expand All @@ -483,7 +540,7 @@ def swap(arr):
out = [strided_slice(o, beg, end) for o in out]

if axis != ndim - 1:
axes = swap(list(range(ndim)))
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

if ret_type == "values":
Expand Down