Skip to content

Commit

Permalink
feat(//core/plugins): Add adaptive_max_pool2d plugin, enable the plug…
Browse files Browse the repository at this point in the history
…ins to run on GPU

Signed-off-by: Dheeraj Peri <[email protected]>

Signed-off-by: Dheeraj Peri <[email protected]>

Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed May 19, 2021
1 parent 03a6ca4 commit 6f4aa40
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 177 deletions.
98 changes: 39 additions & 59 deletions core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,77 +60,53 @@ bool AdaptivePoolingConverter(
auto in_shape = util::toVec(in->getDimensions());
nvinfer1::ILayer* new_layer = nullptr;

if (ctx->input_is_dynamic) {
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
LOG_WARNING(
"Adaptive pooling layer will be run through ATen, via not TensorRT, performace will be lower than expected. Consider switching either to static input shape or moving to non adaptive pooling if this is an issue");
#else
LOG_WARNING(
"Adaptive pooling layer will be run through ATen (on CPU), via not TensorRT, performace will suffer. Consider switching either to static input shape or moving to non adaptive pooling");
#endif
/*======CONFIGURE PLUGIN PARAMETERS======*/
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;

TRTORCH_CHECK(
pool_type == nvinfer1::PoolingType::kAVERAGE,
"Unable to create MAX pooling (interpolation) plugin from node" << *n);

nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;

auto out_shape = in_shape;
std::copy_n(out_size.d, out_size.nbDims, out_shape.begin() + (in_shape.size() - out_size.nbDims));
auto out_shape = in_shape;
auto out_size_vec = util::toVec(out_size);

std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
f.emplace_back(
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));
std::copy(out_size_vec.begin(), out_size_vec.end(), out_shape.begin() + (in_shape.size() - out_size_vec.size()));

std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
f.emplace_back(nvinfer1::PluginField(
"out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
f.emplace_back(
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));

auto out_size_vec = util::toVec(out_size);
std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end());
f.emplace_back(nvinfer1::PluginField(
"out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size()));
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
f.emplace_back(
nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));

f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0));
std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end());
f.emplace_back(nvinfer1::PluginField(
"out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size()));

std::string mode = "adaptive_pool2d";
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0));

int32_t align_corners_casted = 0;
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));
int32_t align_corners_casted = 0;
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));

int32_t use_scales_casted = 0;
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
int32_t use_scales_casted = 0;
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));

fc.nbFields = f.size();
fc.fields = f.data();
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
auto interpolate_plugin = creator->createPlugin("adaptive_pool2d", &fc);

new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
std::string mode = "adaptive_avg_pool2d";
if (pool_type == nvinfer1::PoolingType::kMAX) {
mode = "adaptive_max_pool2d";
}
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));

} else {
std::vector<int64_t> stride(out_size.nbDims);
for (int64_t i = 0; i < out_size.nbDims; i++) {
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_size.d[(out_size.nbDims - 1) - i];
}
LOG_DEBUG("Stride: " << util::toDims(stride));
fc.nbFields = f.size();
fc.fields = f.data();
/*====== PLUGIN PARAMETERS CONFIGURATION COMPLETED ======*/

std::vector<int64_t> window(out_size.nbDims);
for (int64_t i = 0; i < out_size.nbDims; i++) {
window[window.size() - 1 - i] =
in_shape[in_shape.size() - 1 - i] - (out_size.d[out_size.nbDims - 1 - i] - 1) * stride[stride.size() - 1 - i];
}
LOG_WARNING(
"Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue");

LOG_DEBUG("Window: " << util::toDims(window));
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc);

auto pooling_layer = ctx->net->addPoolingNd(*in, pool_type, util::toDims(window));
TRTORCH_CHECK(pooling_layer, "Unable to create average pooling layer from node: " << *n);
pooling_layer->setStrideNd(util::toDims(stride));
new_layer = pooling_layer;
}
new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);

new_layer->setName(util::node_info(n).c_str());
auto layer_output = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims, false, false);
Expand All @@ -156,7 +132,7 @@ bool PoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args,
auto padding = util::toDims(args[3].unwrapToIntList());
auto stride = util::toDims(args[2].unwrapToIntList());
if (stride.nbDims == 0) {
LOG_DEBUG("Stride not providied, using kernel_size as stride");
LOG_DEBUG("Stride not provided, using kernel_size as stride");
stride = util::toDims(args[1].unwrapToIntList());
}

Expand Down Expand Up @@ -265,6 +241,10 @@ auto pooling_registrations TRTORCH_UNUSED =
.pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE);
}})
.pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX);
}});
} // namespace
} // namespace impl
Expand Down
99 changes: 30 additions & 69 deletions core/plugins/impl/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ InterpolatePlugin::InterpolatePlugin(
align_corners_(align_corners),
use_scales_(use_scales) {
if (use_scales) {
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
TRTORCH_ASSERT(mode_ != "adaptive_avg_pool2d", "use_scales is not valid for adaptive_avg_pool2d");
TRTORCH_ASSERT(
scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
Expand Down Expand Up @@ -106,7 +106,11 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
}

int InterpolatePlugin::getNbOutputs() const {
return 1;
if (mode_ == "adaptive_max_pool2d") {
return 2;
} else {
return 1;
}
}

const char* InterpolatePlugin::getPluginType() const {
Expand Down Expand Up @@ -166,15 +170,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
}

int InterpolatePlugin::initialize() {
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
tensor_options_ = tensor_options_.device(c10::kCUDA);
#else
tensor_options_ = tensor_options_.device(c10::kCPU);
#endif

// c10::kFloat = FLOAT32
tensor_options_ = tensor_options_.dtype(c10::kFloat);

return 0;
}

Expand Down Expand Up @@ -211,9 +206,15 @@ bool InterpolatePlugin::supportsFormatCombination(
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) {
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");

if (mode_ == "adaptive_max_pool2d") {
TRTORCH_ASSERT(nbOutputs == 2, "Expected 2 tensors as output to interpolate plugin");
TRTORCH_ASSERT(0 <= pos && pos <= 2, "There should be exactly 3 connections to the plugin - 1 input, 2 output");
} else {
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
}

const nvinfer1::PluginTensorDesc& in = inOut[0];

Expand Down Expand Up @@ -250,10 +251,10 @@ int InterpolatePlugin::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) {
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
at::Tensor output = at::from_blob(
outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);
at::Tensor input =
at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
at::Tensor output =
at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);

at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
Expand All @@ -263,27 +264,30 @@ int InterpolatePlugin::enqueue(
cudaEventRecord(event, stream);

cudaStreamWaitEvent(torch_stream.stream(), event, 0);

at::Tensor out;
if (use_scales_) {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
out = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
out = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
out = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}
} else {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
out = at::upsample_linear1d(input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_avg_pool2d") {
out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
} else if (mode_ == "adaptive_max_pool2d") {
out = std::get<0>(at::adaptive_max_pool2d(input, {size_[0], size_[1]}));
}
}

output.copy_(out);
cudaEvent_t torch_event;
cudaEventCreate(&torch_event);
cudaEventRecord(torch_event, torch_stream.stream());
Expand All @@ -294,49 +298,6 @@ int InterpolatePlugin::enqueue(
cudaEventDestroy(torch_event);

return 0;
#else
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
// Tensor in the context of TensorRT execution
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
cudaMemcpyAsync(
input_blob,
static_cast<const void*>(inputs[0]),
util::volume(inputDesc->dims) * sizeof(float),
cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
at::Tensor output;
if (use_scales_) {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}
} else {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
}
}

cudaMemcpyAsync(
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);

free(input_blob);

return 0;
#endif
}

/*
Expand Down
1 change: 0 additions & 1 deletion core/plugins/impl/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ namespace impl {

class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
private:
at::TensorOptions tensor_options_;
nvinfer1::DataType dtype_;

std::vector<int64_t> in_shape_;
Expand Down
45 changes: 7 additions & 38 deletions core/plugins/impl/normalize_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,6 @@ nvinfer1::DataType NormalizePlugin::getOutputDataType(int index, const nvinfer1:
}

int NormalizePlugin::initialize() {
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
tensor_options_ = tensor_options_.device(c10::kCUDA);
#else
tensor_options_ = tensor_options_.device(c10::kCPU);
#endif

// c10::kFloat = FLOAT32
tensor_options_ = tensor_options_.dtype(c10::kFloat);

return 0;
}

Expand Down Expand Up @@ -181,11 +172,10 @@ int NormalizePlugin::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) {
// TRT <= 7.0
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
at::Tensor output = at::from_blob(
outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);
at::Tensor input =
at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
at::Tensor output =
at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);

at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
Expand All @@ -195,7 +185,9 @@ int NormalizePlugin::enqueue(
cudaEventRecord(event, stream);

cudaStreamWaitEvent(torch_stream.stream(), event, 0);
at::Tensor result = at::norm(input, order_, axes_, keep_dims_);

std::vector<int64_t> axes_double(axes_.begin(), axes_.end());
at::Tensor result = at::norm(input, (int64_t)order_, axes_double, (bool)keep_dims_);
output.copy_(result);
cudaEvent_t torch_event;
cudaEventCreate(&torch_event);
Expand All @@ -206,29 +198,6 @@ int NormalizePlugin::enqueue(
cudaEventDestroy(event);
cudaEventDestroy(torch_event);
return 0;
#else
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
// Tensor in the context of TensorRT execution
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
cudaMemcpyAsync(
input_blob,
static_cast<const void*>(inputs[0]),
util::volume(inputDesc->dims) * sizeof(float),
cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
std::vector<int64_t> axes_new(axes_.begin(), axes_.end());
at::Tensor output = at::norm(input, (int64_t)order_, axes_new, (bool)keep_dims_);
cudaMemcpyAsync(
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);

free(input_blob);
return 0;
#endif
}

/*
Expand Down
1 change: 0 additions & 1 deletion core/plugins/impl/normalize_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace impl {

class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt {
private:
at::TensorOptions tensor_options_;
nvinfer1::DataType dtype_;
int32_t order_;
std::vector<int32_t> axes_;
Expand Down
Loading

0 comments on commit 6f4aa40

Please sign in to comment.