Skip to content

Commit

Permalink
fix(): added some fixes, trt/jit output still mismatches
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer authored and narendasan committed Aug 28, 2020
1 parent d7c3164 commit 723ac1d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 4 deletions.
9 changes: 6 additions & 3 deletions core/conversion/converters/impl/lstm_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()

auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n);

// gates
// get all 4 gates
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
auto add_out = add->getOutput(0);
Expand Down Expand Up @@ -135,14 +135,17 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));
auto cy_out = cy->getOutput(0);

// compute hy
auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH);
TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n);
auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n);
auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0));
auto hy_out = hy->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], hy_out);
ctx->AssociateValueAndTensor(n->outputs()[1], cy_out);

LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());
Expand Down
3 changes: 2 additions & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::Conv2DToConvolution(g);
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
torch::jit::EliminateCommonSubexpression(g);
torch::jit::UnrollLoops(g);
//torch::jit::UnrollLoops(g);
torch::jit::EliminateCommonSubexpression(g);
passes::UnpackAddMM(g);
//passes::UnpackBatchNorm(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(
],
srcs = [
"conv2d_to_convolution.cpp",
"conv3d_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
Expand Down
33 changes: 33 additions & 0 deletions core/lowering/passes/conv3d_to_convolution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv3d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1)
return (%4))IR";;

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
map_conv3d_to_convolution.RegisterRewritePattern(
conv3d_pattern, convolution_pattern);
map_conv3d_to_convolution.runOnGraph(graph);
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace lowering {
namespace passes {

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
Expand Down

0 comments on commit 723ac1d

Please sign in to comment.