Skip to content

Commit

Permalink
feat(interpolate): Addressing the linear, scale factor, align corners…
Browse files Browse the repository at this point in the history
… edge case

This commit adds support in some cases for the edge case when handling
torch.nn.functional.interpolate where the user is doing some form of
linear upsampling and uses scale factor to calculate the new tensor size
at runtime and they set align corners to true (as of PyTorch 1.5 this is
no longer the default behavior). This commit adds support for this
case when users chose to construct static input size engines via the
interpolate plugin which will run the function from ATen on CPU.

In the case of dynamic input shapes with these 3 conditions the
compilation will terminate with an error. The ultimate solution will be
to find the root cause of the descripancy between PyTorch and TensorRT.
Barring that we will need to use the dimension calculation primatives for
TensorRT plugins. However, there is a limitation where static values in
the computation cannot be floats which PyTorch scale factors are. Therefore
it doesn't seem possible currently to support this usecase.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jan 27, 2021
1 parent 0cda1cc commit 92e3818
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 198 deletions.
167 changes: 85 additions & 82 deletions core/conversion/converters/impl/interpolate.cpp

Large diffs are not rendered by default.

118 changes: 99 additions & 19 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,31 @@ InterpolatePlugin::InterpolatePlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners)
: in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
bool align_corners,
bool use_scales)
: in_shape_(in_shape), out_shape_(out_shape), size_(size), scales_(scales), mode_(mode), align_corners_(align_corners), use_scales_(use_scales) {
if (use_scales) {
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
TRTORCH_ASSERT(scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
at::Tensor output;

if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
std::cout << output.sizes() << std::endl;
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}

out_shape_ = output.sizes().vec();
} else {
TRTORCH_ASSERT((size_.size() != 0 && out_shape_.size() != 0), "Attempted to use interpolate plugin without providing output size while use_scales=false");
}
}

InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
std::istringstream data_stream(std::string(data, length));
Expand All @@ -42,6 +64,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("size", value);
size_ = value.toIntVector();
}
{
torch::IValue value;
input_archive.read("scales", value);
scales_ = value.toDoubleVector();
}
{
torch::IValue value;
input_archive.read("mode", value);
Expand All @@ -52,6 +79,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("align_corners", value);
align_corners_ = value.toBool();
}
{
torch::IValue value;
input_archive.read("use_scales", value);
use_scales_ = value.toBool();
}
}

std::vector<int64_t> InterpolatePlugin::getInputShape() {
Expand Down Expand Up @@ -83,7 +115,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
}

nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
}

nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
Expand All @@ -93,9 +125,27 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output(inputs[0]);

for (unsigned int i = 0; i < out_shape_.size(); i++) {
output.d[i] = exprBuilder.constant(out_shape_[i]);
}
// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true to cover
// the different implementations between PyTorch and TRT. However TRT currently does not support doubles
// for ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor
// if (use_scales_) {
// auto input_dimsexprs = inputs[0];
// output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
// if (mode_ == "linear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
// } else if (mode_ == "bilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
// output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
// } else if (mode_ == "trilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
// output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
// output.d[3] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
// }
// } else {
for (unsigned int i = 0; i < out_shape_.size(); i++) {
output.d[i] = exprBuilder.constant(out_shape_[i]);
}
//}

return output;
}
Expand Down Expand Up @@ -131,8 +181,10 @@ std::string InterpolatePlugin::serializeToString() const {
output_archive.write("in_shape", torch::IValue(in_shape_));
output_archive.write("out_shape", torch::IValue(out_shape_));
output_archive.write("size", torch::IValue(size_));
output_archive.write("scales", torch::IValue(scales_));
output_archive.write("mode", torch::IValue(mode_));
output_archive.write("align_corners", torch::IValue(align_corners_));
output_archive.write("use_scales", torch::IValue(use_scales_));

std::ostringstream data_str;
output_archive.save_to(data_str);
Expand Down Expand Up @@ -201,14 +253,24 @@ int InterpolatePlugin::enqueue(

cudaStreamWaitEvent(torch_stream.stream(), event, 0);

if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
}
} else {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
}
}

cudaEvent_t torch_event;
Expand All @@ -234,11 +296,27 @@ int InterpolatePlugin::enqueue(
stream);
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
at::Tensor output;
if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}
} else {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
}
}

cudaMemcpyAsync(
Expand Down Expand Up @@ -277,10 +355,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners) {
bool align_corners,
bool use_scales) {
name_ = name;
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
return new InterpolatePlugin(in_shape, out_shape, size, scales, mode, align_corners, use_scales);
}

nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(
Expand Down
20 changes: 13 additions & 7 deletions core/conversion/converters/impl/plugins/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape_;
std::vector<int64_t> out_shape_;
std::vector<int64_t> size_;
std::vector<double> scales_;
std::string mode_;
bool align_corners_;
bool use_scales_;

protected:
// To prevent compiler warnings
Expand All @@ -49,8 +51,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners);
bool align_corners,
bool use_scales);

InterpolatePlugin(const char* data, size_t length);

Expand Down Expand Up @@ -136,12 +140,14 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;

InterpolatePlugin* createPlugin(
const char* name,
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::string mode,
bool align_corners);
const char* name,
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners,
bool use_scales);

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;

Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ auto pooling_registrations TRTORCH_UNUSED =

auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin(
"adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
"adaptive_pool2d", in_shape, out_shape, out_size, {}, std::string("adaptive_pool2d"), false, false);

auto pooling_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
Expand Down
Loading

0 comments on commit 92e3818

Please sign in to comment.