Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][OP] Use Thrust sort for argsort and topk #5097

Merged
merged 10 commits into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF)
Expand Down Expand Up @@ -101,9 +102,11 @@ else(MSVC)
message("Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
else()
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O2 -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
if (HIDE_PRIVATE_SYMBOLS)
message(STATUS "Hide private symbols...")
set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}")
Expand Down Expand Up @@ -262,6 +265,7 @@ if(NOT MSVC)
check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
message(STATUS "Build with c++14")
set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 14)
endif()

add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,6 @@ set(USE_VTA_FPGA OFF)

# Whether to build the example external runtime module
set(USE_EXAMPLE_EXT_RUNTIME OFF)

# Whether use Thrust
set(USE_THRUST OFF)
8 changes: 8 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ if(USE_CUDA)
endif()
endif(USE_CUBLAS)

if(USE_THRUST)
message(STATUS "Build with Thrust support")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)

else(USE_CUDA)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
endif(USE_CUDA)
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.te import SpecializedCondition
from .generic import *
from .. import op as _op
from .... import get_global_func

@schedule_injective.register(["cuda", "gpu"])
def schedule_injective_cuda(attrs, outs, target):
Expand Down Expand Up @@ -328,6 +329,11 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_argsort(topi.cuda.argsort),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort_thrust.cuda",
plevel=15)
return strategy

@topk_strategy.register(["cuda", "gpu"])
Expand All @@ -337,6 +343,11 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk_thrust.cuda",
plevel=15)
return strategy

@multibox_prior_strategy.register(["cuda", "gpu"])
Expand Down
133 changes: 133 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file Use external Thrust library call
*/

#include <thrust/device_ptr.h>
#include <thrust/sort.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>

namespace tvm {
namespace contrib {

using namespace runtime;

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
}

thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += n_values;
indices_ptr += n_values;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});

} // namespace contrib
} // namespace tvm
1 change: 1 addition & 0 deletions tests/lint/check_file_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"pxi",
"pyd",
"pyx",
"cu",
# relay text format
"rly",
# configurations
Expand Down
112 changes: 111 additions & 1 deletion topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .injective import schedule_injective_from_existing
from ..math import identity
from ..transform import strided_slice
from ..transform import strided_slice, transpose
from .. import tag

def _schedule_sort(outs):
Expand Down Expand Up @@ -291,6 +291,40 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
tag="argsort_gpu")[1]
return out

def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.

Parameters
----------
data: tvm.te.Tensor
The input array.

valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.

axis : int, optional
Axis long which to sort the input tensor.

is_ascend : boolean, optional
Whether to sort in ascending or descending order.

dtype : string, optional
DType of the output indices.

Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out


def schedule_argsort(outs):
"""Schedule for argsort operator.

Expand Down Expand Up @@ -384,6 +418,82 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
return output


def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.

Parameters
----------
data : tvm.te.Tensor
The input tensor.

k : int, optional
Number of top elements to select. Return all elements if k < 1.

axis : int, optional
Axis long which to sort the input tensor.

ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.

is_ascend : boolean, optional
Whether to sort in ascending or descending order.

dtype : string, optional
The data type of the indices output.

Returns
-------
out : tvm.te.Tensor or List[tvm.te.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis

def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)))
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
]

out = te.extern([data.shape, data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_gpu",
tag="topk_gpu")

if k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k]
out = [strided_slice(o, beg, end) for o in out]

if axis != ndim - 1:
axes = swap(list(range(ndim)))
out = [transpose(o, axes) for o in out]

if ret_type == "values":
out = out[0]
elif ret_type == "indices":
out = out[1]

return out


def schedule_topk(outs):
"""Schedule for argsort operator.

Expand Down