Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending @uni19's work on support dynamic shape input and scale_factor in interpolate layer #295

Merged
merged 5 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
641 changes: 575 additions & 66 deletions core/conversion/converters/impl/interpolate.cpp

Large diffs are not rendered by default.

123 changes: 107 additions & 16 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,40 @@ InterpolatePlugin::InterpolatePlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners)
: in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
bool align_corners,
bool use_scales)
: in_shape_(in_shape),
out_shape_(out_shape),
size_(size),
scales_(scales),
mode_(mode),
align_corners_(align_corners),
use_scales_(use_scales) {
if (use_scales) {
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
TRTORCH_ASSERT(
scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
at::Tensor output;

if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
std::cout << output.sizes() << std::endl;
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}

out_shape_ = output.sizes().vec();
} else {
TRTORCH_ASSERT(
(size_.size() != 0 && out_shape_.size() != 0),
"Attempted to use interpolate plugin without providing output size while use_scales=false");
}
}

InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
std::istringstream data_stream(std::string(data, length));
Expand All @@ -42,6 +73,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("size", value);
size_ = value.toIntVector();
}
{
torch::IValue value;
input_archive.read("scales", value);
scales_ = value.toDoubleVector();
}
{
torch::IValue value;
input_archive.read("mode", value);
Expand All @@ -52,6 +88,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
input_archive.read("align_corners", value);
align_corners_ = value.toBool();
}
{
torch::IValue value;
input_archive.read("use_scales", value);
use_scales_ = value.toBool();
}
}

std::vector<int64_t> InterpolatePlugin::getInputShape() {
Expand Down Expand Up @@ -83,7 +124,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
}

nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
}

nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
Expand All @@ -93,9 +134,30 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output(inputs[0]);

// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true
// to cover the different implementations between PyTorch and TRT. However TRT currently does not support doubles for
// ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor if
// (use_scales_) {
// auto input_dimsexprs = inputs[0];
// output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
// if (mode_ == "linear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1]));
// } else if (mode_ == "bilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
// } else if (mode_ == "trilinear") {
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2])); output.d[3] =
// exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
// }
// } else {
for (unsigned int i = 0; i < out_shape_.size(); i++) {
output.d[i] = exprBuilder.constant(out_shape_[i]);
}
//}

return output;
}
Expand Down Expand Up @@ -131,8 +193,10 @@ std::string InterpolatePlugin::serializeToString() const {
output_archive.write("in_shape", torch::IValue(in_shape_));
output_archive.write("out_shape", torch::IValue(out_shape_));
output_archive.write("size", torch::IValue(size_));
output_archive.write("scales", torch::IValue(scales_));
output_archive.write("mode", torch::IValue(mode_));
output_archive.write("align_corners", torch::IValue(align_corners_));
output_archive.write("use_scales", torch::IValue(use_scales_));

std::ostringstream data_str;
output_archive.save_to(data_str);
Expand Down Expand Up @@ -201,14 +265,24 @@ int InterpolatePlugin::enqueue(

cudaStreamWaitEvent(torch_stream.stream(), event, 0);

if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
}
} else {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
}
}

cudaEvent_t torch_event;
Expand All @@ -235,10 +309,25 @@ int InterpolatePlugin::enqueue(
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);

at::Tensor output;
if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
if (use_scales_) {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
}
} else {
if (mode_ == "linear") {
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
} else if (mode_ == "bilinear") {
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
} else if (mode_ == "trilinear") {
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode_ == "adaptive_pool2d") {
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
}
}

cudaMemcpyAsync(
Expand Down Expand Up @@ -277,10 +366,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners) {
bool align_corners,
bool use_scales) {
name_ = name;
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
return new InterpolatePlugin(in_shape, out_shape, size, scales, mode, align_corners, use_scales);
}

nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(
Expand Down
10 changes: 8 additions & 2 deletions core/conversion/converters/impl/plugins/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape_;
std::vector<int64_t> out_shape_;
std::vector<int64_t> size_;
std::vector<double> scales_;
std::string mode_;
bool align_corners_;
bool use_scales_;

protected:
// To prevent compiler warnings
Expand All @@ -49,8 +51,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners);
bool align_corners,
bool use_scales);

InterpolatePlugin(const char* data, size_t length);

Expand Down Expand Up @@ -140,8 +144,10 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
std::vector<int64_t> in_shape,
std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::vector<double> scales,
std::string mode,
bool align_corners);
bool align_corners,
bool use_scales);

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;

Expand Down
9 changes: 8 additions & 1 deletion core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,14 @@ auto pooling_registrations TRTORCH_UNUSED =

auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin(
"adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
"adaptive_pool2d",
in_shape,
out_shape,
out_size,
{},
std::string("adaptive_pool2d"),
false,
false);

auto pooling_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
Expand Down
Loading