Skip to content

Commit

Permalink
Remove IsDependency Trick
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Jun 30, 2023
1 parent a40aab7 commit 172ba95
Showing 1 changed file with 51 additions and 121 deletions.
172 changes: 51 additions & 121 deletions cinn/hlir/pass/general_fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ class FuseHelper {

virtual bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0;

virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0;

virtual bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0;

virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0;
virtual bool IsConsumerSetsReachable(const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const = 0;

protected:
FuseHelper() = default;
Expand Down Expand Up @@ -108,6 +111,19 @@ class GraphGroupFuseHelper final : public FuseHelper {
return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs);
}

bool IsConsumerSetsReachable(const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const override {
for (const auto& consumer : consumers) {
if (group == consumer) {
continue;
}
if (IsReachableInDag(consumer, group)) {
return true;
}
}
return false;
}

private:
bool IsReachableInDag(const OpGroupPtr& producer, const OpGroupPtr& consumer) const {
const auto& MinDepth4Node = [&](OpGroupPtr node) {
Expand Down Expand Up @@ -186,10 +202,9 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx {
EnableFuse_(EnableFuse),
fuse_helper_(new GraphGroupFuseHelper<GraphGroupLightwareFusePassCtx>(this)) {}

GraphGroupLightwareFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupList& candidates)>& EnableFuseList)
GraphGroupLightwareFusePassCtx(const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupList& candidates)>& EnableFuseList)
: graph_group_fusion_helper_(graph_group_fusion_helper),
group_(group),
EnableFuseList_(EnableFuseList),
Expand Down Expand Up @@ -427,41 +442,16 @@ class DefaultInputFusePass final : public InputFusePass {

int Benefit() const override { return 100; }

bool IsDependency(const OpGroupPtr& consumer,
const std::unordered_set<OpGroupPtr>& consumers) const {
std::queue<OpGroupPtr> candidates;
candidates.push(consumer);

std::unordered_set<OpGroupPtr> visited_set;
while (!candidates.empty()) {
auto& candidate = candidates.front();
candidates.pop();
for (const auto& producer_and_list : candidate->producer_groups()) {
const auto& producer = producer_and_list.first;
if (consumers.count(producer)) {
return true;
}
if (!visited_set.count(producer)) {
visited_set.insert(producer);
candidates.push(producer);
}
}
}
return false;
}

void operator()(InputFusePassCtx* ctx) const override {
VLOG(1) << "DefaultInputFusePass";
const auto& consumer_set = ctx->PickConsumersWithSameInputs();

const std::unordered_set<OpGroupPtr> consumer_candidates = [&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : consumer_set) {
if (consumer->kind() == framework::kElementWise ||
consumer->kind() == framework::kBroadcast ||
consumer->kind() == framework::kInjective ||
consumer->kind() == framework::kReduction) {
consumers.insert(consumer);
if (consumer->kind() == framework::kElementWise || consumer->kind() == framework::kBroadcast ||
consumer->kind() == framework::kInjective || consumer->kind() == framework::kReduction) {
consumers.insert(consumer);
}
}
return consumers;
Expand All @@ -472,22 +462,7 @@ class DefaultInputFusePass final : public InputFusePass {

std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {

// bool reachable = false;
// for (const auto& tmp: consumer_candidates) {
// if (tmp == candidate) {
// continue;
// }
// if (ctx->fuse_helper().IsReachable(candidate, tmp)) {
// reachable = true;
// break;
// }
// }
// if (reachable) {
// continue;
// }

if (IsDependency(candidate, consumer_candidates)) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
Expand All @@ -511,8 +486,8 @@ class DefaultInputFusePass final : public InputFusePass {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups: fusionable_consumers) {

for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
ctx->EnableFuse(groups);
}
Expand Down Expand Up @@ -554,45 +529,15 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass {

int Benefit() const override { return 100; }

bool IsDependency(const OpGroupPtr& producer_g,
const OpGroupPtr& consumer,
const std::unordered_set<OpGroupPtr>& consumers) const {
std::queue<OpGroupPtr> candidates;
candidates.push(consumer);

std::unordered_set<OpGroupPtr> visited_set;
while (!candidates.empty()) {
auto& candidate = candidates.front();
candidates.pop();
for (const auto& producer_and_list : candidate->producer_groups()) {
const auto& producer = producer_and_list.first;
if (producer == producer_g) {
continue;
}

if (consumers.count(producer)) {
return true;
}
if (!visited_set.count(producer)) {
visited_set.insert(producer);
candidates.push(producer);
}
}
}
return false;
}

void operator()(LightwareFusePassCtx* ctx) const override {
VLOG(1) << "DefaultHorizontalFusePass";
const auto& producer = ctx->PickOpGroup();
const auto& producer = ctx->PickOpGroup();
const std::unordered_set<OpGroupPtr> consumer_candidates = [&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& pair : producer->consumer2outputs()) {
if (pair.first->kind() == framework::kElementWise ||
pair.first->kind() == framework::kBroadcast ||
pair.first->kind() == framework::kInjective ||
pair.first->kind() == framework::kReduction) {
consumers.insert(pair.first);
if (pair.first->kind() == framework::kElementWise || pair.first->kind() == framework::kBroadcast ||
pair.first->kind() == framework::kInjective || pair.first->kind() == framework::kReduction) {
consumers.insert(pair.first);
}
}
return consumers;
Expand All @@ -603,22 +548,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass {

std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {

// bool reachable = false;
// for (const auto& tmp: consumer_candidates) {
// if (tmp == candidate) {
// continue;
// }
// if (ctx->fuse_helper().IsReachable(candidate, tmp)) {
// reachable = true;
// break;
// }
// }
// if (reachable) {
// continue;
// }

if (IsDependency(producer, candidate, consumer_candidates)) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
Expand All @@ -642,8 +572,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups: fusionable_consumers) {

for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size();
ctx->EnableFuse(groups);
Expand Down Expand Up @@ -835,11 +765,13 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass {
}

if (candidates.empty()) {
VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << std::dynamic_pointer_cast<Graph::Group>(producer)->group_id;
VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group "
<< std::dynamic_pointer_cast<Graph::Group>(producer)->group_id;
}
VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << unsafe_candidates.size();

if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) {
if (!candidates.empty() && unsafe_candidates.size() == consumers.size() &&
producer->kind() == framework::kElementWise) {
for (const auto& consumer : consumers) {
ctx->EnableFuse(producer, consumer);
}
Expand All @@ -859,7 +791,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass {
};

struct LightwareFusePassComparator {
bool operator()(const std::shared_ptr<LightwareFusePass>& lhs, const std::shared_ptr<LightwareFusePass>& rhs) const{
bool operator()(const std::shared_ptr<LightwareFusePass>& lhs, const std::shared_ptr<LightwareFusePass>& rhs) const {
return lhs->Benefit() > rhs->Benefit();
}
};
Expand Down Expand Up @@ -1161,9 +1093,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
VLOG(1) << "DEBUG Horizontal, begin check : " << producer->group_id;
const auto& GetFusableConsumerGroupLists = [&]() -> std::vector<OpGroupList> {
std::vector<OpGroupList> tagged_lists;
const auto& EnableFuse = [&](const OpGroupList& candidates) {
tagged_lists.push_back(candidates);
};
const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); };
GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse);
EnableFusedHorizontalGroups(&fuse_ctx);
return tagged_lists;
Expand All @@ -1176,19 +1106,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
std::vector<GroupList> ret;
for (const auto& group_list : group_lists) {
GroupList tmp;
for (const auto& group: group_list) {
for (const auto& group : group_list) {
tmp.push_back(std::dynamic_pointer_cast<Graph::Group>(group));
}
ret.push_back(tmp);
}
return ret;
};

const auto& group_lists = GetFusableConsumerGroupList();
if (group_lists.empty()) {
return false;
}
for (const auto& group_list: group_lists) {
for (const auto& group_list : group_lists) {
VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id;
HorizontalFuse(group_list);
}
Expand Down Expand Up @@ -1239,12 +1169,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {

bool CallGeneralInputFusePass(const std::unordered_set<GroupPtr>& consumers) {
VLOG(3) << "CallGeneralInputFusePass...!";
using OpGroupSets = std::set<std::set<OpGroupPtr>>;
using OpGroupSets = std::set<std::set<OpGroupPtr>>;
const auto& GetFusableConsumerGroupLists = [&]() -> std::vector<OpGroupList> {
std::vector<OpGroupList> tagged_lists;
const auto& EnableFuse = [&](const OpGroupList& candidates) {
tagged_lists.push_back(candidates);
};
const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); };
GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse);
EnableFusedInputGroups(&fuse_ctx);
return tagged_lists;
Expand All @@ -1257,7 +1185,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
std::vector<GroupList> ret;
for (const auto& group_list : group_lists) {
GroupList tmp;
for (const auto& group: group_list) {
for (const auto& group : group_list) {
tmp.push_back(std::dynamic_pointer_cast<Graph::Group>(group));
}
ret.push_back(tmp);
Expand All @@ -1269,7 +1197,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
if (group_lists.empty()) {
return false;
}
for (const auto& group_list: group_lists) {
for (const auto& group_list : group_lists) {
HorizontalFuse(group_list);
}

Expand Down Expand Up @@ -1359,15 +1287,16 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
GroupList consumers = const_consumers;
if (consumers.size() == 2) {
if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") {
auto tmp = consumers[0];
auto tmp = consumers[0];
consumers[0] = consumers[1];
consumers[1] = tmp;
}
}
// fuse all group into fusion group.
VLOG(1) << "********** DEBUG Begin check Horizontal ************";
for (const auto& consumer : consumers) {
VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" << " Pattern kind = " << consumer->op_pattern_kind;
VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!"
<< " Pattern kind = " << consumer->op_pattern_kind;
// update depth
fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth);
fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth);
Expand Down Expand Up @@ -1864,7 +1793,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase {
bool update = false;
auto consumer_groups = GetFusableConsumerGroupSet();
if (consumer_groups.size() > 0) {
CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!";
CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size())
<< "Recompute requires fuse all consumers!";
VLOG(1) << "DEBUG recompute fuse group " << producer->group_id;
RecomputeFuse(producer, consumer_groups);
update = true;
Expand Down

0 comments on commit 172ba95

Please sign in to comment.