diff --git a/CMakeLists.txt b/CMakeLists.txt index db63903d4395..74288029d020 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none) tvm_option(USE_MKLDNN "Build with MKLDNN" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) +tvm_option(USE_THRUST "Build with Thrust" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" OFF) @@ -101,9 +102,11 @@ else(MSVC) message("Build in Debug mode") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") + set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") else() set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}") + set(CMAKE_CUDA_FLAGS "-O2 -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") if (HIDE_PRIVATE_SYMBOLS) message(STATUS "Hide private symbols...") set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}") @@ -262,6 +265,7 @@ if(NOT MSVC) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14) message(STATUS "Build with c++14") set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}") + set(CMAKE_CUDA_STANDARD 14) endif() add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) diff --git a/cmake/config.cmake b/cmake/config.cmake index 7e40f0d17869..fd295aa35b34 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -201,3 +201,6 @@ set(USE_VTA_FPGA OFF) # Whether to build the example external runtime module set(USE_EXAMPLE_EXT_RUNTIME OFF) + +# Whether use Thrust +set(USE_THRUST OFF) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 250a65c591e1..936bb681b7ff 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -55,6 +55,14 @@ if(USE_CUDA) endif() endif(USE_CUBLAS) + if(USE_THRUST) + message(STATUS "Build with Thrust support") + cmake_minimum_required(VERSION 3.13) # to compile CUDA code + enable_language(CUDA) + file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) + list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) + endif(USE_THRUST) + else(USE_CUDA) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) endif(USE_CUDA) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e5eff1c6b790..8ccd6bf51508 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -20,6 +20,7 @@ from tvm.te import SpecializedCondition from .generic import * from .. import op as _op +from .... import get_global_func @schedule_injective.register(["cuda", "gpu"]) def schedule_injective_cuda(attrs, outs, target): @@ -328,6 +329,11 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_argsort(topi.cuda.argsort), wrap_topi_schedule(topi.cuda.schedule_argsort), name="argsort.cuda") + if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): + strategy.add_implementation(wrap_compute_argsort(topi.cuda.argsort_thrust), + wrap_topi_schedule(topi.cuda.schedule_argsort), + name="argsort_thrust.cuda", + plevel=15) return strategy @topk_strategy.register(["cuda", "gpu"]) @@ -337,6 +343,11 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): strategy.add_implementation(wrap_compute_topk(topi.cuda.topk), wrap_topi_schedule(topi.cuda.schedule_topk), name="topk.cuda") + if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): + strategy.add_implementation(wrap_compute_topk(topi.cuda.topk_thrust), + wrap_topi_schedule(topi.cuda.schedule_topk), + name="topk_thrust.cuda", + plevel=15) return strategy @multibox_prior_strategy.register(["cuda", "gpu"]) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu new file mode 100644 index 000000000000..fc9deacd371b --- /dev/null +++ b/src/runtime/contrib/thrust/thrust.cu @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external Thrust library call + */ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace contrib { + +using namespace runtime; + +// Performs sorting along axis -1 and returns both sorted values and indices. +template +void thrust_sort(DLTensor* input, + DLTensor* out_values, + DLTensor* out_indices, + bool is_ascend) { + thrust::device_ptr data_ptr(static_cast(input->data)); + thrust::device_ptr values_ptr(static_cast(out_values->data)); + thrust::device_ptr indices_ptr(static_cast(out_indices->data)); + + int n_values = input->shape[input->ndim - 1]; + int n_iter = 1; + for (int i = 0; i < input->ndim - 1; ++i) { + n_iter *= input->shape[i]; + } + + thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr); + + for (int i = 0 ; i < n_iter; ++i) { + thrust::sequence(indices_ptr, indices_ptr + n_values); + if (is_ascend) { + thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); + } else { + thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr, + thrust::greater()); + } + values_ptr += n_values; + indices_ptr += n_values; + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_GE(args.num_args, 4); + DLTensor* input = args[0]; + DLTensor* values_out = args[1]; + DLTensor* indices_out = args[2]; + bool is_ascend = args[3]; + + auto data_dtype = DLDataType2String(input->dtype); + auto out_dtype = DLDataType2String(indices_out->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } +}); + +} // namespace contrib +} // namespace tvm diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 42cdff714534..a2bb260e27ad 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -42,6 +42,7 @@ "pxi", "pyd", "pyx", + "cu", # relay text format "rly", # configurations diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index f9e535e133fa..5499683f9d23 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -21,7 +21,7 @@ from .injective import schedule_injective_from_existing from ..math import identity -from ..transform import strided_slice +from ..transform import strided_slice, transpose from .. import tag def _schedule_sort(outs): @@ -291,6 +291,40 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): tag="argsort_gpu")[1] return out +def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): + """Performs sorting along the given axis and returns an array of indicies + having same shape as an input array that index data in sorted order. + + Parameters + ---------- + data: tvm.te.Tensor + The input array. + + valid_count : tvm.te.Tensor, optional + The number of valid elements to be sorted. + + axis : int, optional + Axis long which to sort the input tensor. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + DType of the output indices. + + Returns + ------- + out : tvm.te.Tensor + The output of this function. + """ + if valid_count is not None: + # TODO: implement argsort_nms with Thrust + out = argsort(data, valid_count, axis, is_ascend, dtype) + else: + out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype) + return out + + def schedule_argsort(outs): """Schedule for argsort operator. @@ -384,6 +418,82 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): return output +def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): + """Get the top k elements in an input tensor along the given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input tensor. + + k : int, optional + Number of top elements to select. Return all elements if k < 1. + + axis : int, optional + Axis long which to sort the input tensor. + + ret_type: str, optional + The return type [both, values, indices]. + "both": return both top k data and indices. + "values": return top k data only. + "indices": return top k indices only. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + The data type of the indices output. + + Returns + ------- + out : tvm.te.Tensor or List[tvm.te.Tensor] + The computed result. + """ + assert ret_type in ["both", "values", "indices"] + ndim = len(data.shape) + axis = ndim + axis if axis < 0 else axis + + def swap(arr): + """ swap arr[axis] and arr[-1] """ + return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]] + + if axis != ndim - 1: + # Prepare for sorting along axis -1. + axes = swap(list(range(ndim))) + data = transpose(data, axes) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_bufs = [ + tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8), + tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) + ] + + out = te.extern([data.shape, data.shape], + [data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend), + in_buffers=[data_buf], + out_buffers=out_bufs, + name="topk_gpu", + tag="topk_gpu") + + if k > 0: + beg = [0] * ndim + end = data.shape[:-1] + [k] + out = [strided_slice(o, beg, end) for o in out] + + if axis != ndim - 1: + axes = swap(list(range(ndim))) + out = [transpose(o, axes) for o in out] + + if ret_type == "values": + out = out[0] + elif ret_type == "indices": + out = out[1] + + return out + + def schedule_topk(outs): """Schedule for argsort operator.