Skip to content

Commit

Permalink
feat(//core/conversion/converters/impl/shuffle): Implement aten::resize
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 29, 2020
1 parent a51c7b6 commit 353f2d2
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 8 deletions.
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_library(
"impl/linear.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/shuffle.cpp",
"impl/softmax.cpp",
"impl/unary.cpp",
],
Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "ATen/core/function_schema.h"

#include "core/util/prelude.h"
#include "core/conversion/conversionctx/ConversionCtx.h"

namespace trtorch {
Expand Down
33 changes: 33 additions & 0 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "core/conversion/converters/converters.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

static auto shuffle_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(new_shape);
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
});
} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
10 changes: 2 additions & 8 deletions core/conversion/converters/impl/softmax.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"

namespace trtorch {
Expand Down Expand Up @@ -29,12 +28,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns()
auto softmax = ctx->net->addSoftMax(*in);

TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n);

if (!softmax) {
LOG_ERROR("Unable to create softmax layer from node: " << *n);
return false;
}
LOG_WARNING("Disregarding dtype argument, please verify");
LOG_DEBUG("Disregarding dtype argument");

if (shape.size() > 3) {
softmax->setAxes(1 << (dim));
Expand Down Expand Up @@ -69,4 +63,4 @@ static auto softmax_registrations = RegisterNodeConversionPatterns()
} // namespace converters
} // namespace conversion
} // namespace core
} // trtorch
} // namespace trtorch
23 changes: 23 additions & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ nvinfer1::Dims toDims(c10::List<int64_t> l) {
return dims;
}

nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
if (l.size() > pad_to) {
LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad");
return toDims(l);
}

if (pad_to > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}

nvinfer1::Dims dims;
dims.nbDims = pad_to;
for (size_t i = 0; i < pad_to - l.size(); i++) {
dims.d[i] = 1;
}

for (size_t i = pad_to - l.size(); i < pad_to; i++) {
dims.d[i] = l[i - (pad_to - l.size())];
}
return dims;
}

std::vector<int64_t> toVec(nvinfer1::Dims d) {
std::vector<int64_t> dims;
for (int i = 0; i < d.nbDims; i++) {
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ namespace util {
int64_t volume(const nvinfer1::Dims& d);

nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
nvinfer1::Dims toDims(c10::IntArrayRef l);
nvinfer1::Dims toDims(c10::List<int64_t> l);
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);
Expand Down
5 changes: 5 additions & 0 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ converter_test(
name = "test_softmax"
)

converter_test(
name = "test_shuffle"
)

converter_test(
name = "test_activation"
)
Expand Down Expand Up @@ -36,6 +40,7 @@ test_suite(
name = "test_converters",
tests = [
":test_softmax",
":test_shuffle",
":test_activation",
":test_pooling",
":test_unary",
Expand Down
29 changes: 29 additions & 0 deletions tests/core/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenReshapeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::reshape(%0, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit 353f2d2

Please sign in to comment.