Skip to content

Commit

Permalink
each group applies a rule only once
Browse files Browse the repository at this point in the history
  • Loading branch information
nevermore3 committed Oct 9, 2023
1 parent 798d48e commit 8def78a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/graph/optimizer/OptContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <boost/core/noncopyable.hpp>
#include <memory>
#include <unordered_map>
#include <unordered_set>

#include "common/cpp/helpers.h"

Expand All @@ -18,11 +19,14 @@ class ObjectPool;

namespace graph {
class QueryContext;
class PlanNode;
} // namespace graph

namespace opt {

class OptGroupNode;
class OptGroup;
class Optimizer;

class OptContext final : private boost::noncopyable, private cpp::NonMovable {
public:
Expand All @@ -48,12 +52,16 @@ class OptContext final : private boost::noncopyable, private cpp::NonMovable {
const OptGroupNode *findOptGroupNodeByPlanNodeId(int64_t planNodeId) const;

private:
friend OptGroup;
friend Optimizer;
// A global flag to record whether this iteration caused a change to the plan
bool changed_{true};
graph::QueryContext *qctx_{nullptr};
// Memo memory management in the Optimizer phase
std::unique_ptr<ObjectPool> objPool_;
std::unordered_map<int64_t, const OptGroupNode *> planNodeToOptGroupNodeMap_;
std::unordered_set<const OptGroup *> visited_;
std::unordered_map<const OptGroup *, const graph::PlanNode *> group2PlanNodeMap_;
};

} // namespace opt
Expand Down
18 changes: 17 additions & 1 deletion src/graph/optimizer/OptGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ OptGroup *OptGroup::create(OptContext *ctx) {
}

void OptGroup::setUnexplored(const OptRule *rule) {
if (!ctx_->visited_.emplace(this).second) {
return;
}
auto iter = std::find(exploredRules_.begin(), exploredRules_.end(), rule);
if (iter != exploredRules_.end()) {
exploredRules_.erase(iter);
Expand Down Expand Up @@ -100,6 +103,9 @@ Status OptGroup::validate(const OptRule *rule) const {
rule->toString().c_str(),
groupNodesReferenced_.size());
}
if (!ctx_->visited_.emplace(this).second) {
return Status::OK();
}
for (auto *gn : groupNodes_) {
NG_RETURN_IF_ERROR(gn->validate(rule));
if (gn->node()->outputVar() != outputVar_) {
Expand Down Expand Up @@ -138,6 +144,9 @@ Status OptGroup::explore(const OptRule *rule) {
return Status::OK();
}
setExplored(rule);
if (!ctx_->visited_.emplace(this).second) {
return Status::OK();
}

// TODO(yee): the opt group maybe in the loop body branch
// DCHECK(isRootGroup_ || !groupNodesReferenced_.empty())
Expand Down Expand Up @@ -241,8 +250,15 @@ double OptGroup::getCost() const {
}

const PlanNode *OptGroup::getPlan() const {
auto &group2PlanNodeMap = ctx_->group2PlanNodeMap_;
auto iter = group2PlanNodeMap.find(this);
if (iter != group2PlanNodeMap.end()) {
return iter->second;
}
const OptGroupNode *minGroupNode = findMinCostGroupNode().second;
return DCHECK_NOTNULL(minGroupNode)->getPlan();
const auto plan = DCHECK_NOTNULL(minGroupNode)->getPlan();
group2PlanNodeMap.emplace(this, plan);
return plan;
}

void OptGroup::deleteRefGroupNode(const OptGroupNode *node) {
Expand Down
11 changes: 10 additions & 1 deletion src/graph/optimizer/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ Status Optimizer::doExploration(OptContext *octx, OptGroup *rootGroup) {
for (auto rule : ruleSet->rules()) {
// Explore until the maximum number of iterations(Rules) is reached
NG_RETURN_IF_ERROR(rootGroup->exploreUntilMaxRound(rule));
octx->visited_.clear();
NG_RETURN_IF_ERROR(rootGroup->validate(rule));
octx->visited_.clear();
rootGroup->setUnexplored(rule);
octx->visited_.clear();
}
}
}
Expand Down Expand Up @@ -226,17 +229,23 @@ Status Optimizer::rewriteArgumentInputVar(PlanNode *root) {

Status Optimizer::checkPlanDepth(const PlanNode *root) const {
std::queue<const PlanNode *> queue;
std::unordered_set<const PlanNode *> visited;
queue.push(root);
visited.emplace(root);
size_t depth = 0;
while (!queue.empty()) {
size_t size = queue.size();
for (size_t i = 0; i < size; ++i) {
const PlanNode *node = queue.front();
queue.pop();
for (size_t j = 0; j < node->numDeps(); j++) {
queue.push(node->dep(j));
const auto *dep = node->dep(j);
if (visited.emplace(dep).second) {
queue.push(dep);
}
}
}

++depth;
if (depth > FLAGS_max_plan_depth) {
return Status::Error("The depth of plan tree has exceeded the max %lu level",
Expand Down

0 comments on commit 8def78a

Please sign in to comment.