Skip to content

Commit

Permalink
Remove trick to HorizontalFusePass
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Jul 3, 2023
1 parent 172ba95 commit 91f67a6
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions cinn/hlir/pass/general_fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,17 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass {
for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size();

// Trick for BERT, maybe not required, wait for substitution from unordered_set to set
if (groups.size() == 2) {
OpGroupList fuse_group;
if (std::dynamic_pointer_cast<Graph::Group>(groups[1])->group_id == "cast_13" && std::dynamic_pointer_cast<Graph::Group>(groups[0])->group_id == "reshape_split") {
fuse_group.push_back(groups[1]);
fuse_group.push_back(groups[0]);
ctx->EnableFuse(fuse_group);
continue;
}
}
ctx->EnableFuse(groups);
}
}
Expand Down Expand Up @@ -1274,7 +1285,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
return updated;
}

void HorizontalFuse(const GroupList& const_consumers) {
void HorizontalFuse(const GroupList& consumers) {
VLOG(3) << "HorizontalFuse Groups...";
// create fusion group
auto fused_group = std::make_shared<Graph::Group>();
Expand All @@ -1283,15 +1294,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
std::unordered_set<GroupPtr> sub_group_set;
// find the first consumer.
GroupPtr first_consumer(nullptr);
// Trick for BERT
GroupList consumers = const_consumers;
if (consumers.size() == 2) {
if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") {
auto tmp = consumers[0];
consumers[0] = consumers[1];
consumers[1] = tmp;
}
}
// fuse all group into fusion group.
VLOG(1) << "********** DEBUG Begin check Horizontal ************";
for (const auto& consumer : consumers) {
Expand Down

0 comments on commit 91f67a6

Please sign in to comment.