Skip to content

Commit

Permalink
fix(): trying to resolve interpolate plugin problems
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 16, 2020
1 parent 58dbaef commit f0fefaa
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 206 deletions.
2 changes: 1 addition & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ config_setting(
cc_library(
name = "converters",
hdrs = [
"converters.h",
"converters.h"
],
srcs = [
"NodeConverterRegistry.cpp",
Expand Down
33 changes: 24 additions & 9 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "torch/torch.h"
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"
#include "NvInfer.h"
#include "plugins/interpolate_plugin.h"

#include <csignal>

Expand Down Expand Up @@ -108,7 +110,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());

bool align_corners = args[2].IValue()->to<bool>();
bool align_corners = args[2].unwrapToBool();

// Case 1: user uses output size and not scales
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
Expand All @@ -119,16 +121,29 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto out_shape = in_shape;
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));

auto resize_layer = ctx->net->addResize(*in);
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
if (!align_corners) {
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
//auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
auto creator = new plugins::InterpolatePluginCreator();

resize_layer->setOutputDimensions(util::toDims(out_shape));
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
resize_layer->setAlignCorners(align_corners);
resize_layer->setName(util::node_info(n).c_str());
auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(in), 1, *plugin);

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
} else {
auto resize_layer = ctx->net->addResize(*in);
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);

resize_layer->setOutputDimensions(util::toDims(out_shape));
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
resize_layer->setAlignCorners(align_corners);
resize_layer->setName(util::node_info(n).c_str());

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
}
} else {
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
}
Expand Down
6 changes: 4 additions & 2 deletions core/conversion/converters/impl/plugins/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ config_setting(

cc_library(
name = "plugins",
hdrs = [],
hdrs = [
"interpolate_plugin.h"
],
srcs = [
"interpolate_plugin.cpp"
],
Expand All @@ -29,5 +31,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
pkg_tar(
name = "include",
package_dir = "core/conversion/converters/impl/plugins",
srcs = [],
srcs = ["interpolate_plugin.h"],
)
Loading

0 comments on commit f0fefaa

Please sign in to comment.