Skip to content

Commit

Permalink
test(aten::stack): Added test for aten::stack
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 26, 2020
1 parent 6659b44 commit 97c8f52
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ converter_test(
name = "test_select"
)

converter_test(
name = "test_stack"
)

test_suite(
name = "test_converters",
tests = [
Expand All @@ -78,7 +82,8 @@ test_suite(
":test_softmax",
":test_unary",
":test_interpolate",
":test_select"
":test_select",
":test_stack"
]
)

Expand Down
53 changes: 53 additions & 0 deletions tests/core/converters/test_stack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=3]()
%4 : Tensor = aten::stack(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 4, 4)):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=1]()
%4 : Tensor = aten::stack(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});

params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit 97c8f52

Please sign in to comment.