From f00bc33a3b53215bbaabe5068863fc19bb0e8250 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 3 Mar 2022 17:48:15 +0800 Subject: [PATCH 1/2] add tensorrt dcn support --- .../tensorrt/common/common_cuda_helper.hpp | 1 + .../tensorrt/common_impl/trt_cuda_helper.cu | 10 +- .../tensorrt/deform_conv/trt_deform_conv.cpp | 257 ++++++++++++++++++ .../tensorrt/deform_conv/trt_deform_conv.hpp | 81 ++++++ .../deform_conv/trt_deform_conv_kernel.cu | 165 +++++++++++ .../deform_conv/trt_deform_conv_kernel.cuh | 145 ++++++++++ .../deform_conv/trt_deform_conv_kernel.hpp | 20 ++ mmdeploy/mmcv/ops/deform_conv.py | 31 +++ tests/test_ops/test_ops.py | 46 ++++ 9 files changed, 751 insertions(+), 5 deletions(-) create mode 100644 csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp create mode 100644 csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp create mode 100644 csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu create mode 100644 csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh create mode 100644 csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp diff --git a/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp b/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp index 02c57c62e6..68f884dd2b 100644 --- a/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp +++ b/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp @@ -4,6 +4,7 @@ #include #include +#include #include diff --git a/csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu b/csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu index 065218958f..092e712def 100644 --- a/csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu +++ b/csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu @@ -5,13 +5,13 @@ using mmdeploy::TensorDesc; template -__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n, - TensorDesc ts_src_stride, TensorDesc ts_dst_stride, +__global__ void copy_permute_kernel(scalar_t *__restrict__ dst, const scalar_t *__restrict__ src, + int n, TensorDesc ts_src_stride, TensorDesc ts_dst_stride, TensorDesc ts_permute) { const int src_dim = ts_src_stride.dim; - int *src_stride = &(ts_src_stride.stride[0]); - int *dst_stride = &(ts_dst_stride.stride[0]); - int *permute = &(ts_permute.shape[0]); + const auto src_stride = ts_src_stride.stride; + const auto dst_stride = ts_dst_stride.stride; + const auto permute = ts_permute.shape; CUDA_1D_KERNEL_LOOP(index, n) { size_t dst_index = index; size_t src_index = 0; diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp new file mode 100644 index 0000000000..95dd27ba83 --- /dev/null +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp @@ -0,0 +1,257 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "trt_deform_conv.hpp" + +#include + +#include + +#include "trt_deform_conv_kernel.hpp" +#include "trt_serialize.hpp" + +using namespace nvinfer1; + +namespace mmdeploy { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"MMCVDeformConv2d"}; +} // namespace + +DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string &name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, const int group) + : TRTPluginBase(name), + mStride(stride), + mPadding(padding), + mDilation(dilation), + mDeformableGroup(deformableGroup), + mGroup(group) {} + +DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void *data, + size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); +} +DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT { + DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic( + mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + // input[0] == input + // input[1] == offset + // input[2] == weight + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[2].d[0]; + + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + + return ret; +} + +bool DeformableConvPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { + if (pos == 0) { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } else { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } +} + +void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, + int nbOutputs) TRT_NOEXCEPT {} + +size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT { + int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[2].dims.d[2]; + int kH = inputs[2].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * + outputWidth * sizeof_dtype); + + size_t out_size = 0; + if (im2col_step != 1) + out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth * + sizeof_dtype); + + return col_size + out_size; +} + +int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, + void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[2].dims.d[2]; + int kernel_w = inputDesc[2].dims.d[3]; + + const void *x = inputs[0]; + const void *offset = inputs[1]; + const void *weight = inputs[2]; + void *output = outputs[0]; + int im2col_step = std::min(batch, 32); + + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + deform_conv((float *)x, (float *)weight, (float *)offset, (float *)output, workSpace, + batch, channels, height, width, channels_out, kernel_w, kernel_h, + mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], + mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, + stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { + return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + + serialized_size(mDeformableGroup) + serialized_size(mGroup); +} + +void DeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); +} + +void DeformableConvPluginDynamic::attachToContext( + cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { + m_cublas_handle = cublasContext; +} + +void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + +////////////////////// creator ///////////////////////////// + +DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deform_groups") == 0) { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("groups") == 0) { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + DeformableConvPluginDynamic *plugin = + new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { + auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} +REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); +} // namespace mmdeploy diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp new file mode 100644 index 0000000000..3ea0ccbefe --- /dev/null +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp @@ -0,0 +1,81 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef TRT_DEFORM_CONV_HPP +#define TRT_DEFORM_CONV_HPP +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" + +namespace mmdeploy { +class DeformableConvPluginDynamic : public TRTPluginBase { + public: + DeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride, + const nvinfer1::Dims padding, const nvinfer1::Dims dilation, + const int deformableGroup, const int group); + + DeformableConvPluginDynamic(const std::string name, const void *data, size_t length); + + DeformableConvPluginDynamic() = delete; + + ~DeformableConvPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, + int nbInputs, nvinfer1::IExprBuilder &exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + const char *getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + + cublasHandle_t m_cublas_handle; +}; + +class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase { + public: + DeformableConvPluginDynamicCreator(); + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmdeploy +#endif // TRT_DEFORM_CONV_HPP diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu new file mode 100644 index 0000000000..0ae82d82a1 --- /dev/null +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu @@ -0,0 +1,165 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include "common_cuda_helper.hpp" +#include "trt_deform_conv_kernel.cuh" +#include "trt_deform_conv_kernel.hpp" +#include "trt_plugin_helper.hpp" + +template +void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, cudaStream_t stream) { + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + deformable_im2col_gpu_kernel<<>>( + num_kernels, input, offset, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, + deformable_group, height_col, width_col, column); + + cudaCheckError(); +} + +template +void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, + scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, + int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, + int padH, int dilationW, int dilationH, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) { + size_t word_size = sizeof(scalar_t); + + im2col_step = std::min(int(batchSize), im2col_step); + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + long outputHW = outputHeight * outputWidth; + long kHW = kH * kW; + long columns_size = + mmdeploy::getAlignedSize(nInputPlane * kHW * im2col_step * outputHW * word_size); + + // column buffer for img2col + char* workspace_ptr = reinterpret_cast(workspace); + scalar_t* columns = reinterpret_cast(workspace_ptr); + workspace_ptr = workspace_ptr + columns_size; + + scalar_t* output_buffer; + if (im2col_step == 1) { + output_buffer = output; + } else { + // output need permute when im2col_step!=1 + output_buffer = reinterpret_cast(workspace_ptr); + } + + long input_elt_step = im2col_step * nInputPlane * inputHeight * inputWidth; + long offset_elt_step = im2col_step * deformable_group * 2 * kHW * outputHW; + long out_buffer_step = nOutputPlane * im2col_step * outputHW; + long col_g_step = nInputPlane * kHW * im2col_step * outputHW / group; + long weight_g_step = nOutputPlane * nInputPlane * kHW / (group * group); + long out_buffer_g_step = out_buffer_step / group; + int m = nOutputPlane / group; + int n = im2col_step * outputHW; + int k = nInputPlane * kHW / group; + scalar_t alpha = 1.f; + scalar_t beta = 0.f; + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + const scalar_t* input_start = input + elt * input_elt_step; + const scalar_t* offset_start = offset + elt * offset_elt_step; + + deform_conv_im2col(input_start, offset_start, columns, nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, + im2col_step, deformable_group, stream); + + for (int g = 0; g < group; ++g) { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output_buffer + elt * out_buffer_step + g * out_buffer_g_step; + + cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, + n, weight_start, k, &beta, out_buffer_start, n); + cudaCheckError(); + } + } + + if (im2col_step != 1) { + int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane, im2col_step, + static_cast(outputHeight), static_cast(outputWidth)}; + int output_buffer_permute[5] = {0, 2, 1, 3, 4}; + memcpyPermute(output, output_buffer, &output_buffer_shape[0], + &output_buffer_permute[0], 5, stream); + } +} + +template void deform_conv(const float* input, const float* weight, const float* offset, + float* output, void* workspace, int batchSize, int nInputPlane, + int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, + int dW, int dH, int padW, int padH, int dilationW, int dilationH, + int group, int deformable_group, int im2col_step, + cublasHandle_t cublas_handle, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh new file mode 100644 index 0000000000..6514efa82c --- /dev/null +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh @@ -0,0 +1,145 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include "common_cuda_helper.hpp" + +template +__device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* __restrict__ input, + const int height, const int width, + scalar_t h, scalar_t w) { + if (h <= -1.f || height <= h || w <= -1.f || width <= w) { + return 0; + } + + const int h_low = floorf(h); + const int w_low = floorf(w); + + input += h_low * width; + const scalar_t v1 = (h_low >= 0 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const int w_high = w_low + 1; + const scalar_t v2 = + (h_low >= 0 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t lw = w - w_low; + const scalar_t v_low = fmaf(v2 - v1, lw, v1); + input += width; + const scalar_t v3 = + (h_low <= height - 2 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const scalar_t v4 = + (h_low <= height - 2 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t v_high = fmaf(v4 - v3, lw, v3); + const scalar_t lh = h - h_low; + const scalar_t val = fmaf(v_high - v_low, lh, v_low); + return val; +} + +template +__global__ void deformable_im2col_gpu_kernel( + const int n, const scalar_t* __restrict__ data_im, const scalar_t* __restrict__ data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, const int width_col, + scalar_t* __restrict__ data_col) { + const int hw_col = height_col * width_col; + const int data_col_step = batch_size * hw_col; + + CUDA_1D_KERNEL_LOOP(index, n) { + // index index of output matrix + int tmp_index = index; + const int w_col = tmp_index % width_col; + tmp_index /= width_col; + const int h_col = tmp_index % height_col; + tmp_index /= height_col; + const int b_col = tmp_index % batch_size; + const int c_im = tmp_index / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t* __restrict__ data_col_ptr = data_col + c_col * data_col_step + index % data_col_step; + const scalar_t* __restrict__ data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* __restrict__ data_offset_ptr = + data_offset + + ((b_col * deformable_group + deformable_group_index) << 1) * kernel_h * kernel_w * hw_col + + h_col * width_col + w_col; + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h = (i * kernel_w + j) * hw_col << 1; + const scalar_t offset_h = data_offset_ptr[data_offset_h]; + const int data_offset_w = data_offset_h + hw_col; + const scalar_t offset_w = data_offset_ptr[data_offset_w]; + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im); + *data_col_ptr = val; + data_col_ptr += data_col_step; + } + } + } +} diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp new file mode 100644 index 0000000000..ceb0fe4147 --- /dev/null +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp @@ -0,0 +1,20 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef TRT_DEFORM_CONV_KERNEL_HPP +#define TRT_DEFORM_CONV_KERNEL_HPP +#include +#include + +template +void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, cudaStream_t stream); + +template +void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, + scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, + int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, + int padH, int dilationW, int dilationH, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); +#endif // TRT_DEFORM_CONV_KERNEL_HPP \ No newline at end of file diff --git a/mmdeploy/mmcv/ops/deform_conv.py b/mmdeploy/mmcv/ops/deform_conv.py index c7bdbc43cb..388e5f6cbd 100644 --- a/mmdeploy/mmcv/ops/deform_conv.py +++ b/mmdeploy/mmcv/ops/deform_conv.py @@ -2,6 +2,37 @@ from mmdeploy.core import SYMBOLIC_REWRITER +@SYMBOLIC_REWRITER.register_symbolic( + 'mmcv.ops.deform_conv.DeformConv2dFunction') +def deform_conv__default(ctx, + g, + input, + offset, + weight, + stride, + padding, + dilation, + groups, + deform_groups, + bias=False, + im2col_step=32): + """Rewrite symbolic function for default backend.""" + assert not bias, 'The "bias" parameter should be False.' + assert groups == 1, 'The "groups" parameter should be 1.' + domain = 'mmdeploy' + op_name = 'MMCVDeformConv2d' + return g.op( + f'{domain}::{op_name}', + input, + offset, + weight, + stride_i=stride, + padding_i=[p for pair in zip(padding, padding) for p in pair], + dilation_i=dilation, + groups_i=groups, + deformable_groups_i=deform_groups) + + @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino') def deform_conv_openvino(ctx, diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 60143d878f..7a5784b0ac 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -214,6 +214,52 @@ def test_modulated_deform_conv(backend, save_dir=save_dir) +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize('in_channels,out_channels,stride,padding,' + 'dilation,groups,deform_groups,kernel_size', + [(3, 64, 1, 0, 1, 1, 1, 3), + (1, 32, 3, 2, 1, 1, 1, 3)]) +def test_deform_conv(backend, + in_channels, + out_channels, + stride, + padding, + dilation, + groups, + deform_groups, + kernel_size, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + input = torch.rand( + 1, in_channels, 28, 28, requires_grad=False) # (n, c, h, w) + else: + input = torch.tensor(input_list[0]) + conv_offset = nn.Conv2d( + in_channels=in_channels, + out_channels=deform_groups * 2 * kernel_size * kernel_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True) + offset = conv_offset(input) + + from mmcv.ops import DeformConv2d + model = DeformConv2d(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, deform_groups).eval() + + with RewriterContext(cfg={}, backend=backend.backend_name, opset=11): + backend.run_and_validate( + model, [input, offset], + 'deform_conv', + input_names=['input', 'offset'], + output_names=['output'], + save_dir=save_dir) + + @pytest.mark.parametrize('backend', [TEST_TENSORRT]) @pytest.mark.parametrize('dynamic_export', [True, False]) @pytest.mark.parametrize('fp16_mode', [True, False]) From d3f9a6c4856360445cf2ceb74b21bfe674d999e6 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 3 Mar 2022 19:13:06 +0800 Subject: [PATCH 2/2] fix lint --- .../tensorrt/deform_conv/trt_deform_conv_kernel.cu | 2 +- .../tensorrt/deform_conv/trt_deform_conv_kernel.hpp | 2 +- mmdeploy/mmcv/ops/deform_conv.py | 6 +----- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu index 0ae82d82a1..5ddb905a42 100644 --- a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu @@ -162,4 +162,4 @@ template void deform_conv(const float* input, const float* weight, const int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); \ No newline at end of file + cublasHandle_t cublas_handle, cudaStream_t stream); diff --git a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp index ceb0fe4147..3d8f6dfc45 100644 --- a/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp +++ b/csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp @@ -17,4 +17,4 @@ void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); -#endif // TRT_DEFORM_CONV_KERNEL_HPP \ No newline at end of file +#endif // TRT_DEFORM_CONV_KERNEL_HPP diff --git a/mmdeploy/mmcv/ops/deform_conv.py b/mmdeploy/mmcv/ops/deform_conv.py index 388e5f6cbd..ccd0542678 100644 --- a/mmdeploy/mmcv/ops/deform_conv.py +++ b/mmdeploy/mmcv/ops/deform_conv.py @@ -17,12 +17,8 @@ def deform_conv__default(ctx, bias=False, im2col_step=32): """Rewrite symbolic function for default backend.""" - assert not bias, 'The "bias" parameter should be False.' - assert groups == 1, 'The "groups" parameter should be 1.' - domain = 'mmdeploy' - op_name = 'MMCVDeformConv2d' return g.op( - f'{domain}::{op_name}', + 'mmdeploy::MMCVDeformConv2d', input, offset, weight,