From d33ec82cfe797e736e3e1070442b3f0b10cf8e15 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Thu, 25 Jun 2020 16:56:52 -0700 Subject: [PATCH] test(aten::select.int): added test for aten::select.int Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- tests/core/converters/BUILD | 5 ++++ tests/core/converters/test_select.cpp | 33 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100755 tests/core/converters/test_select.cpp diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 1d81eeb552..6de31757a4 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -59,6 +59,10 @@ converter_test( name = "test_interpolate" ) +converter_test( + name = "test_select" +) + test_suite( name = "test_converters", tests = [ @@ -74,6 +78,7 @@ test_suite( ":test_softmax", ":test_unary", ":test_interpolate", + ":test_select" ] ) diff --git a/tests/core/converters/test_select.cpp b/tests/core/converters/test_select.cpp new file mode 100755 index 0000000000..652a45d6ed --- /dev/null +++ b/tests/core/converters/test_select.cpp @@ -0,0 +1,33 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : int = prim::Constant[value=3]() + %4 : Tensor = aten::select(%0, %2, %2) + %5 : Tensor = aten::select(%4, %2, %3) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file