From 723ac1d812f1f2ed98424fd0cb3e2630eef47d31 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Tue, 4 Aug 2020 19:22:25 -0700 Subject: [PATCH] fix(): added some fixes, trt/jit output still mismatches Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/conversion/converters/impl/lstm_cell.cpp | 9 +++-- core/lowering/lowering.cpp | 3 +- core/lowering/passes/BUILD | 1 + .../lowering/passes/conv3d_to_convolution.cpp | 33 +++++++++++++++++++ core/lowering/passes/passes.h | 1 + 5 files changed, 43 insertions(+), 4 deletions(-) create mode 100755 core/lowering/passes/conv3d_to_convolution.cpp diff --git a/core/conversion/converters/impl/lstm_cell.cpp b/core/conversion/converters/impl/lstm_cell.cpp index f83fff475a..f8ae878e20 100755 --- a/core/conversion/converters/impl/lstm_cell.cpp +++ b/core/conversion/converters/impl/lstm_cell.cpp @@ -84,7 +84,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n); - // gates + // get all 4 gates auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM); TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n); auto add_out = add->getOutput(0); @@ -135,14 +135,17 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n); auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n); - auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0)); + auto cy_out = cy->getOutput(0); // compute hy auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH); TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n); auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n); - auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0)); + auto hy_out = hy->getOutput(0); + + ctx->AssociateValueAndTensor(n->outputs()[0], hy_out); + ctx->AssociateValueAndTensor(n->outputs()[1], cy_out); LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions()); LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions()); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index eea21a265b..4b7a1db0c8 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -32,9 +32,10 @@ void LowerGraph(std::shared_ptr& g) { passes::RemoveDropout(g); passes::FuseFlattenLinear(g); passes::Conv2DToConvolution(g); + passes::Conv3DToConvolution(g); passes::FuseAddMMBranches(g); torch::jit::EliminateCommonSubexpression(g); - torch::jit::UnrollLoops(g); + //torch::jit::UnrollLoops(g); torch::jit::EliminateCommonSubexpression(g); passes::UnpackAddMM(g); //passes::UnpackBatchNorm(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 67cebf2502..be6f3fcf42 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -14,6 +14,7 @@ cc_library( ], srcs = [ "conv2d_to_convolution.cpp", + "conv3d_to_convolution.cpp", "exception_elimination.cpp", "fuse_addmm_branches.cpp", "fuse_flatten_linear.cpp", diff --git a/core/lowering/passes/conv3d_to_convolution.cpp b/core/lowering/passes/conv3d_to_convolution.cpp new file mode 100755 index 0000000000..62df37896a --- /dev/null +++ b/core/lowering/passes/conv3d_to_convolution.cpp @@ -0,0 +1,33 @@ +#include + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void Conv3DToConvolution(std::shared_ptr& graph) { + std::string conv3d_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g) + return (%4))IR"; + std::string convolution_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %1 : bool = prim::Constant[value=0]() + %2 : int[] = prim::Constant[value=[0, 0]]() + %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1) + return (%4))IR";; + + // replace matmul + add pattern to linear + torch::jit::SubgraphRewriter map_conv3d_to_convolution; + map_conv3d_to_convolution.RegisterRewritePattern( + conv3d_pattern, convolution_pattern); + map_conv3d_to_convolution.runOnGraph(graph); + LOG_GRAPH("Post map conv3d -> _convolution: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 9fe21b918b..b3979c58e7 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -8,6 +8,7 @@ namespace lowering { namespace passes { void Conv2DToConvolution(std::shared_ptr& graph); +void Conv3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); void FuseFlattenLinear(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph);