Skip to content

Commit

Permalink
fix(//core/conversion/converters/impl): 1d case not working
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>
Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 8, 2020
1 parent e4cb117 commit f42562b
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ auto interpolate_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node*n, args& args) -> bool {
std::cout << "GOT IN HERE!!!!!!" << std::endl;

auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());

// remove padding that TensorRt adds automatically
if (in_shape.size() >= 4) {
// remove first dimension
in_shape.erase(in_shape.begin());

auto shuffle = ctx->net->addShuffle(*in);
shuffle->setReshapeDimensions(util::toDims(in_shape));
shuffle->setName( (util::node_info(n) + " [Reshape to " + util::toStr(util::toDims(in_shape)) + "]").c_str() );
in = shuffle->getOutput(0);
}

// Case 1: user uses output size and not scales
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
Expand All @@ -28,9 +37,6 @@ auto interpolate_registrations = RegisterNodeConversionPatterns()

auto out_shape = in_shape;
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));

// remove padding that TensorRT adds automatically
// out_shape.erase(out_shape.begin(), out_shape.begin()+1);

auto resize_layer = ctx->net->addResize(*in);
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
Expand All @@ -39,8 +45,30 @@ auto interpolate_registrations = RegisterNodeConversionPatterns()
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
resize_layer->setName(util::node_info(n).c_str());

// auto out_tensor = resize_layer->getOutput(0);
// out_shape.erase(out_shape.begin());
// auto shuffle = ctx->net->addShuffle(*out_tensor);
// shuffle->setReshapeDimensions(util::toDims(out_shape));
// shuffle->setName( (util::node_info(n) + " [Reshape to " + util::toStr(util::toDims(out_shape)) + "]").c_str() );
// auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
// LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());

// std::cout << "PRINTING STUFF AT THE END!" << std::endl;
// auto final = util::toVec(shuffle->getOutput(0)->getDimensions());
// for (auto iter = final.begin(); iter != final.end(); iter++) {
// std::cout << *iter << std::endl;
// }

//std::raise(SIGABRT);

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());

// std::cout << "PRINTING STUFF AT THE END!" << std::endl;
// auto final = util::toVec(resize_layer->getOutput(0)->getDimensions());
// for (auto iter = final.begin(); iter != final.end(); iter++) {
// std::cout << *iter << std::endl;
// }
} else {
LOG_DEBUG("scale factor parameter not supported yet.");
}
Expand Down

0 comments on commit f42562b

Please sign in to comment.