Skip to content

Commit

Permalink
Merge pull request #295 from NVIDIA/dynamic_interpolation
Browse files Browse the repository at this point in the history
Extending @uni19's work on support dynamic shape input and scale_factor in interpolate layer
  • Loading branch information
narendasan authored Jan 28, 2021
2 parents 08b2455 + 1781f25 commit 8fb390d
Show file tree
Hide file tree
Showing 6 changed files with 1,072 additions and 352 deletions.
641 changes: 575 additions & 66 deletions core/conversion/converters/impl/interpolate.cpp

Large diffs are not rendered by default.

123 changes: 107 additions & 16 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,40 @@ InterpolatePlugin::InterpolatePlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners)
: in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
bool align_corners,
bool use_scales)
: in_shape_(in_shape),
out_shape_(out_shape),
size_(size),
scales_(scales),
mode_(mode),
align_corners_(align_corners),
use_scales_(use_scales) {
if (use_scales) {
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
TRTORCH_ASSERT(
scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
at::Tensor output;

if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
std::cout << output.sizes() << std::endl;
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}

out_shape_ = output.sizes().vec();
} else {
TRTORCH_ASSERT(
(size_.size() != 0 && out_shape_.size() != 0),
"Attempted to use interpolate plugin without providing output size while use_scales=false");
}
}

InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
std::istringstream data_stream(std::string(data, length));
Expand All @@ -42,6 +73,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("size", value);
size_ = value.toIntVector();
}
{
torch::IValue value;
input_archive.read("scales", value);
scales_ = value.toDoubleVector();
}
{
torch::IValue value;
input_archive.read("mode", value);
Expand All @@ -52,6 +88,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("align_corners", value);
align_corners_ = value.toBool();
}
{
torch::IValue value;
input_archive.read("use_scales", value);
use_scales_ = value.toBool();
}
}

std::vector<int64_t> InterpolatePlugin::getInputShape() {
Expand Down Expand Up @@ -83,7 +124,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
}

nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
}

nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
Expand All @@ -93,9 +134,30 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output(inputs[0]);

// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true
// to cover the different implementations between PyTorch and TRT. However TRT currently does not support doubles for
// ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor if
// (use_scales_) {
// auto input_dimsexprs = inputs[0];
// output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
// if (mode_ == "linear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1]));
// } else if (mode_ == "bilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
// } else if (mode_ == "trilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2])); output.d[3] =
// exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
// }
// } else {
for (unsigned int i = 0; i < out_shape_.size(); i++) {
output.d[i] = exprBuilder.constant(out_shape_[i]);
}
//}

return output;
}
Expand Down Expand Up @@ -131,8 +193,10 @@ std::string InterpolatePlugin::serializeToString() const {
output_archive.write("in_shape", torch::IValue(in_shape_));
output_archive.write("out_shape", torch::IValue(out_shape_));
output_archive.write("size", torch::IValue(size_));
output_archive.write("scales", torch::IValue(scales_));
output_archive.write("mode", torch::IValue(mode_));
output_archive.write("align_corners", torch::IValue(align_corners_));
output_archive.write("use_scales", torch::IValue(use_scales_));

std::ostringstream data_str;
output_archive.save_to(data_str);
Expand Down Expand Up @@ -201,14 +265,24 @@ int InterpolatePlugin::enqueue(

cudaStreamWaitEvent(torch_stream.stream(), event, 0);

if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
}
} else {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
}
}

cudaEvent_t torch_event;
Expand All @@ -235,10 +309,25 @@ int InterpolatePlugin::enqueue(
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);

at::Tensor output;
if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}
} else {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
}
}

cudaMemcpyAsync(
Expand Down Expand Up @@ -277,10 +366,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners) {
bool align_corners,
bool use_scales) {
name_ = name;
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
return new InterpolatePlugin(in_shape, out_shape, size, scales, mode, align_corners, use_scales);
}

nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(
Expand Down
10 changes: 8 additions & 2 deletions core/conversion/converters/impl/plugins/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape_;
std::vector<int64_t> out_shape_;
std::vector<int64_t> size_;
std::vector<double> scales_;
std::string mode_;
bool align_corners_;
bool use_scales_;

protected:
// To prevent compiler warnings
Expand All @@ -49,8 +51,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners);
bool align_corners,
bool use_scales);

InterpolatePlugin(const char* data, size_t length);

Expand Down Expand Up @@ -140,8 +144,10 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners);
bool align_corners,
bool use_scales);

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;

Expand Down
9 changes: 8 additions & 1 deletion core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,14 @@ auto pooling_registrations TRTORCH_UNUSED =

auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin(
"adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
"adaptive_pool2d",
in_shape,
out_shape,
out_size,
{},
std::string("adaptive_pool2d"),
false,
false);

auto pooling_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
Expand Down
Loading

0 comments on commit 8fb390d

Please sign in to comment.