Skip to content

Commit

Permalink
fix: Implement a patch for gelu schema change in older NGC containers
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Feb 2, 2022
1 parent d6694db commit 9ee3a04
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
34 changes: 33 additions & 1 deletion core/lowering/passes/reduce_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ namespace passes {

void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
std::string gelu_pattern = R"IR(
graph(%x):
graph(%x : Tensor):
%out : Tensor = aten::gelu(%x)
return (%out))IR";

// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
// an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
std::string gelu_approximate_pattern = R"IR(
graph(%x : Tensor, %approx):
%out : Tensor = aten::gelu(%x, %approx)
return (%out))IR";

std::string gelu_reduce_pattern = R"IR(
graph(%x.1 : Tensor):
%6 : float = prim::Constant[value=0.044714999999999998]()
Expand All @@ -30,11 +37,36 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
%15 : Tensor = aten::mul(%7, %14)
return (%15))IR";

// This is same as gelu_reduce_pattern except for an additional input %approx.
// SubgraphRewriter only works as expected if the number of inputs to gelu_approximate_pattern
// and gelu_reduce_multi_input_pattern are same.
std::string gelu_reduce_multi_input_pattern = R"IR(
graph(%x.1 : Tensor, %approx):
%6 : float = prim::Constant[value=0.044714999999999998]()
%5 : float = prim::Constant[value=0.79788456080000003]()
%4 : float = prim::Constant[value=1.]()
%3 : float = prim::Constant[value=0.5]()
%2 : int = prim::Constant[value=1]()
%7 : Tensor = aten::mul(%x.1, %3)
%8 : Tensor = aten::mul(%x.1, %5)
%9 : Tensor = aten::mul(%x.1, %6)
%10 : Tensor = aten::mul(%9, %x.1)
%11 : Tensor = aten::add(%10, %4, %2)
%12 : Tensor = aten::mul(%8, %11)
%13 : Tensor = aten::tanh(%12)
%14 : Tensor = aten::add(%13, %4, %2)
%15 : Tensor = aten::mul(%7, %14)
return (%15))IR";

// replace aten::gelu with pointwise operations
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
map_gelu_to_pointwise_ops.runOnGraph(graph);

torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
map_gelu_approximate_to_pointwise_ops.runOnGraph(graph);

LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
}

Expand Down
35 changes: 35 additions & 0 deletions tests/core/lowering/test_reduce_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,38 @@ TEST(LoweringPasses, ReduceGeluCorrectly) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, ReduceGeluApproximateCorrectly) {
std::string source_graph = R"IR(
graph(%x, %approx):
%out : Tensor = aten::gelu(%x, %approx)
return (%out))IR";
std::string target_graph = R"IR(
graph(%x.1 : Tensor, %approx):
%6 : float = prim::Constant[value=0.044714999999999998]()
%5 : float = prim::Constant[value=0.79788456080000003]()
%4 : float = prim::Constant[value=1.]()
%3 : float = prim::Constant[value=0.5]()
%2 : int = prim::Constant[value=1]()
%7 : Tensor = aten::mul(%x.1, %3)
%8 : Tensor = aten::mul(%x.1, %5)
%9 : Tensor = aten::mul(%x.1, %6)
%10 : Tensor = aten::mul(%9, %x.1)
%11 : Tensor = aten::add(%10, %4, %2)
%12 : Tensor = aten::mul(%8, %11)
%13 : Tensor = aten::tanh(%12)
%14 : Tensor = aten::add(%13, %4, %2)
%15 : Tensor = aten::mul(%7, %14)
return (%15))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::ReduceGelu(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit 9ee3a04

Please sign in to comment.