Skip to content

Commit

Permalink
fix push down rank(e) (vesoft-inc#5135)
Browse files Browse the repository at this point in the history
  • Loading branch information
jievince authored Dec 29, 2022
1 parent 967a8c9 commit 373a847
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 43 deletions.
40 changes: 24 additions & 16 deletions src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,34 @@ StatusOr<OptRule::TransformResult> PushFilterDownTraverseRule::transform(
auto pool = qctx->objPool();

// Pick the expr looks like `$-.e[0].likeness
auto picker = [&edgeAlias](const Expression* e) -> bool {
// TODO(jie): Handle the strange exists expr. e.g. exists(e.likeness)
auto exprs = graph::ExpressionUtils::collectAll(e, {Expression::Kind::kPredicate});
for (auto* expr : exprs) {
if (static_cast<const PredicateExpression*>(expr)->name() == "exists") {
auto picker = [&edgeAlias](const Expression* expr) -> bool {
bool shouldNotPick = false;
auto finder = [&shouldNotPick, &edgeAlias](const Expression* e) -> bool {
// When visiting the expression tree and find an expession node is a one step edge property
// expression, stop visiting its children and return true.
if (graph::ExpressionUtils::isOneStepEdgeProp(edgeAlias, e)) return true;
// Otherwise, continue visiting its children. And if the following two conditions are met,
// mark the expression as shouldNotPick and return false.
if (e->kind() == Expression::Kind::kInputProperty ||
e->kind() == Expression::Kind::kVarProperty) {
shouldNotPick = true;
return false;
}
// TODO(jie): Handle the strange exists expr. e.g. exists(e.likeness)
if (e->kind() == Expression::Kind::kPredicate &&
static_cast<const PredicateExpression*>(e)->name() == "exists") {
shouldNotPick = true;
return false;
}
}

auto varProps = graph::ExpressionUtils::collectAll(
e, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty});
if (varProps.empty()) {
return false;
};
graph::FindVisitor visitor(finder, true, true);
const_cast<Expression*>(expr)->accept(&visitor);
if (shouldNotPick) return false;
if (!visitor.results().empty()) {
return true;
}
for (auto* expr : varProps) {
DCHECK(graph::ExpressionUtils::isPropertyExpr(expr));
auto& propName = static_cast<const PropertyExpression*>(expr)->prop();
if (propName != edgeAlias) return false;
}
return true;
return false;
};
Expression* filterPicked = nullptr;
Expression* filterUnpicked = nullptr;
Expand Down
36 changes: 34 additions & 2 deletions src/graph/util/ExpressionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,38 @@ Expression *ExpressionUtils::rewriteAttr2LabelTagProp(
return RewriteVisitor::transform(expr, std::move(matcher), std::move(rewriter));
}

// rewrite rank(e) to e._rank
Expression *ExpressionUtils::rewriteRankFunc2LabelAttribute(
const Expression *expr, const std::unordered_map<std::string, AliasType> &aliasTypeMap) {
ObjectPool *pool = expr->getObjPool();
auto matcher = [&aliasTypeMap](const Expression *e) -> bool {
if (e->kind() != Expression::Kind::kFunctionCall) return false;

auto *funcExpr = static_cast<const FunctionCallExpression *>(e);
auto funcName = funcExpr->name();
std::transform(funcName.begin(), funcName.end(), funcName.begin(), ::tolower);
if (funcName != "rank") return false;
auto args = funcExpr->args()->args();
if (args.size() != 1) return false;
if (args[0]->kind() != Expression::Kind::kLabel) return false;

auto &label = static_cast<const LabelExpression *>(args[0])->name();
auto iter = aliasTypeMap.find(label);
if (iter == aliasTypeMap.end() || iter->second != AliasType::kEdge) {
return false;
}
return true;
};
auto rewriter = [pool](const Expression *e) -> Expression * {
auto funcExpr = static_cast<const FunctionCallExpression *>(e);
auto args = funcExpr->args()->args();
return LabelAttributeExpression::make(
pool, static_cast<LabelExpression *>(args[0]), ConstantExpression::make(pool, "_rank"));
};

return RewriteVisitor::transform(expr, std::move(matcher), std::move(rewriter));
}

Expression *ExpressionUtils::rewriteLabelAttr2TagProp(const Expression *expr) {
ObjectPool *pool = expr->getObjPool();
auto matcher = [](const Expression *e) -> bool {
Expand Down Expand Up @@ -1518,7 +1550,7 @@ bool ExpressionUtils::checkExprDepth(const Expression *expr) {
}

/*static*/
bool ExpressionUtils::isSingleLenExpandExpr(const std::string &edgeAlias, const Expression *expr) {
bool ExpressionUtils::isOneStepEdgeProp(const std::string &edgeAlias, const Expression *expr) {
if (expr->kind() != Expression::Kind::kAttribute) {
return false;
}
Expand Down Expand Up @@ -1562,7 +1594,7 @@ bool ExpressionUtils::isSingleLenExpandExpr(const std::string &edgeAlias, const
const std::string &edgeAlias,
Expression *expr) {
graph::RewriteVisitor::Matcher matcher = [&edgeAlias](const Expression *e) -> bool {
return isSingleLenExpandExpr(edgeAlias, e);
return isOneStepEdgeProp(edgeAlias, e);
};
graph::RewriteVisitor::Rewriter rewriter = [pool](const Expression *e) -> Expression * {
DCHECK_EQ(e->kind(), Expression::Kind::kAttribute);
Expand Down
6 changes: 5 additions & 1 deletion src/graph/util/ExpressionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class ExpressionUtils {
static Expression* rewriteAttr2LabelTagProp(
const Expression* expr, const std::unordered_map<std::string, AliasType>& aliasTypeMap);

// rewrite rank(e) to e._rank
static Expression* rewriteRankFunc2LabelAttribute(
const Expression* expr, const std::unordered_map<std::string, AliasType>& aliasTypeMap);

// rewrite LabelAttr to tagProp
static Expression* rewriteLabelAttr2TagProp(const Expression* expr);

Expand Down Expand Up @@ -239,7 +243,7 @@ class ExpressionUtils {
static bool isVidPredication(const Expression* expr);

// Check if the expr looks like `$-.e[0].likeness`
static bool isSingleLenExpandExpr(const std::string& edgeAlias, const Expression* expr);
static bool isOneStepEdgeProp(const std::string& edgeAlias, const Expression* expr);

static Expression* rewriteEdgePropertyFilter(ObjectPool* pool,
const std::string& edgeAlias,
Expand Down
8 changes: 6 additions & 2 deletions src/graph/validator/MatchValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,12 @@ Status MatchValidator::validateFilter(const Expression *filter,
auto transformRes = ExpressionUtils::filterTransform(newFilter);
NG_RETURN_IF_ERROR(transformRes);
// rewrite Attribute to LabelTagProperty
whereClauseCtx.filter = ExpressionUtils::rewriteAttr2LabelTagProp(
transformRes.value(), whereClauseCtx.aliasesAvailable);
newFilter = ExpressionUtils::rewriteAttr2LabelTagProp(transformRes.value(),
whereClauseCtx.aliasesAvailable);
newFilter =
ExpressionUtils::rewriteRankFunc2LabelAttribute(newFilter, whereClauseCtx.aliasesAvailable);

whereClauseCtx.filter = newFilter;

auto typeStatus = deduceExprType(whereClauseCtx.filter);
NG_RETURN_IF_ERROR(typeStatus);
Expand Down
40 changes: 21 additions & 19 deletions src/graph/visitor/FindVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ namespace nebula {
namespace graph {

void FindVisitor::visit(TypeCastingExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->operand()->accept(this);
}

void FindVisitor::visit(UnaryExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->operand()->accept(this);
}

void FindVisitor::visit(FunctionCallExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& arg : expr->args()->args()) {
arg->accept(this);
Expand All @@ -28,13 +28,13 @@ void FindVisitor::visit(FunctionCallExpression* expr) {
}

void FindVisitor::visit(AggregateExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->arg()->accept(this);
}

void FindVisitor::visit(ListExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& item : expr->items()) {
item->accept(this);
Expand All @@ -43,7 +43,7 @@ void FindVisitor::visit(ListExpression* expr) {
}

void FindVisitor::visit(SetExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& item : expr->items()) {
item->accept(this);
Expand All @@ -52,7 +52,7 @@ void FindVisitor::visit(SetExpression* expr) {
}

void FindVisitor::visit(MapExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& pair : expr->items()) {
pair.second->accept(this);
Expand All @@ -61,7 +61,7 @@ void FindVisitor::visit(MapExpression* expr) {
}

void FindVisitor::visit(CaseExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;

if (expr->hasCondition()) {
Expand All @@ -81,7 +81,7 @@ void FindVisitor::visit(CaseExpression* expr) {
}

void FindVisitor::visit(PredicateExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;

expr->collection()->accept(this);
Expand All @@ -92,7 +92,7 @@ void FindVisitor::visit(PredicateExpression* expr) {
}

void FindVisitor::visit(ReduceExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;

expr->initial()->accept(this);
Expand All @@ -104,7 +104,7 @@ void FindVisitor::visit(ReduceExpression* expr) {
}

void FindVisitor::visit(ListComprehensionExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;

expr->collection()->accept(this);
Expand All @@ -121,7 +121,7 @@ void FindVisitor::visit(ListComprehensionExpression* expr) {
}

void FindVisitor::visit(LogicalExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& operand : expr->operands()) {
operand->accept(this);
Expand Down Expand Up @@ -190,7 +190,7 @@ void FindVisitor::visit(LabelExpression* expr) {
}

void FindVisitor::visit(LabelAttributeExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->left()->accept(this);
if (!needFindAll_ && !foundExprs_.empty()) return;
Expand All @@ -202,7 +202,7 @@ void FindVisitor::visit(VertexExpression* expr) {
}

void FindVisitor::visit(LabelTagPropertyExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->label()->accept(this);
}
Expand All @@ -216,7 +216,7 @@ void FindVisitor::visit(ColumnExpression* expr) {
}

void FindVisitor::visit(PathBuildExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
for (const auto& item : expr->items()) {
item->accept(this);
Expand All @@ -225,7 +225,7 @@ void FindVisitor::visit(PathBuildExpression* expr) {
}

void FindVisitor::visit(SubscriptRangeExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;

expr->list()->accept(this);
Expand All @@ -243,7 +243,7 @@ void FindVisitor::visit(SubscriptRangeExpression* expr) {
}

void FindVisitor::visit(MatchPathPatternExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
if (expr->genList() != nullptr) {
expr->genList()->accept(this);
Expand All @@ -252,17 +252,19 @@ void FindVisitor::visit(MatchPathPatternExpression* expr) {
}

void FindVisitor::visitBinaryExpr(BinaryExpression* expr) {
findInCurrentExpr(expr);
if (findInCurrentExpr(expr) && stopVisitChildrenAfterFind_) return;
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->left()->accept(this);
if (!needFindAll_ && !foundExprs_.empty()) return;
expr->right()->accept(this);
}

void FindVisitor::findInCurrentExpr(Expression* expr) {
bool FindVisitor::findInCurrentExpr(Expression* expr) {
if (finder_(expr)) {
foundExprs_.emplace_back(expr);
return true;
}
return false;
}

} // namespace graph
Expand Down
11 changes: 8 additions & 3 deletions src/graph/visitor/FindVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ class FindVisitor final : public ExprVisitorImpl {
public:
using Finder = std::function<bool(Expression*)>;

explicit FindVisitor(Finder finder, bool needFindAll = false)
: finder_(finder), needFindAll_(needFindAll) {}
explicit FindVisitor(Finder finder,
bool needFindAll = false,
bool stopVisitChildrenAfterFind = false)
: finder_(finder),
needFindAll_(needFindAll),
stopVisitChildrenAfterFind_(stopVisitChildrenAfterFind) {}

bool ok() const override {
// TODO: delete this interface
Expand Down Expand Up @@ -80,11 +84,12 @@ class FindVisitor final : public ExprVisitorImpl {
void visit(MatchPathPatternExpression* expr) override;

void visitBinaryExpr(BinaryExpression* expr) override;
void findInCurrentExpr(Expression* expr);
bool findInCurrentExpr(Expression* expr);

private:
Finder finder_;
bool needFindAll_;
bool stopVisitChildrenAfterFind_{false};
std::vector<const Expression*> foundExprs_;
};

Expand Down
Loading

0 comments on commit 373a847

Please sign in to comment.