Skip to content

Commit

Permalink
Add thrust support for nms (apache#5116)
Browse files Browse the repository at this point in the history
* add argsort_nms_thrust

* consider valid count in thrust nms sort

* make thrust optional

* typo

* typo

* fix pylint

* address some of the comments

* address more comments

* fix lint

* address more comments

* address more comments
  • Loading branch information
Laurawly authored and Trevor Morris committed Apr 16, 2020
1 parent 9b4eb10 commit 064029b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 40 deletions.
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ set(USE_NNPACK OFF)
# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)

# /path/to/tensorflow: tensorflow root path when use tflite library
Expand Down
90 changes: 62 additions & 28 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
#include <functional>

namespace tvm {
namespace contrib {
Expand All @@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
bool is_ascend,
const std::function<int(int)> &get_sort_len) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
Expand All @@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input,
thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
n_values = get_sort_len(i);
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
Expand All @@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
const std::function<int(int)> &get_sort_len,
std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
auto get_sort_len = [=](int i) { return n_values; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});
} // namespace contrib
} // namespace tvm
10 changes: 7 additions & 3 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import te

from tvm.tir import if_then_else
from .sort import argsort
from .sort import argsort, argsort_thrust
from .. import tag


Expand Down Expand Up @@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
sort_tensor = argsort_thrust(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
else:
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)

sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
Expand Down
73 changes: 65 additions & 8 deletions topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from ..transform import strided_slice, transpose
from .. import tag

def swap(arr, axis):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

def _schedule_sort(outs):
"""Schedule for argsort operator.
Expand Down Expand Up @@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):

return ib.get()

def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
if axis < 0:
axis = ndim + axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
]
out = te.extern([data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu")

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

return out[1]

def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand Down Expand Up @@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
Expand Down Expand Up @@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis

def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)))
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
Expand All @@ -483,7 +540,7 @@ def swap(arr):
out = [strided_slice(o, beg, end) for o in out]

if axis != ndim - 1:
axes = swap(list(range(ndim)))
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

if ret_type == "values":
Expand Down

0 comments on commit 064029b

Please sign in to comment.