Skip to content

Commit

Permalink
Canonicalize aten::multiply to aten::mul
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise committed Dec 2, 2022
1 parent 2b1cedf commit 43a11c9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
14 changes: 14 additions & 0 deletions core/lowering/passes/op_aliasing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph) {
rewrite_scatter.RegisterRewritePattern(scatter_sub_pattern, scatter_pattern);
rewrite_scatter.runOnGraph(graph);
LOG_GRAPH("Post map scatter_ -> scatter: " << *graph);

std::string multiply_pattern = R"IR(
graph(%self, %other):
%o : Tensor = aten::multiply(%self, %other)
return (%o))IR";
std::string mul_pattern = R"IR(
graph(%self, %other):
%o : Tensor = aten::mul(%self, %other)
return (%o))IR";

torch::jit::SubgraphRewriter rewrite_multiply;
rewrite_multiply.RegisterRewritePattern(multiply_pattern, mul_pattern);
rewrite_multiply.runOnGraph(graph);
LOG_GRAPH("Post map multiply -> mul: " << *graph);
}

} // namespace passes
Expand Down
20 changes: 20 additions & 0 deletions tests/core/lowering/test_operator_aliasing_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,23 @@ TEST(LoweringPasses, LoweringTrueDivideCorrectly) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, LoweringMultiplyCorrectly) {
std::string source_graph = R"IR(
graph(%s, %o):
%2 = aten::multiply(%s, %o)
return (%2))IR";
std::string target_graph = R"IR(
graph(%s, %o):
%2 = aten::mul(%s, %o)
return (%2))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::AliasOperators(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit 43a11c9

Please sign in to comment.