Skip to content

Commit

Permalink
Push filter down cross join (#5473)
Browse files Browse the repository at this point in the history
* fix comment

* push down filter through cross join

---------

Co-authored-by: Sophie <[email protected]>
  • Loading branch information
yixinglu and Sophie-Xie committed Apr 6, 2023
1 parent 37a24f1 commit 4be13b0
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 57 deletions.
1 change: 1 addition & 0 deletions src/graph/executor/algo/CartesianProductExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "graph/executor/algo/CartesianProductExecutor.h"

#include "graph/planner/plan/Algo.h"
#include "graph/planner/plan/Query.h"

namespace nebula {
namespace graph {
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ nebula_add_library(
OptGroup.cpp
OptRule.cpp
OptContext.cpp
rule/PushFilterDownCrossJoinRule.cpp
rule/PushFilterDownGetNbrsRule.cpp
rule/RemoveNoopProjectRule.cpp
rule/CombineFilterRule.cpp
Expand Down
139 changes: 139 additions & 0 deletions src/graph/optimizer/rule/PushFilterDownCrossJoinRule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#include "graph/optimizer/rule/PushFilterDownCrossJoinRule.h"

#include "graph/optimizer/OptContext.h"
#include "graph/optimizer/OptGroup.h"
#include "graph/planner/plan/PlanNode.h"
#include "graph/planner/plan/Query.h"
#include "graph/util/ExpressionUtils.h"

using nebula::graph::CrossJoin;
using nebula::graph::ExpressionUtils;
using nebula::graph::Filter;
using nebula::graph::PlanNode;
using nebula::graph::QueryContext;

namespace nebula {
namespace opt {

std::unique_ptr<OptRule> PushFilterDownCrossJoinRule::kInstance =
std::unique_ptr<PushFilterDownCrossJoinRule>(new PushFilterDownCrossJoinRule());

PushFilterDownCrossJoinRule::PushFilterDownCrossJoinRule() {
RuleSet::QueryRules().addRule(this);
}

const Pattern& PushFilterDownCrossJoinRule::pattern() const {
static Pattern pattern = Pattern::create(
PlanNode::Kind::kFilter,
{Pattern::create(
PlanNode::Kind::kCrossJoin,
{Pattern::create(PlanNode::Kind::kUnknown), Pattern::create(PlanNode::Kind::kUnknown)})});
return pattern;
}

StatusOr<OptRule::TransformResult> PushFilterDownCrossJoinRule::transform(
OptContext* octx, const MatchedResult& matched) const {
auto* filterGroupNode = matched.node;
auto* oldFilterNode = filterGroupNode->node();
DCHECK_EQ(oldFilterNode->kind(), PlanNode::Kind::kFilter);

auto* crossJoinNode = matched.planNode({0, 0});
DCHECK_EQ(crossJoinNode->kind(), PlanNode::Kind::kCrossJoin);
auto* oldCrossJoinNode = static_cast<const CrossJoin*>(crossJoinNode);

const auto* condition = static_cast<Filter*>(oldFilterNode)->condition();
DCHECK(condition);

const auto& leftResult = matched.result({0, 0, 0});
const auto& rightResult = matched.result({0, 0, 1});

Expression *leftFilterUnpicked = nullptr, *rightFilterUnpicked = nullptr;
OptGroup* leftGroup = pushFilterDownChild(octx, leftResult, condition, &leftFilterUnpicked);
OptGroup* rightGroup =
pushFilterDownChild(octx, rightResult, leftFilterUnpicked, &rightFilterUnpicked);

if (!leftGroup && !rightGroup) {
return TransformResult::noTransform();
}

leftGroup = leftGroup ? leftGroup : const_cast<OptGroup*>(leftResult.node->group());
rightGroup = rightGroup ? rightGroup : const_cast<OptGroup*>(rightResult.node->group());

// produce new CrossJoin node
auto* newCrossJoinNode = static_cast<CrossJoin*>(oldCrossJoinNode->clone());
auto newJoinGroup = rightFilterUnpicked ? OptGroup::create(octx) : filterGroupNode->group();
// TODO(yee): it's too tricky
auto newGroupNode = rightFilterUnpicked
? const_cast<OptGroup*>(newJoinGroup)->makeGroupNode(newCrossJoinNode)
: OptGroupNode::create(octx, newCrossJoinNode, newJoinGroup);
newGroupNode->dependsOn(leftGroup);
newGroupNode->dependsOn(rightGroup);
newCrossJoinNode->setLeftVar(leftGroup->outputVar());
newCrossJoinNode->setRightVar(rightGroup->outputVar());

if (rightFilterUnpicked) {
auto newFilterNode = Filter::make(octx->qctx(), nullptr, rightFilterUnpicked);
newFilterNode->setOutputVar(oldFilterNode->outputVar());
newFilterNode->setColNames(oldFilterNode->colNames());
newFilterNode->setInputVar(newCrossJoinNode->outputVar());
newGroupNode = OptGroupNode::create(octx, newFilterNode, filterGroupNode->group());
newGroupNode->dependsOn(const_cast<OptGroup*>(newJoinGroup));
} else {
newCrossJoinNode->setOutputVar(oldFilterNode->outputVar());
newCrossJoinNode->setColNames(oldCrossJoinNode->colNames());
}

TransformResult result;
result.eraseAll = true;
result.newGroupNodes.emplace_back(newGroupNode);
return result;
}

OptGroup* PushFilterDownCrossJoinRule::pushFilterDownChild(OptContext* octx,
const MatchedResult& child,
const Expression* condition,
Expression** unpickedFilter) {
if (!condition) return nullptr;

const auto* childPlanNode = DCHECK_NOTNULL(child.node->node());
const auto& colNames = childPlanNode->colNames();

// split the `condition` based on whether the varPropExpr comes from the left child
auto picker = [&colNames](const Expression* e) -> bool {
return ExpressionUtils::checkColName(colNames, e);
};

Expression* filterPicked = nullptr;
ExpressionUtils::splitFilter(condition, picker, &filterPicked, unpickedFilter);
if (!filterPicked) return nullptr;

auto* newChildPlanNode = childPlanNode->clone();
DCHECK_NE(childPlanNode->outputVar(), newChildPlanNode->outputVar());
newChildPlanNode->setInputVar(childPlanNode->inputVar());
newChildPlanNode->setColNames(childPlanNode->colNames());
auto* newChildGroup = OptGroup::create(octx);
auto* newChildGroupNode = newChildGroup->makeGroupNode(newChildPlanNode);
for (auto* g : child.node->dependencies()) {
newChildGroupNode->dependsOn(g);
}

auto* newFilterNode = Filter::make(octx->qctx(), nullptr, filterPicked);
newFilterNode->setOutputVar(childPlanNode->outputVar());
newFilterNode->setColNames(colNames);
newFilterNode->setInputVar(newChildPlanNode->outputVar());
auto* group = OptGroup::create(octx);
group->makeGroupNode(newFilterNode)->dependsOn(newChildGroup);
return group;
}

std::string PushFilterDownCrossJoinRule::toString() const {
return "PushFilterDownCrossJoinRule";
}

} // namespace opt
} // namespace nebula
37 changes: 37 additions & 0 deletions src/graph/optimizer/rule/PushFilterDownCrossJoinRule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#ifndef GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_
#define GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_

#include "graph/optimizer/OptRule.h"

namespace nebula {
namespace opt {

// Push down the filter items into the child sub-plan of [[CrossJoin]]
class PushFilterDownCrossJoinRule final : public OptRule {
public:
const Pattern &pattern() const override;

StatusOr<OptRule::TransformResult> transform(OptContext *octx,
const MatchedResult &matched) const override;

std::string toString() const override;

private:
PushFilterDownCrossJoinRule();
static OptGroup *pushFilterDownChild(OptContext *octx,
const MatchedResult &child,
const Expression *condition,
Expression **unpickedFilter);

static std::unique_ptr<OptRule> kInstance;
};

} // namespace opt
} // namespace nebula

#endif // GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_
3 changes: 1 addition & 2 deletions src/graph/planner/match/MatchSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ Expression* MatchSolver::makeIndexFilter(const std::string& label,

auto* root = relationals[0];
for (auto i = 1u; i < relationals.size(); i++) {
auto* left = root;
root = LogicalExpression::makeAnd(qctx->objPool(), left, relationals[i]);
root = LogicalExpression::makeAnd(qctx->objPool(), root, relationals[i]);
}

return root;
Expand Down
2 changes: 1 addition & 1 deletion src/graph/planner/match/StartVidFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using StartVidFinderInstantiateFunc = std::function<std::unique_ptr<StartVidFind
// 3. PropIndexSeek finds if a plan could traverse from some vids that could be
// read from the property indices.
// MATCH(n:Tag{prop:value}) RETURN n
// MATCH(n:Tag) WHERE n.prop = value RETURN n
// MATCH(n:Tag) WHERE n.Tag.prop = value RETURN n
//
// 4. LabelIndexSeek finds if a plan could traverse from some vids that could be
// read from the label indices.
Expand Down
28 changes: 0 additions & 28 deletions src/graph/planner/plan/Algo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,34 +139,6 @@ std::vector<std::string> CartesianProduct::inputVars() const {
return varNames;
}

std::unique_ptr<PlanNodeDescription> CrossJoin::explain() const {
return BinaryInputNode::explain();
}

PlanNode* CrossJoin::clone() const {
auto* node = make(qctx_);
node->cloneMembers(*this);
return node;
}

void CrossJoin::cloneMembers(const CrossJoin& r) {
BinaryInputNode::cloneMembers(r);
}

CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right)
: BinaryInputNode(qctx, Kind::kCrossJoin, left, right) {
auto lColNames = left->colNames();
auto rColNames = right->colNames();
lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end());
setColNames(lColNames);
}

void CrossJoin::accept(PlanNodeVisitor* visitor) {
visitor->visit(this);
}

CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {}

std::unique_ptr<PlanNodeDescription> Subgraph::explain() const {
auto desc = SingleInputNode::explain();
addDescription("src", src_ ? src_->toString() : "", desc.get());
Expand Down
26 changes: 0 additions & 26 deletions src/graph/planner/plan/Algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,32 +437,6 @@ class Subgraph final : public SingleInputNode {
std::unique_ptr<std::vector<EdgeProp>> edgeProps_;
};

class CrossJoin final : public BinaryInputNode {
public:
static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx, left, right);
}

std::unique_ptr<PlanNodeDescription> explain() const override;

PlanNode* clone() const override;

void accept(PlanNodeVisitor* visitor) override;

private:
friend ObjectPool;

// used for clone only
static CrossJoin* make(QueryContext* qctx) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx);
}

void cloneMembers(const CrossJoin& r);

CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right);
// use for clone
explicit CrossJoin(QueryContext* qctx);
};
} // namespace graph
} // namespace nebula
#endif // GRAPH_PLANNER_PLAN_ALGO_H_
28 changes: 28 additions & 0 deletions src/graph/planner/plan/Query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,34 @@ void HashInnerJoin::cloneMembers(const HashInnerJoin& l) {
HashJoin::cloneMembers(l);
}

std::unique_ptr<PlanNodeDescription> CrossJoin::explain() const {
return BinaryInputNode::explain();
}

PlanNode* CrossJoin::clone() const {
auto* node = make(qctx_);
node->cloneMembers(*this);
return node;
}

void CrossJoin::cloneMembers(const CrossJoin& r) {
BinaryInputNode::cloneMembers(r);
}

CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right)
: BinaryInputNode(qctx, Kind::kCrossJoin, left, right) {
auto lColNames = left->colNames();
auto rColNames = right->colNames();
lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end());
setColNames(lColNames);
}

void CrossJoin::accept(PlanNodeVisitor* visitor) {
visitor->visit(this);
}

CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {}

std::unique_ptr<PlanNodeDescription> RollUpApply::explain() const {
auto desc = BinaryInputNode::explain();
addDescription("compareCols", folly::toJson(util::toJson(compareCols_)), desc.get());
Expand Down
27 changes: 27 additions & 0 deletions src/graph/planner/plan/Query.h
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,33 @@ class HashInnerJoin final : public HashJoin {
void cloneMembers(const HashInnerJoin&);
};

class CrossJoin final : public BinaryInputNode {
public:
static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx, left, right);
}

std::unique_ptr<PlanNodeDescription> explain() const override;

PlanNode* clone() const override;

void accept(PlanNodeVisitor* visitor) override;

private:
friend ObjectPool;

// used for clone only
static CrossJoin* make(QueryContext* qctx) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx);
}

void cloneMembers(const CrossJoin& r);

CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right);
// use for clone
explicit CrossJoin(QueryContext* qctx);
};

// Roll Up Apply two results from two inputs.
class RollUpApply : public BinaryInputNode {
public:
Expand Down
32 changes: 32 additions & 0 deletions tests/tck/features/optimizer/PushFilterDownCrossJoinRule.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023 vesoft inc. All rights reserved.
#
# This source code is licensed under Apache 2.0 License.
Feature: Push Filter down HashInnerJoin rule

Background:
Given a graph with space named "nba"

Scenario: push filter down HashInnerJoin
When profiling query:
"""
with ['Tim Duncan', 'Tony Parker'] as id_list
match (v1:player)-[e]-(v2:player)
where id(v1) in ['Tim Duncan', 'Tony Parker'] AND id(v2) in ['Tim Duncan', 'Tony Parker']
return count(e)
"""
Then the result should be, in any order:
| count(e) |
| 8 |
And the execution plan should be:
| id | name | dependencies | operator info |
| 11 | Aggregate | 14 | |
| 14 | CrossJoin | 1,16 | |
| 1 | Project | 2 | |
| 2 | Start | | |
| 16 | Project | 15 | |
| 15 | Filter | 18 | {"condition": "((id($-.v1) IN [\"Tim Duncan\",\"Tony Parker\"]) AND (id($-.v2) IN [\"Tim Duncan\",\"Tony Parker\"]))"} |
| 18 | AppendVertices | 17 | |
| 17 | Traverse | 4 | |
| 4 | Dedup | 3 | |
| 3 | PassThrough | 5 | |
| 5 | Start | | |

0 comments on commit 4be13b0

Please sign in to comment.