diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 16319b6529..014e261bae 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -67,9 +67,12 @@ class FuseHelper { virtual bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0; + virtual bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; - virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0; + virtual bool IsConsumerSetsReachable(const OpGroupPtr& group, + const std::unordered_set& consumers) const = 0; protected: FuseHelper() = default; @@ -108,6 +111,19 @@ class GraphGroupFuseHelper final : public FuseHelper { return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs); } + bool IsConsumerSetsReachable(const OpGroupPtr& group, + const std::unordered_set& consumers) const override { + for (const auto& consumer : consumers) { + if (group == consumer) { + continue; + } + if (IsReachableInDag(consumer, group)) { + return true; + } + } + return false; + } + private: bool IsReachableInDag(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { @@ -186,10 +202,9 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { EnableFuse_(EnableFuse), fuse_helper_(new GraphGroupFuseHelper(this)) {} - GraphGroupLightwareFusePassCtx( - const FusionHelperBase* graph_group_fusion_helper, - const OpGroupPtr& group, - const std::function& EnableFuseList) + GraphGroupLightwareFusePassCtx(const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& EnableFuseList) : graph_group_fusion_helper_(graph_group_fusion_helper), group_(group), EnableFuseList_(EnableFuseList), @@ -427,29 +442,6 @@ class DefaultInputFusePass final : public InputFusePass { int Benefit() const override { return 100; } - bool IsDependency(const OpGroupPtr& consumer, - const std::unordered_set& consumers) const { - std::queue candidates; - candidates.push(consumer); - - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - const auto& producer = producer_and_list.first; - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - void operator()(InputFusePassCtx* ctx) const override { VLOG(1) << "DefaultInputFusePass"; const auto& consumer_set = ctx->PickConsumersWithSameInputs(); @@ -457,11 +449,9 @@ class DefaultInputFusePass final : public InputFusePass { const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& consumer : consumer_set) { - if (consumer->kind() == framework::kElementWise || - consumer->kind() == framework::kBroadcast || - consumer->kind() == framework::kInjective || - consumer->kind() == framework::kReduction) { - consumers.insert(consumer); + if (consumer->kind() == framework::kElementWise || consumer->kind() == framework::kBroadcast || + consumer->kind() == framework::kInjective || consumer->kind() == framework::kReduction) { + consumers.insert(consumer); } } return consumers; @@ -472,22 +462,7 @@ class DefaultInputFusePass final : public InputFusePass { std::vector fusionable_consumers; for (auto& candidate : consumer_candidates) { - - // bool reachable = false; - // for (const auto& tmp: consumer_candidates) { - // if (tmp == candidate) { - // continue; - // } - // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { - // reachable = true; - // break; - // } - // } - // if (reachable) { - // continue; - // } - - if (IsDependency(candidate, consumer_candidates)) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) { continue; } if (fusionable_consumers.empty()) { @@ -511,8 +486,8 @@ class DefaultInputFusePass final : public InputFusePass { fusionable_consumers.push_back({candidate}); } } - - for (const auto& groups: fusionable_consumers) { + + for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { ctx->EnableFuse(groups); } @@ -554,45 +529,15 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { int Benefit() const override { return 100; } - bool IsDependency(const OpGroupPtr& producer_g, - const OpGroupPtr& consumer, - const std::unordered_set& consumers) const { - std::queue candidates; - candidates.push(consumer); - - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - const auto& producer = producer_and_list.first; - if (producer == producer_g) { - continue; - } - - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - void operator()(LightwareFusePassCtx* ctx) const override { VLOG(1) << "DefaultHorizontalFusePass"; - const auto& producer = ctx->PickOpGroup(); + const auto& producer = ctx->PickOpGroup(); const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& pair : producer->consumer2outputs()) { - if (pair.first->kind() == framework::kElementWise || - pair.first->kind() == framework::kBroadcast || - pair.first->kind() == framework::kInjective || - pair.first->kind() == framework::kReduction) { - consumers.insert(pair.first); + if (pair.first->kind() == framework::kElementWise || pair.first->kind() == framework::kBroadcast || + pair.first->kind() == framework::kInjective || pair.first->kind() == framework::kReduction) { + consumers.insert(pair.first); } } return consumers; @@ -603,22 +548,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { std::vector fusionable_consumers; for (auto& candidate : consumer_candidates) { - - // bool reachable = false; - // for (const auto& tmp: consumer_candidates) { - // if (tmp == candidate) { - // continue; - // } - // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { - // reachable = true; - // break; - // } - // } - // if (reachable) { - // continue; - // } - - if (IsDependency(producer, candidate, consumer_candidates)) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) { continue; } if (fusionable_consumers.empty()) { @@ -642,8 +572,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { fusionable_consumers.push_back({candidate}); } } - - for (const auto& groups: fusionable_consumers) { + + for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size(); ctx->EnableFuse(groups); @@ -835,11 +765,13 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } if (candidates.empty()) { - VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << std::dynamic_pointer_cast(producer)->group_id; + VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " + << std::dynamic_pointer_cast(producer)->group_id; } VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << unsafe_candidates.size(); - if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && + producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -859,7 +791,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { }; struct LightwareFusePassComparator { - bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const{ + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { return lhs->Benefit() > rhs->Benefit(); } }; @@ -1161,9 +1093,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VLOG(1) << "DEBUG Horizontal, begin check : " << producer->group_id; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; - const auto& EnableFuse = [&](const OpGroupList& candidates) { - tagged_lists.push_back(candidates); - }; + const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_lists; @@ -1176,19 +1106,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::vector ret; for (const auto& group_list : group_lists) { GroupList tmp; - for (const auto& group: group_list) { + for (const auto& group : group_list) { tmp.push_back(std::dynamic_pointer_cast(group)); } ret.push_back(tmp); } return ret; }; - + const auto& group_lists = GetFusableConsumerGroupList(); if (group_lists.empty()) { return false; } - for (const auto& group_list: group_lists) { + for (const auto& group_list : group_lists) { VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(group_list); } @@ -1239,12 +1169,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - using OpGroupSets = std::set>; + using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; - const auto& EnableFuse = [&](const OpGroupList& candidates) { - tagged_lists.push_back(candidates); - }; + const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); EnableFusedInputGroups(&fuse_ctx); return tagged_lists; @@ -1257,7 +1185,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::vector ret; for (const auto& group_list : group_lists) { GroupList tmp; - for (const auto& group: group_list) { + for (const auto& group : group_list) { tmp.push_back(std::dynamic_pointer_cast(group)); } ret.push_back(tmp); @@ -1269,7 +1197,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (group_lists.empty()) { return false; } - for (const auto& group_list: group_lists) { + for (const auto& group_list : group_lists) { HorizontalFuse(group_list); } @@ -1359,7 +1287,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { GroupList consumers = const_consumers; if (consumers.size() == 2) { if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") { - auto tmp = consumers[0]; + auto tmp = consumers[0]; consumers[0] = consumers[1]; consumers[1] = tmp; } @@ -1367,7 +1295,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // fuse all group into fusion group. VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { - VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" << " Pattern kind = " << consumer->op_pattern_kind; + VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" + << " Pattern kind = " << consumer->op_pattern_kind; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -1864,7 +1793,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size() > 0) { - CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; + CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) + << "Recompute requires fuse all consumers!"; VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; RecomputeFuse(producer, consumer_groups); update = true;