Skip to content

Commit

Permalink
feat(//core/conversion/converters/impl): Non dimensional reduce
Browse files Browse the repository at this point in the history
converter

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 27, 2020
1 parent de8659b commit ccab7b9
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 8 deletions.
80 changes: 78 additions & 2 deletions core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@ namespace impl {
namespace {
auto reduced_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensor();
auto in_dims = util::toVec(in_tensor->getDimensions());
LOG_WARNING("Mean Converter disregards dtype");

uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1);

auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, false);

TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);

mean_layer->setName(util::node_info(n).c_str());
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
return true;
}
}).pattern({
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensor();
Expand All @@ -23,7 +40,7 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);

mean_layer->setName(util::node_info(n).c_str());
associate_value_and_tensor(ctx, n->outputs()[0], mean_layer->getOutput(0));
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
return true;
}
});
Expand All @@ -32,5 +49,64 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
} // namespace trtorch

// #include "core/util/prelude.h"
// #include "core/conversion/converters/converters.h"

// namespace trtorch {
// namespace core {
// namespace conversion {
// namespace converters {
// namespace impl {
// namespace {

// #define convert(unary, trt_type) \
// auto unary##_registrations TRTORCH_UNUSED = \
// RegisterNodeConversionPatterns().pattern( \
// {"aten::" #unary "(Tensor self) -> Tensor", \
// [](ConversionCtx *ctx, const torch::jit::Node *n, \
// args &args) -> bool { \
// auto in = args[0].ITensor(); \
// auto unary = \
// ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
// \
// TRTORCH_CHECK( \
// unary, \
// "Unable to create " #unary " layer from node: " << *n); \
// \
// unary->setName(util::node_info(n).c_str()); \
// auto out_tensor = ctx->AssociateValueAndTensor( \
// n->outputs()[0], \
// unary->getOutput(0)); \
// LOG_DEBUG( \
// "Output tensor shape: " << out_tensor->getDimensions()); \
// \
// return true; \
// }});

// convert(cos, kCOS);
// convert(acos, kACOS);
// convert(cosh, kCOSH);
// convert(sin, kSIN);
// convert(asin, kASIN);
// convert(sinh, kSINH);
// convert(tan, kTAN);
// convert(atan, kATAN);
// convert(abs, kABS);
// convert(floor, kFLOOR);
// convert(reciprocal, kRECIP);
// convert(log, kLOG);
// convert(ceil, kCEIL);
// convert(sqrt, kSQRT);
// convert(exp, kEXP);
// convert(neg, kNEG);

// #undef convert

// } // namespace
// } // namespace impl
// } // namespace converters
// } // namespace conversion
// } // namespace core
// } // namespace trtorch
54 changes: 48 additions & 6 deletions tests/core/converters/test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,55 @@
#include "core/compiler.h"

TEST(Converters, ATenMeanConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::script::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenMeanHigherDimensionConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::script::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {4, 4, 4, 4}, at::kCUDA);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenMeanRowConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int[] = prim::ListConstruct(%1)
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=0]()
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %2, %3, %4)
%5 : Tensor = aten::mean(%0, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -24,18 +66,18 @@ TEST(Converters, ATenMeanConvertsCorrectly) {
in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int[] = prim::ListConstruct(%1)
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::mean(%0, %2, %3, %4)
%5 : Tensor = aten::mean(%0, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -48,6 +90,6 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

0 comments on commit ccab7b9

Please sign in to comment.