diff --git a/src/graph/executor/algo/CartesianProductExecutor.cpp b/src/graph/executor/algo/CartesianProductExecutor.cpp index 147b60e5af1..b48ab7a4bcc 100644 --- a/src/graph/executor/algo/CartesianProductExecutor.cpp +++ b/src/graph/executor/algo/CartesianProductExecutor.cpp @@ -5,6 +5,7 @@ #include "graph/executor/algo/CartesianProductExecutor.h" #include "graph/planner/plan/Algo.h" +#include "graph/planner/plan/Query.h" namespace nebula { namespace graph { diff --git a/src/graph/optimizer/CMakeLists.txt b/src/graph/optimizer/CMakeLists.txt index d53fa00498e..2d98c2c1a14 100644 --- a/src/graph/optimizer/CMakeLists.txt +++ b/src/graph/optimizer/CMakeLists.txt @@ -10,6 +10,7 @@ nebula_add_library( OptGroup.cpp OptRule.cpp OptContext.cpp + rule/PushFilterDownCrossJoinRule.cpp rule/PushFilterDownGetNbrsRule.cpp rule/RemoveNoopProjectRule.cpp rule/CombineFilterRule.cpp diff --git a/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.cpp b/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.cpp new file mode 100644 index 00000000000..9de2d8ad516 --- /dev/null +++ b/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.cpp @@ -0,0 +1,139 @@ +/* Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "graph/optimizer/rule/PushFilterDownCrossJoinRule.h" + +#include "graph/optimizer/OptContext.h" +#include "graph/optimizer/OptGroup.h" +#include "graph/planner/plan/PlanNode.h" +#include "graph/planner/plan/Query.h" +#include "graph/util/ExpressionUtils.h" + +using nebula::graph::CrossJoin; +using nebula::graph::ExpressionUtils; +using nebula::graph::Filter; +using nebula::graph::PlanNode; +using nebula::graph::QueryContext; + +namespace nebula { +namespace opt { + +std::unique_ptr PushFilterDownCrossJoinRule::kInstance = + std::unique_ptr(new PushFilterDownCrossJoinRule()); + +PushFilterDownCrossJoinRule::PushFilterDownCrossJoinRule() { + RuleSet::QueryRules().addRule(this); +} + +const Pattern& PushFilterDownCrossJoinRule::pattern() const { + static Pattern pattern = Pattern::create( + PlanNode::Kind::kFilter, + {Pattern::create( + PlanNode::Kind::kCrossJoin, + {Pattern::create(PlanNode::Kind::kUnknown), Pattern::create(PlanNode::Kind::kUnknown)})}); + return pattern; +} + +StatusOr PushFilterDownCrossJoinRule::transform( + OptContext* octx, const MatchedResult& matched) const { + auto* filterGroupNode = matched.node; + auto* oldFilterNode = filterGroupNode->node(); + DCHECK_EQ(oldFilterNode->kind(), PlanNode::Kind::kFilter); + + auto* crossJoinNode = matched.planNode({0, 0}); + DCHECK_EQ(crossJoinNode->kind(), PlanNode::Kind::kCrossJoin); + auto* oldCrossJoinNode = static_cast(crossJoinNode); + + const auto* condition = static_cast(oldFilterNode)->condition(); + DCHECK(condition); + + const auto& leftResult = matched.result({0, 0, 0}); + const auto& rightResult = matched.result({0, 0, 1}); + + Expression *leftFilterUnpicked = nullptr, *rightFilterUnpicked = nullptr; + OptGroup* leftGroup = pushFilterDownChild(octx, leftResult, condition, &leftFilterUnpicked); + OptGroup* rightGroup = + pushFilterDownChild(octx, rightResult, leftFilterUnpicked, &rightFilterUnpicked); + + if (!leftGroup && !rightGroup) { + return TransformResult::noTransform(); + } + + leftGroup = leftGroup ? leftGroup : const_cast(leftResult.node->group()); + rightGroup = rightGroup ? rightGroup : const_cast(rightResult.node->group()); + + // produce new CrossJoin node + auto* newCrossJoinNode = static_cast(oldCrossJoinNode->clone()); + auto newJoinGroup = rightFilterUnpicked ? OptGroup::create(octx) : filterGroupNode->group(); + // TODO(yee): it's too tricky + auto newGroupNode = rightFilterUnpicked + ? const_cast(newJoinGroup)->makeGroupNode(newCrossJoinNode) + : OptGroupNode::create(octx, newCrossJoinNode, newJoinGroup); + newGroupNode->dependsOn(leftGroup); + newGroupNode->dependsOn(rightGroup); + newCrossJoinNode->setLeftVar(leftGroup->outputVar()); + newCrossJoinNode->setRightVar(rightGroup->outputVar()); + + if (rightFilterUnpicked) { + auto newFilterNode = Filter::make(octx->qctx(), nullptr, rightFilterUnpicked); + newFilterNode->setOutputVar(oldFilterNode->outputVar()); + newFilterNode->setColNames(oldFilterNode->colNames()); + newFilterNode->setInputVar(newCrossJoinNode->outputVar()); + newGroupNode = OptGroupNode::create(octx, newFilterNode, filterGroupNode->group()); + newGroupNode->dependsOn(const_cast(newJoinGroup)); + } else { + newCrossJoinNode->setOutputVar(oldFilterNode->outputVar()); + newCrossJoinNode->setColNames(oldCrossJoinNode->colNames()); + } + + TransformResult result; + result.eraseAll = true; + result.newGroupNodes.emplace_back(newGroupNode); + return result; +} + +OptGroup* PushFilterDownCrossJoinRule::pushFilterDownChild(OptContext* octx, + const MatchedResult& child, + const Expression* condition, + Expression** unpickedFilter) { + if (!condition) return nullptr; + + const auto* childPlanNode = DCHECK_NOTNULL(child.node->node()); + const auto& colNames = childPlanNode->colNames(); + + // split the `condition` based on whether the varPropExpr comes from the left child + auto picker = [&colNames](const Expression* e) -> bool { + return ExpressionUtils::checkColName(colNames, e); + }; + + Expression* filterPicked = nullptr; + ExpressionUtils::splitFilter(condition, picker, &filterPicked, unpickedFilter); + if (!filterPicked) return nullptr; + + auto* newChildPlanNode = childPlanNode->clone(); + DCHECK_NE(childPlanNode->outputVar(), newChildPlanNode->outputVar()); + newChildPlanNode->setInputVar(childPlanNode->inputVar()); + newChildPlanNode->setColNames(childPlanNode->colNames()); + auto* newChildGroup = OptGroup::create(octx); + auto* newChildGroupNode = newChildGroup->makeGroupNode(newChildPlanNode); + for (auto* g : child.node->dependencies()) { + newChildGroupNode->dependsOn(g); + } + + auto* newFilterNode = Filter::make(octx->qctx(), nullptr, filterPicked); + newFilterNode->setOutputVar(childPlanNode->outputVar()); + newFilterNode->setColNames(colNames); + newFilterNode->setInputVar(newChildPlanNode->outputVar()); + auto* group = OptGroup::create(octx); + group->makeGroupNode(newFilterNode)->dependsOn(newChildGroup); + return group; +} + +std::string PushFilterDownCrossJoinRule::toString() const { + return "PushFilterDownCrossJoinRule"; +} + +} // namespace opt +} // namespace nebula diff --git a/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.h b/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.h new file mode 100644 index 00000000000..21af0c23701 --- /dev/null +++ b/src/graph/optimizer/rule/PushFilterDownCrossJoinRule.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_ +#define GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_ + +#include "graph/optimizer/OptRule.h" + +namespace nebula { +namespace opt { + +// Push down the filter items into the child sub-plan of [[CrossJoin]] +class PushFilterDownCrossJoinRule final : public OptRule { + public: + const Pattern &pattern() const override; + + StatusOr transform(OptContext *octx, + const MatchedResult &matched) const override; + + std::string toString() const override; + + private: + PushFilterDownCrossJoinRule(); + static OptGroup *pushFilterDownChild(OptContext *octx, + const MatchedResult &child, + const Expression *condition, + Expression **unpickedFilter); + + static std::unique_ptr kInstance; +}; + +} // namespace opt +} // namespace nebula + +#endif // GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_ diff --git a/src/graph/planner/match/MatchSolver.cpp b/src/graph/planner/match/MatchSolver.cpp index 8687d906893..679829937d7 100644 --- a/src/graph/planner/match/MatchSolver.cpp +++ b/src/graph/planner/match/MatchSolver.cpp @@ -207,8 +207,7 @@ Expression* MatchSolver::makeIndexFilter(const std::string& label, auto* root = relationals[0]; for (auto i = 1u; i < relationals.size(); i++) { - auto* left = root; - root = LogicalExpression::makeAnd(qctx->objPool(), left, relationals[i]); + root = LogicalExpression::makeAnd(qctx->objPool(), root, relationals[i]); } return root; diff --git a/src/graph/planner/match/StartVidFinder.h b/src/graph/planner/match/StartVidFinder.h index 7726ef0b646..531d9dc54f1 100644 --- a/src/graph/planner/match/StartVidFinder.h +++ b/src/graph/planner/match/StartVidFinder.h @@ -28,7 +28,7 @@ using StartVidFinderInstantiateFunc = std::function CartesianProduct::inputVars() const { return varNames; } -std::unique_ptr CrossJoin::explain() const { - return BinaryInputNode::explain(); -} - -PlanNode* CrossJoin::clone() const { - auto* node = make(qctx_); - node->cloneMembers(*this); - return node; -} - -void CrossJoin::cloneMembers(const CrossJoin& r) { - BinaryInputNode::cloneMembers(r); -} - -CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right) - : BinaryInputNode(qctx, Kind::kCrossJoin, left, right) { - auto lColNames = left->colNames(); - auto rColNames = right->colNames(); - lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end()); - setColNames(lColNames); -} - -void CrossJoin::accept(PlanNodeVisitor* visitor) { - visitor->visit(this); -} - -CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {} - std::unique_ptr Subgraph::explain() const { auto desc = SingleInputNode::explain(); addDescription("src", src_ ? src_->toString() : "", desc.get()); diff --git a/src/graph/planner/plan/Algo.h b/src/graph/planner/plan/Algo.h index 58a5ff83602..74f5f12f14e 100644 --- a/src/graph/planner/plan/Algo.h +++ b/src/graph/planner/plan/Algo.h @@ -437,32 +437,6 @@ class Subgraph final : public SingleInputNode { std::unique_ptr> edgeProps_; }; -class CrossJoin final : public BinaryInputNode { - public: - static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) { - return qctx->objPool()->makeAndAdd(qctx, left, right); - } - - std::unique_ptr explain() const override; - - PlanNode* clone() const override; - - void accept(PlanNodeVisitor* visitor) override; - - private: - friend ObjectPool; - - // used for clone only - static CrossJoin* make(QueryContext* qctx) { - return qctx->objPool()->makeAndAdd(qctx); - } - - void cloneMembers(const CrossJoin& r); - - CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right); - // use for clone - explicit CrossJoin(QueryContext* qctx); -}; } // namespace graph } // namespace nebula #endif // GRAPH_PLANNER_PLAN_ALGO_H_ diff --git a/src/graph/planner/plan/Query.cpp b/src/graph/planner/plan/Query.cpp index 5a64b3e0c8d..6dd9c037f48 100644 --- a/src/graph/planner/plan/Query.cpp +++ b/src/graph/planner/plan/Query.cpp @@ -959,6 +959,34 @@ void HashInnerJoin::cloneMembers(const HashInnerJoin& l) { HashJoin::cloneMembers(l); } +std::unique_ptr CrossJoin::explain() const { + return BinaryInputNode::explain(); +} + +PlanNode* CrossJoin::clone() const { + auto* node = make(qctx_); + node->cloneMembers(*this); + return node; +} + +void CrossJoin::cloneMembers(const CrossJoin& r) { + BinaryInputNode::cloneMembers(r); +} + +CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right) + : BinaryInputNode(qctx, Kind::kCrossJoin, left, right) { + auto lColNames = left->colNames(); + auto rColNames = right->colNames(); + lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end()); + setColNames(lColNames); +} + +void CrossJoin::accept(PlanNodeVisitor* visitor) { + visitor->visit(this); +} + +CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {} + std::unique_ptr RollUpApply::explain() const { auto desc = BinaryInputNode::explain(); addDescription("compareCols", folly::toJson(util::toJson(compareCols_)), desc.get()); diff --git a/src/graph/planner/plan/Query.h b/src/graph/planner/plan/Query.h index c130d4361d7..fdd055a89af 100644 --- a/src/graph/planner/plan/Query.h +++ b/src/graph/planner/plan/Query.h @@ -1899,6 +1899,33 @@ class HashInnerJoin final : public HashJoin { void cloneMembers(const HashInnerJoin&); }; +class CrossJoin final : public BinaryInputNode { + public: + static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) { + return qctx->objPool()->makeAndAdd(qctx, left, right); + } + + std::unique_ptr explain() const override; + + PlanNode* clone() const override; + + void accept(PlanNodeVisitor* visitor) override; + + private: + friend ObjectPool; + + // used for clone only + static CrossJoin* make(QueryContext* qctx) { + return qctx->objPool()->makeAndAdd(qctx); + } + + void cloneMembers(const CrossJoin& r); + + CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right); + // use for clone + explicit CrossJoin(QueryContext* qctx); +}; + // Roll Up Apply two results from two inputs. class RollUpApply : public BinaryInputNode { public: diff --git a/tests/tck/features/optimizer/PushFilterDownCrossJoinRule.feature b/tests/tck/features/optimizer/PushFilterDownCrossJoinRule.feature new file mode 100644 index 00000000000..ee2263c921c --- /dev/null +++ b/tests/tck/features/optimizer/PushFilterDownCrossJoinRule.feature @@ -0,0 +1,32 @@ +# Copyright (c) 2023 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License. +Feature: Push Filter down HashInnerJoin rule + + Background: + Given a graph with space named "nba" + + Scenario: push filter down HashInnerJoin + When profiling query: + """ + with ['Tim Duncan', 'Tony Parker'] as id_list + match (v1:player)-[e]-(v2:player) + where id(v1) in ['Tim Duncan', 'Tony Parker'] AND id(v2) in ['Tim Duncan', 'Tony Parker'] + return count(e) + """ + Then the result should be, in any order: + | count(e) | + | 8 | + And the execution plan should be: + | id | name | dependencies | operator info | + | 11 | Aggregate | 14 | | + | 14 | CrossJoin | 1,16 | | + | 1 | Project | 2 | | + | 2 | Start | | | + | 16 | Project | 15 | | + | 15 | Filter | 18 | {"condition": "((id($-.v1) IN [\"Tim Duncan\",\"Tony Parker\"]) AND (id($-.v2) IN [\"Tim Duncan\",\"Tony Parker\"]))"} | + | 18 | AppendVertices | 17 | | + | 17 | Traverse | 4 | | + | 4 | Dedup | 3 | | + | 3 | PassThrough | 5 | | + | 5 | Start | | |