Skip to content

Commit

Permalink
feat(): finished logic for LSTM cell, now to test
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer authored and narendasan committed Aug 28, 2020
1 parent 546d790 commit a88cfaf
Showing 1 changed file with 50 additions and 9 deletions.
59 changes: 50 additions & 9 deletions core/conversion/converters/impl/lstm_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,57 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n);
auto add3_out = add3->getOutput(0);





auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
// chunk Tensor into 4 parts and apply activation functions
auto dims = util::toVec(add3_out->getDimensions());
auto batch = dims[0];
auto hidden = dims[1]/4;

auto size = util::toDims(std::vector<int64_t>({batch, hidden}));
auto stride = util::toDims(std::vector<int64_t>({1, 1}));

auto slice1 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 0})), size, stride);
TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n);
auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n);
auto ingate = activ1->getOutput(0);

auto slice2 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, hidden})), size, stride);
TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n);
auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n);
auto forgetgate = activ2->getOutput(0);

auto slice3 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 2*hidden})), size, stride);
TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n);
auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH);
TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n);
auto cellgate = activ3->getOutput(0);

auto slice4 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 3*hidden})), size, stride);
TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n);
auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n);
auto outgate = activ4->getOutput(0);

// compute cy
auto forget_cx = ctx->net->addElementWise(*forgetgate, *state[1], nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n);
auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));

// compute hy
auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH);
TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n);
auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n);
auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0));

LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}
});
Expand Down

0 comments on commit a88cfaf

Please sign in to comment.