Skip to content

Commit

Permalink
Add trick for BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Jun 30, 2023
1 parent 6dc511d commit a40aab7
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions cinn/hlir/pass/general_fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,12 @@ class DefaultInputFusePass final : public InputFusePass {
const std::unordered_set<OpGroupPtr> consumer_candidates = [&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : consumer_set) {
if (consumer->kind() != framework::kElementWise &&
consumer->kind() != framework::kBroadcast &&
consumer->kind() != framework::kInjective &&
consumer->kind() != framework::kReduction) {
continue;
if (consumer->kind() == framework::kElementWise ||
consumer->kind() == framework::kBroadcast ||
consumer->kind() == framework::kInjective ||
consumer->kind() == framework::kReduction) {
consumers.insert(consumer);
}
consumers.insert(consumer);
}
return consumers;
}();
Expand Down Expand Up @@ -586,21 +585,15 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass {
void operator()(LightwareFusePassCtx* ctx) const override {
VLOG(1) << "DefaultHorizontalFusePass";
const auto& producer = ctx->PickOpGroup();





const std::unordered_set<OpGroupPtr> consumer_candidates = [&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& pair : producer->consumer2outputs()) {
if (pair.first->kind() != framework::kElementWise &&
pair.first->kind() != framework::kBroadcast &&
pair.first->kind() != framework::kInjective &&
pair.first->kind() != framework::kReduction) {
continue;
if (pair.first->kind() == framework::kElementWise ||
pair.first->kind() == framework::kBroadcast ||
pair.first->kind() == framework::kInjective ||
pair.first->kind() == framework::kReduction) {
consumers.insert(pair.first);
}
consumers.insert(pair.first);
}
return consumers;
}();
Expand Down Expand Up @@ -1353,7 +1346,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
return updated;
}

void HorizontalFuse(const GroupList& consumers) {
void HorizontalFuse(const GroupList& const_consumers) {
VLOG(3) << "HorizontalFuse Groups...";
// create fusion group
auto fused_group = std::make_shared<Graph::Group>();
Expand All @@ -1362,10 +1355,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
std::unordered_set<GroupPtr> sub_group_set;
// find the first consumer.
GroupPtr first_consumer(nullptr);
// Trick for BERT
GroupList consumers = const_consumers;
if (consumers.size() == 2) {
if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") {
auto tmp = consumers[0];
consumers[0] = consumers[1];
consumers[1] = tmp;
}
}
// fuse all group into fusion group.
VLOG(1) << "********** DEBUG Begin check Horizontal ************";
for (const auto& consumer : consumers) {
VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!";
VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" << " Pattern kind = " << consumer->op_pattern_kind;
// update depth
fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth);
fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth);
Expand Down Expand Up @@ -1469,6 +1471,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
}
}

VLOG(1) << "DEBUG consumers.back() kind : " << static_cast<int>((consumers.back())->op_pattern_kind);
if (static_cast<int>(framework::kReduction) > static_cast<int>((consumers.back())->op_pattern_kind)) {
auto consumer = consumers.back();

Expand All @@ -1485,7 +1488,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
}
}
if (master_node) {
VLOG(3) << "Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id;
VLOG(1) << "DEBUG Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id;
fused_group->master_nodes.insert(master_node);
break;
}
Expand Down

0 comments on commit a40aab7

Please sign in to comment.