Skip to content

Commit

Permalink
fix cudnn 8.7+ bug on cudnnConvolutionBiasActivationForward (#55412)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Jul 17, 2023
1 parent 4910f40 commit 6a2db61
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");

#if CUDNN_VERSION >= 8000
// NOTE(liuyuanle): cudnn [8.7, 8.9 now) version has bug when act is
// sigmoid/tanh. Ref to issue
// https://github.com/PaddlePaddle/Paddle/issues/50853
#if CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8700
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
Expand All @@ -154,6 +157,7 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, all_act_set);

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
Expand Down Expand Up @@ -220,8 +224,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
elementwise_add_out,
elementwise_add_out_1,
act_op});
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
->assert_is_op_input("conv2d", "Input")
->AsInput();

#if CUDNN_VERSION >= 8000
// NOTE(liuyuanle): cudnn [8.7, 8.9 now) version has bug when act is
// sigmoid/tanh. Ref to issue
// https://github.com/PaddlePaddle/Paddle/issues/50853
#if CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8700
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
Expand All @@ -175,6 +178,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, all_act_set);

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
Expand Down Expand Up @@ -226,9 +230,11 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(
graph,
{conv_op, conv_out, elementwise_add_op, elementwise_add_out, act_op});
found_count++;
};

gpd(graph, handler);
AddStatis(found_count);
}

} // namespace ir
Expand Down

0 comments on commit 6a2db61

Please sign in to comment.