diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 484d551508..0ab30f6f46 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -89,14 +89,25 @@ class Graph : public cinn::common::Graph { return first.get() == second.get(); } }; + + struct WeakGroupHasher { + size_t operator()(const std::weak_ptr& group) const noexcept { + return std::hash()(reinterpret_cast(group.lock().get())); + } + }; + struct WeakGroupComparator { + bool operator()(const std::weak_ptr& first, const std::weak_ptr& second) const noexcept { + return first.lock().get() == second.lock().get(); + } + }; // input groups - std::unordered_set, SharedGroupHasher, SharedGroupComparator> producer_groups; + std::unordered_set, WeakGroupHasher, WeakGroupComparator> producer_groups; // output grous std::unordered_set, SharedGroupHasher, SharedGroupComparator> consumer_groups; // fused sub-groups, used for fusion merge pass std::vector> fused_sub_groups; // if as sub-group, used for belong groups. - std::unordered_set, SharedGroupHasher, SharedGroupComparator> belong_groups; + std::unordered_set, WeakGroupHasher, WeakGroupComparator> belong_groups; // for op lowering. std::vector input_names; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 0121f8f056..2a00798e16 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -32,9 +32,14 @@ using common::GraphNode; using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; +using WeakComparator = Graph::Group::WeakGroupComparator; +using WeakHasher = Graph::Group::WeakGroupHasher; + using GroupPtr = std::shared_ptr; using GroupList = std::vector; +using WeakGroupPtr = std::weak_ptr; + using ConditionFunction = std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused @@ -62,7 +67,8 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } for (auto& producer : group->producer_groups) { - VLOG(3) << " Producer -> " << producer->group_id; + CHECK(!producer.expired()); + VLOG(3) << " Producer -> " << producer.lock()->group_id; } for (auto& consumer : group->consumer_groups) { VLOG(3) << " Consumer -> " << consumer->group_id; @@ -147,8 +153,9 @@ class FusionMergePassHelper : public FusionHelperBase { bool exist = false; for (auto& producer : group->producer_groups) { - if (fusion_groups_set.count(producer)) { - VLOG(4) << group->group_id << " " << producer->group_id; + CHECK(!producer.expired()); + if (fusion_groups_set.count(producer.lock())) { + VLOG(4) << group->group_id << " " << producer.lock()->group_id; exist = true; break; } @@ -312,10 +319,11 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer group for (auto& producer : consumer->producer_groups) { - fused_group->producer_groups.insert(producer); + CHECK(!producer.expired()); + fused_group->producer_groups.insert(producer.lock()); // update producer's consumer - producer->consumer_groups.erase(consumer); - producer->consumer_groups.insert(fused_group); + producer.lock()->consumer_groups.erase(consumer); + producer.lock()->consumer_groups.insert(fused_group); } // consumer group for (auto& gconsumer : consumer->consumer_groups) { @@ -485,10 +493,11 @@ class FusionMergePassHelper : public FusionHelperBase { // producer groups for (auto& group : producer->producer_groups) { + CHECK(!group.expired()); fused_group->producer_groups.insert(group); // update producer's producer's consumer - group->consumer_groups.erase(producer); - group->consumer_groups.insert(fused_group); + group.lock()->consumer_groups.erase(producer); + group.lock()->consumer_groups.insert(fused_group); } // sub groups @@ -535,11 +544,12 @@ class FusionMergePassHelper : public FusionHelperBase { // producer nodes for (auto& group : consumer->producer_groups) { - if (group.get() != producer.get()) { + CHECK(!group.expired()); + if (group.lock().get() != producer.get()) { fused_group->producer_groups.insert(group); // update consumer's producer's consumer - group->consumer_groups.erase(consumer); - group->consumer_groups.insert(fused_group); + group.lock()->consumer_groups.erase(consumer); + group.lock()->consumer_groups.insert(fused_group); } } // consumer nodes @@ -723,15 +733,16 @@ class FusionMergePassHelper : public FusionHelperBase { auto& candidate = candidates.front(); candidates.pop(); for (auto& producer : candidate->producer_groups) { - if (producer.get() == producer_g.get()) { + CHECK(!producer.expired()); + if (producer.lock().get() == producer_g.get()) { continue; } - if (consumers.count(producer)) { + if (consumers.count(producer.lock())) { return true; } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); + if (!visited_set.count(producer.lock())) { + visited_set.insert(producer.lock()); + candidates.push(producer.lock()); } } } @@ -750,18 +761,19 @@ class FusionMergePassHelper : public FusionHelperBase { auto& candidate = candidates.front(); candidates.pop(); for (auto& producer : candidate->producer_groups) { - if (producer.get() == producer_g.get()) { + CHECK(!producer.expired()); + if (producer.lock().get() == producer_g.get()) { continue; } - if (producer->min_depth > check_upper_depth) { + if (producer.lock()->min_depth > check_upper_depth) { continue; } - if (consumers.count(producer)) { + if (consumers.count(producer.lock())) { return true; } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); + if (!visited_set.count(producer.lock())) { + visited_set.insert(producer.lock()); + candidates.push(producer.lock()); } } } @@ -824,10 +836,11 @@ class FusionMergePassHelper : public FusionHelperBase { updated_consumers.insert(cur); } else { for (auto& belong_group : cur->belong_groups) { - if (belong_group->group_id == cur->group_id) { - updated_consumers.insert(belong_group); + CHECK(!belong_group.expired()); + if (belong_group.lock()->group_id == cur->group_id) { + updated_consumers.insert(belong_group.lock()); } else { - fused_groups.push(belong_group); + fused_groups.push(belong_group.lock()); } } } @@ -881,16 +894,18 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_set producers; + std::unordered_set producers; std::unordered_set consumers; for (auto& producer : group->producer_groups) { - CHECK(producer->belong_groups.size()); - producers.insert(*producer->belong_groups.begin()); + CHECK(!producer.expired()); + CHECK(producer.lock()->belong_groups.size()); + producers.insert(*producer.lock()->belong_groups.begin()); } for (auto& consumer : group->consumer_groups) { CHECK(consumer->belong_groups.size()); - consumers.insert(*consumer->belong_groups.begin()); + CHECK(!consumer->belong_groups.begin()->expired()); + consumers.insert(consumer->belong_groups.begin()->lock()); } CHECK_EQ(group->producer_groups.size(), producers.size()); CHECK_EQ(group->consumer_groups.size(), consumers.size()); diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index cbb38cbd2a..2583dfdd6e 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -355,7 +355,8 @@ void OpFusionPassInternal(Graph* graph) { for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; for (auto& producer : group->producer_groups) { - VLOG(3) << " producer group -> " << producer->group_id; + CHECK(!producer.expired()); + VLOG(3) << " producer group -> " << producer.lock()->group_id; } for (auto& consumer : group->consumer_groups) { VLOG(3) << " consumer group -> " << consumer->group_id;