Skip to content

Commit

Permalink
feat(//core/lowering): Remove aten::contiguous
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 25, 2020
1 parent a93e783 commit 630b615
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 6 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::EliminateExceptionOrPassPattern(g);
torch::jit::FuseLinear(g);
torch::jit::LowerAllTuples(g);
passes::RemoveContiguous(g);
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::Conv2DToConvolution(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"conv2d_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_flatten_linear.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
Expand Down
1 change: 0 additions & 1 deletion core/lowering/passes/fuse_flatten_linear.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace passes {

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
30 changes: 30 additions & 0 deletions core/lowering/passes/remove_contiguous.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph) {
std::string contiguous_pattern = R"IR(
graph(%input, %1):
%2 = aten::contiguous(%input, %1)
return (%2))IR";
std::string no_contiguous_pattern = R"IR(
graph(%input, %1):
return (%input))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter remove_contiguous;
remove_contiguous.RegisterRewritePattern(
contiguous_pattern, no_contiguous_pattern);
remove_contiguous.runOnGraph(graph);
LOG_GRAPH("Post remove contiguous: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
1 change: 0 additions & 1 deletion core/lowering/passes/remove_dropout.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"
Expand Down
1 change: 0 additions & 1 deletion core/lowering/passes/unpack_addmm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
Expand Down
3 changes: 0 additions & 3 deletions core/lowering/passes/unpack_log_softmax.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
Expand All @@ -14,8 +13,6 @@ void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
// https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593
// Should the reshapes be added here or in the converter?

// TODO: In the future this should be removed for a deicated log_softmax converter (more efficent)
// But its easier to stand up a working system if the number of op converters is lower
std::string logsoftmax_pattern = R"IR(
graph(%input, %dim, %dtype):
%log_softmax = aten::log_softmax(%input, %dim, %dtype)
Expand Down

0 comments on commit 630b615

Please sign in to comment.