Skip to content

Commit

Permalink
feat(aten::flatten): Adds a converter for aten flatten since MM is the
Browse files Browse the repository at this point in the history
preferred path now

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 1, 2020
1 parent 4acc3fd commit d945eb9
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 17 deletions.
54 changes: 37 additions & 17 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "core/conversion/converters/converters.h"

#include "torch/torch.h"

namespace trtorch {
namespace core {
namespace conversion {
Expand All @@ -8,23 +10,41 @@ namespace impl {
namespace {

static auto shuffle_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(new_shape);
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
});
.pattern({
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto start_dim = args[1].unwrapToInt();
auto end_dim = args[2].unwrapToInt();
auto in_shape = util::toVec(in->getDimensions());
auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes();

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(out_shape));
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}
}).pattern({
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(new_shape);
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
});
} // namespace
} // namespace impl
} // namespace converters
Expand Down
48 changes: 48 additions & 0 deletions tests/core/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,54 @@
#include "tests/util/util.h"
#include "core/compiler.h"

// TODO: IR Parser doesnt work well with neg numbers
TEST(Converters, ATenFlattenConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::flatten(%0, %1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

// TODO: IR Parser doesnt work well with neg numbers
TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::flatten(%0, %1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenReshapeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
Expand Down

0 comments on commit d945eb9

Please sign in to comment.