Skip to content

Commit

Permalink
feat(//core/conversion/converters): LSTMCell converter
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer authored and narendasan committed Aug 28, 2020
1 parent a3e1093 commit 8c61248
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 5 deletions.
11 changes: 7 additions & 4 deletions core/conversion/converters/impl/lstm_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ nvinfer1::ITensor* add_bias(nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::str
auto shuffle = ctx->net->addShuffle(*b);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_dim), a_dim.nbDims));

b = shuffle->getOutput(0);
}

LOG_DEBUG(b_name << "'s shape: " << b->getDimensions());

auto add = ctx->net->addElementWise(*a, *b, nvinfer1::ElementWiseOperation::kSUM);
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);

Expand Down Expand Up @@ -72,14 +75,14 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n);
auto mm1_out = mm1->getOutput(0);

auto out1 = !args[4].IValue()->isNone() ? add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n) : mm1_out;
auto out1 = (args[4].isIValue() && args[4].IValue()->isNone()) ? mm1_out : add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n);

// calculate second half of gates
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE);
TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n);
auto mm2_out = mm2->getOutput(0);

auto out2 = !args[5].IValue()->isNone() ? add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n) : mm2_out;
auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n);

// gates
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
Expand Down Expand Up @@ -130,7 +133,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n);
auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));

Expand All @@ -143,7 +146,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()

LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());

return true;
}
});
Expand Down
7 changes: 6 additions & 1 deletion tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ converter_test(
name = "test_stack"
)

converter_test(
name = "test_lstm_cell"
)

test_suite(
name = "test_converters",
tests = [
Expand All @@ -83,6 +87,7 @@ test_suite(
":test_unary",
":test_interpolate",
":test_select",
":test_stack"
":test_stack",
":test_lstm_cell"
]
)
97 changes: 97 additions & 0 deletions tests/core/converters/test_lstm_cell.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor,
%3 : Tensor,
%4 : Tensor,
%5 : Tensor,
%6 : Tensor):
%7 : Tensor[] = prim::ListConstruct(%1, %2)
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto input = at::randn({50, 10}, {at::kCUDA});
auto h0 = at::randn({50, 20}, {at::kCUDA});
auto c0 = at::randn({50, 20}, {at::kCUDA});
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});
auto b_ih = at::randn({4*20}, {at::kCUDA});
auto b_hh = at::randn({4*20}, {at::kCUDA});

auto jit_input = at::clone(input);
auto jit_h0 = at::clone(h0);
auto jit_c0 = at::clone(c0);
auto jit_w_ih = at::clone(w_ih);
auto jit_w_hh = at::clone(w_hh);
auto jit_b_ih = at::clone(b_ih);
auto jit_b_hh = at::clone(b_hh);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});

auto trt_input = at::clone(input);
auto trt_h0 = at::clone(h0);
auto trt_c0 = at::clone(c0);
auto trt_w_ih = at::clone(w_ih);
auto trt_w_hh = at::clone(w_hh);
auto trt_b_ih = at::clone(b_ih);
auto trt_b_hh = at::clone(b_hh);

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor,
%3 : Tensor,
%4 : Tensor):
%5 : None = prim::Constant()
%6 : None = prim::Constant()
%7 : Tensor[] = prim::ListConstruct(%1, %2)
%8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto input = at::randn({50, 10}, {at::kCUDA});
auto h0 = at::randn({50, 20}, {at::kCUDA});
auto c0 = at::randn({50, 20}, {at::kCUDA});
auto w_ih = at::randn({4*20, 10}, {at::kCUDA});
auto w_hh = at::randn({4*20, 20}, {at::kCUDA});

auto jit_input = at::clone(input);
auto jit_h0 = at::clone(h0);
auto jit_c0 = at::clone(c0);
auto jit_w_ih = at::clone(w_ih);
auto jit_w_hh = at::clone(w_hh);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});

auto trt_input = at::clone(input);
auto trt_h0 = at::clone(h0);
auto trt_c0 = at::clone(c0);
auto trt_w_ih = at::clone(w_ih);
auto trt_w_hh = at::clone(w_hh);

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit 8c61248

Please sign in to comment.