Skip to content

Commit

Permalink
fix(): added test cases to explicitly check hidden/cell state outputs
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer authored and narendasan committed Aug 28, 2020
1 parent 8c61248 commit d7c3164
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions tests/core/converters/test_lstm_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckHidden) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
Expand Down Expand Up @@ -53,7 +53,56 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor,
%3 : Tensor,
%4 : Tensor,
%5 : Tensor,
%6 : Tensor):
%7 : Tensor[] = prim::ListConstruct(%1, %2)
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
return (%9))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto input = at::randn({50, 10}, {at::kCUDA});
auto h0 = at::randn({50, 20}, {at::kCUDA});
auto c0 = at::randn({50, 20}, {at::kCUDA});
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
auto b_ih = at::randn({4*20}, {at::kCUDA});
auto b_hh = at::randn({4*20}, {at::kCUDA});

auto jit_input = at::clone(input);
auto jit_h0 = at::clone(h0);
auto jit_c0 = at::clone(c0);
auto jit_w_ih = at::clone(w_ih);
auto jit_w_hh = at::clone(w_hh);
auto jit_b_ih = at::clone(b_ih);
auto jit_b_hh = at::clone(b_hh);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});

auto trt_input = at::clone(input);
auto trt_h0 = at::clone(h0);
auto trt_c0 = at::clone(c0);
auto trt_w_ih = at::clone(w_ih);
auto trt_w_hh = at::clone(w_hh);
auto trt_b_ih = at::clone(b_ih);
auto trt_b_hh = at::clone(b_hh);

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
Expand Down Expand Up @@ -93,5 +142,48 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor,
%3 : Tensor,
%4 : Tensor):
%5 : None = prim::Constant()
%6 : None = prim::Constant()
%7 : Tensor[] = prim::ListConstruct(%1, %2)
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
return (%9))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto input = at::randn({50, 10}, {at::kCUDA});
auto h0 = at::randn({50, 20}, {at::kCUDA});
auto c0 = at::randn({50, 20}, {at::kCUDA});
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});

auto jit_input = at::clone(input);
auto jit_h0 = at::clone(h0);
auto jit_c0 = at::clone(c0);
auto jit_w_ih = at::clone(w_ih);
auto jit_w_hh = at::clone(w_hh);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});

auto trt_input = at::clone(input);
auto trt_h0 = at::clone(h0);
auto trt_c0 = at::clone(c0);
auto trt_w_ih = at::clone(w_ih);
auto trt_w_hh = at::clone(w_hh);

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit d7c3164

Please sign in to comment.