Skip to content

Commit

Permalink
feat: replace view with reshape during lowering
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Nov 25, 2021
1 parent 09afccb commit d39b918
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::ReduceToOperation(g);
passes::ReduceGelu(g);
passes::RemoveContiguous(g);
passes::ViewToReshape(g);
passes::RemoveDropout(g);
passes::LinearToAddMM(g);
passes::Conv1DToConvolution(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cc_library(
"reduce_gelu.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"view_to_reshape.cpp",
"remove_dropout.cpp",
"remove_nops.cpp",
"silu_to_sigmoid_multiplication.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
31 changes: 31 additions & 0 deletions core/lowering/passes/view_to_reshape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph) {
std::string view_pattern = R"IR(
graph(%x, %1):
%out : Tensor = aten::view(%x, %1)
return (%out))IR";

std::string reshape_pattern = R"IR(
graph(%x, %1):
%out : Tensor = aten::reshape(%x, %1)
return (%out))IR";

// replace aten::view with aten::reshape
torch::jit::SubgraphRewriter map_view_to_reshape;
map_view_to_reshape.RegisterRewritePattern(view_pattern, reshape_pattern);
map_view_to_reshape.runOnGraph(graph);

LOG_GRAPH("Post lowering of aten::view -> " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ lowering_test(
name = "test_remove_detach_pass",
)

lowering_test(
name = "test_view_to_reshape_pass",
)

lowering_test(
name = "test_operator_aliasing_pass",
)
Expand All @@ -75,6 +79,7 @@ test_suite(
":test_operator_aliasing_pass",
":test_remove_contiguous_pass",
":test_remove_detach_pass",
":test_view_to_reshape_pass",
":test_remove_dropout_pass",
":test_reduce_to_pass",
":test_reduce_gelu",
Expand Down
35 changes: 35 additions & 0 deletions tests/core/lowering/test_view_to_reshape_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"

TEST(LoweringPasses, ViewToReshapeCorrectly) {
std::string source_graph = R"IR(
graph(%x : Tensor, %1, %1.1):
%0 : int = prim::Constant[value=0]()
%2 : Tensor = aten::permute(%x, %1)
%3 : Tensor = aten::contiguous(%2, %0)
%4 : Tensor = aten::view(%3, %1.1)
return (%4))IR";
std::string target_graph = R"IR(
graph(%x : Tensor, %1, %1.1):
%0 : int = prim::Constant[value=0]()
%2 : Tensor = aten::permute(%x, %1)
%4 : Tensor = aten::reshape(%2, %1.1)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
torch_tensorrt::core::lowering::passes::ViewToReshape(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit d39b918

Please sign in to comment.