Skip to content

Commit

Permalink
feat(//core/conversion/converters/impl/reduce): Mean reduce converter
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 24, 2020
1 parent 79c909c commit 259aa4c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"impl/element_wise.cpp",
"impl/linear.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/softmax.cpp",
"impl/unary.cpp",
],
Expand Down
38 changes: 38 additions & 0 deletions core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {
auto reduced_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensor();
auto dim = args[1].unwrapToIntList();
auto keepdim = args[2].unwrapToBool();

uint32_t axis_mask = 1 << dim[0];

LOG_WARNING("Mean converter disregards dtype");
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim);
mean_layer->setName(util::node_info(n).c_str());

auto out_value = n->outputs()[0];
auto out_tensor = mean_layer->getOutput(0);
out_tensor->setName(out_value->debugName().c_str());
ctx->value_tensor_map[out_value] = out_tensor;

return true;
}
});
} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch

5 changes: 5 additions & 0 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ converter_test(
name = "test_conv"
)

converter_test(
name = "test_reduce"
)

test_suite(
name = "test_converters",
tests = [
Expand All @@ -38,6 +42,7 @@ test_suite(
":test_linear",
":test_element_wise",
":test_conv",
":test_reduce"
]
)

Expand Down
53 changes: 53 additions & 0 deletions tests/core/converters/test_reduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenMeanConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=0]()
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::script::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::script::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

0 comments on commit 259aa4c

Please sign in to comment.