Skip to content

Commit

Permalink
Feature/pattern expression ref local variable (vesoft-inc#1169)
Browse files Browse the repository at this point in the history
* Initial.

* Get pattern expression value by expression.

* Filter path by local variable.

* Add tests.

Co-authored-by: Sophie <[email protected]>

Co-authored-by: shylock <[email protected]>
Co-authored-by: Sophie <[email protected]>
  • Loading branch information
3 people authored Aug 16, 2022
1 parent 7350fef commit 120d852
Show file tree
Hide file tree
Showing 18 changed files with 426 additions and 44 deletions.
8 changes: 4 additions & 4 deletions src/common/expression/MatchPathPatternExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace nebula {

const Value& MatchPathPatternExpression::eval(ExpressionContext& ctx) {
result_ = DCHECK_NOTNULL(prop_)->eval(ctx);
result_ = DCHECK_NOTNULL(genList_)->eval(ctx);
return result_;
}

Expand All @@ -23,7 +23,7 @@ bool MatchPathPatternExpression::operator==(const Expression& rhs) const {
return false;
}

// The prop_ field is used for evaluation internally, so it don't identify the expression.
// The genList_ field is used for evaluation internally, so it don't identify the expression.
// We don't compare it here.
// Ditto for result_ field.

Expand All @@ -41,8 +41,8 @@ void MatchPathPatternExpression::accept(ExprVisitor* visitor) {
Expression* MatchPathPatternExpression::clone() const {
auto expr =
MatchPathPatternExpression::make(pool_, std::make_unique<MatchPath>(matchPath_->clone()));
if (prop_ != nullptr) {
expr->setInputProp(static_cast<InputPropertyExpression*>(prop_->clone()));
if (genList_ != nullptr) {
expr->setGenList(genList_->clone());
}
return expr;
}
Expand Down
18 changes: 9 additions & 9 deletions src/common/expression/MatchPathPatternExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ class MatchPathPatternExpression final : public Expression {

Expression* clone() const override;

// Evaluate expression by fetch result from input variable
void setInputProp(const std::string& prop) {
prop_ = InputPropertyExpression::make(pool_, prop);
void setGenList(Expression* expr) {
genList_ = expr;
}

void setInputProp(InputPropertyExpression* expr) {
prop_ = expr;
const Expression* genList() const {
return genList_;
}

InputPropertyExpression* inputProp() const {
return prop_;
Expression* genList() {
return genList_;
}

const MatchPath& matchPath() const {
Expand Down Expand Up @@ -71,8 +70,9 @@ class MatchPathPatternExpression final : public Expression {

private:
std::unique_ptr<MatchPath> matchPath_;
InputPropertyExpression* prop_{
nullptr}; // The column of input stored the result of the expression
// The column of input stored the result of the expression
// The filter apply to each path in result List and generate a new result List.
Expression* genList_{nullptr};
Value result_;
};
} // namespace nebula
Expand Down
24 changes: 24 additions & 0 deletions src/common/function/FunctionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ std::unordered_map<std::string, std::vector<TypeSignature>> FunctionManager::typ
{TypeSignature({Value::Type::STRING}, Value::Type::DURATION),
TypeSignature({Value::Type::MAP}, Value::Type::DURATION)}},
{"extract", {TypeSignature({Value::Type::STRING, Value::Type::STRING}, Value::Type::LIST)}},
{"_nodeid", {TypeSignature({Value::Type::PATH, Value::Type::INT}, Value::Type::INT)}},
};

// static
Expand Down Expand Up @@ -2728,6 +2729,29 @@ FunctionManager::FunctionManager() {
return res;
};
}
{
auto &attr = functions_["_nodeid"];
attr.minArity_ = 2;
attr.maxArity_ = 2;
attr.isAlwaysPure_ = true;
attr.body_ = [](const auto &args) -> Value {
if (!args[0].get().isPath() || !args[1].get().isInt()) {
return Value::kNullBadType;
}

const auto &p = args[0].get().getPath();
const std::size_t nodeIndex = args[1].get().getInt();
if (nodeIndex < 0 || nodeIndex >= (1 + p.steps.size())) {
DLOG(FATAL) << "Out of range node index.";
return Value::kNullBadData;
}
if (nodeIndex == 0) {
return p.src.vid;
} else {
return p.steps[nodeIndex - 1].dst.vid;
}
};
}
} // NOLINT

// static
Expand Down
4 changes: 4 additions & 0 deletions src/graph/util/AnonVarGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class AnonVarGenerator final {
return var;
}

void createVar(const std::string& var) const {
symTable_->newVariable(var);
}

// Check is variable anonymous
// The parser don't allow user name variable started with `_`,
// `_` started variable is generated by nebula only.
Expand Down
12 changes: 6 additions & 6 deletions src/graph/util/ParserUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ namespace graph {
void ParserUtil::rewriteLC(QueryContext *qctx,
ListComprehensionExpression *lc,
const std::string &oldVarName) {
const auto &newVarName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newVarName, Value());
qctx->vctx()->anonVarGen()->createVar(oldVarName);
qctx->ectx()->setValue(oldVarName, Value());
auto *pool = qctx->objPool();

auto matcher = [](const Expression *expr) -> bool {
return expr->kind() == Expression::Kind::kLabel ||
expr->kind() == Expression::Kind::kLabelAttribute;
};

auto rewriter = [&, pool, newVarName](const Expression *expr) {
auto rewriter = [&, pool, oldVarName](const Expression *expr) {
Expression *ret = nullptr;
if (expr->kind() == Expression::Kind::kLabel) {
auto *label = static_cast<const LabelExpression *>(expr);
if (label->name() == oldVarName) {
ret = VariableExpression::make(pool, newVarName, true);
ret = VariableExpression::make(pool, oldVarName, true);
} else {
ret = label->clone();
}
Expand All @@ -38,7 +38,7 @@ void ParserUtil::rewriteLC(QueryContext *qctx,
if (la->left()->name() == oldVarName) {
const auto &value = la->right()->value();
ret = AttributeExpression::make(pool,
VariableExpression::make(pool, newVarName, true),
VariableExpression::make(pool, oldVarName, true),
ConstantExpression::make(pool, value));
} else {
ret = la->clone();
Expand All @@ -48,7 +48,7 @@ void ParserUtil::rewriteLC(QueryContext *qctx,
};

lc->setOriginString(lc->toString());
lc->setInnerVar(newVarName);
lc->setInnerVar(oldVarName);
if (lc->hasFilter()) {
Expression *filter = lc->filter();
auto *newFilter = RewriteVisitor::transform(filter, matcher, rewriter);
Expand Down
14 changes: 13 additions & 1 deletion src/graph/validator/MatchValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "graph/util/ExpressionUtils.h"
#include "graph/visitor/ExtractGroupSuiteVisitor.h"
#include "graph/visitor/RewriteVisitor.h"
#include "graph/visitor/ValidatePatternExpressionVisitor.h"

namespace nebula {
namespace graph {
Expand Down Expand Up @@ -1007,6 +1008,9 @@ Status MatchValidator::validateMatchPathExpr(
Expression *expr,
const std::unordered_map<std::string, AliasType> &availableAliases,
std::vector<Path> &paths) {
auto *pool = qctx_->objPool();
ValidatePatternExpressionVisitor visitor(pool, vctx_);
expr->accept(&visitor);
auto matchPathExprs = ExpressionUtils::collectAll(expr, {Expression::Kind::kMatchPathPattern});
for (auto &matchPathExpr : matchPathExprs) {
// auto matchClauseCtx = getContext<MatchClauseContext>();
Expand All @@ -1021,7 +1025,11 @@ Status MatchValidator::validateMatchPathExpr(
auto &matchPath = matchPathExprImpl->matchPath();
auto pathAlias = matchPath.toString();
matchPath.setAlias(new std::string(pathAlias));
matchPathExprImpl->setInputProp(pathAlias);
if (matchPathExprImpl->genList() == nullptr) {
// Don't done in expression visitor
Expression *genList = InputPropertyExpression::make(pool, pathAlias);
matchPathExprImpl->setGenList(genList);
}
paths.emplace_back();
NG_RETURN_IF_ERROR(validatePath(&matchPath, paths.back()));
NG_RETURN_IF_ERROR(buildRollUpPathInfo(&matchPath, paths.back()));
Expand All @@ -1045,6 +1053,10 @@ Status MatchValidator::validateMatchPathExpr(
}
}
for (const auto &node : matchPath.nodes()) {
if (node->variableDefinedSource() == MatchNode::VariableDefinedSource::kExpression) {
// Checked in visitor
continue;
}
if (!node->alias().empty()) {
const auto find = availableAliases.find(node->alias());
if (find == availableAliases.end()) {
Expand Down
1 change: 1 addition & 0 deletions src/graph/visitor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nebula_add_library(
VidExtractVisitor.cpp
EvaluableExprVisitor.cpp
ExtractGroupSuiteVisitor.cpp
ValidatePatternExpressionVisitor.cpp
)

nebula_add_library(
Expand Down
4 changes: 2 additions & 2 deletions src/graph/visitor/ExprVisitorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ void ExprVisitorImpl::visit(SubscriptRangeExpression *expr) {

void ExprVisitorImpl::visit(MatchPathPatternExpression *expr) {
DCHECK(ok()) << expr->toString();
if (expr->inputProp() != nullptr) {
expr->inputProp()->accept(this);
if (expr->genList() != nullptr) {
expr->genList()->accept(this);
if (!ok()) {
return;
}
Expand Down
4 changes: 2 additions & 2 deletions src/graph/visitor/FindVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ void FindVisitor::visit(SubscriptRangeExpression* expr) {
void FindVisitor::visit(MatchPathPatternExpression* expr) {
findInCurrentExpr(expr);
if (!needFindAll_ && !foundExprs_.empty()) return;
if (expr->inputProp() != nullptr) {
expr->inputProp()->accept(this);
if (expr->genList() != nullptr) {
expr->genList()->accept(this);
if (!needFindAll_ && !foundExprs_.empty()) return;
}
}
Expand Down
11 changes: 4 additions & 7 deletions src/graph/visitor/RewriteSymExprVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,11 @@ void RewriteSymExprVisitor::visit(SubscriptRangeExpression *expr) {
}

void RewriteSymExprVisitor::visit(MatchPathPatternExpression *expr) {
if (expr->inputProp() != nullptr) {
expr->inputProp()->accept(this);
if (expr->genList() != nullptr) {
expr->genList()->accept(this);
if (expr_) {
if (expr_->kind() != Expression::Kind::kInputProperty) {
hasWrongType_ = true;
return;
}
expr->setInputProp(static_cast<InputPropertyExpression *>(expr_));
expr->setGenList(expr_);
expr_ = nullptr;
}
}
}
Expand Down
70 changes: 70 additions & 0 deletions src/graph/visitor/ValidatePatternExpressionVisitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2022 vesoft inc. All rights reserved.
//
// This source code is licensed under Apache 2.0 License.
#include "graph/visitor/ValidatePatternExpressionVisitor.h"

#include "ExprVisitorImpl.h"
#include "graph/context/ValidateContext.h"

namespace nebula {
namespace graph {

void ValidatePatternExpressionVisitor::visit(ListComprehensionExpression *expr) {
DCHECK(ok()) << expr->toString();
// Store current available variables in expression
localVariables_.push_front(expr->innerVar());
SCOPE_EXIT {
localVariables_.pop_front();
};
ExprVisitorImpl::visit(expr);
}

void ValidatePatternExpressionVisitor::visit(MatchPathPatternExpression *expr) {
DCHECK(ok()) << expr->toString();
// don't need to process sub-expression
const auto &matchPath = expr->matchPath();
std::vector<Expression *> nodeFilters;
auto *pathList = InputPropertyExpression::make(pool_, matchPath.toString());
auto listElementVar = vctx_->anonVarGen()->getVar();
for (std::size_t i = 0; i < matchPath.nodes().size(); ++i) {
const auto &node = matchPath.nodes()[i];
if (!node->alias().empty()) {
const auto find = std::find(localVariables_.begin(), localVariables_.end(), node->alias());
if (find != localVariables_.end()) {
// TODO we should check variable is Node type
// from local variable
node->setVariableDefinedSource(MatchNode::VariableDefinedSource::kExpression);
auto *listElement = VariableExpression::make(pool_, listElementVar);
// Note: this require build path by node pattern order
auto *listElementId = FunctionCallExpression::make(
pool_,
"_nodeid",
{listElement, ConstantExpression::make(pool_, static_cast<int64_t>(i))});
auto *nodeValue = VariableExpression::make(pool_, node->alias());
auto *nodeId = FunctionCallExpression::make(pool_, "id", {nodeValue});
auto *equal = RelationalExpression::makeEQ(pool_, listElementId, nodeId);
nodeFilters.emplace_back(equal);
}
}
}
if (!nodeFilters.empty()) {
auto genList = ListComprehensionExpression::make(
pool_, listElementVar, pathList, andAll(nodeFilters), nullptr);
expr->setGenList(genList);
}
}

Expression *ValidatePatternExpressionVisitor::andAll(const std::vector<Expression *> &exprs) const {
CHECK(!exprs.empty());
if (exprs.size() == 1) {
return exprs[0];
}
auto *expr = exprs[0];
for (std::size_t i = 1; i < exprs.size(); ++i) {
expr = LogicalExpression::makeAnd(pool_, expr, exprs[i]);
}
return expr;
}

} // namespace graph
} // namespace nebula
61 changes: 61 additions & 0 deletions src/graph/visitor/ValidatePatternExpressionVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2022 vesoft inc. All rights reserved.
//
// This source code is licensed under Apache 2.0 License.

#pragma once

#include "common/expression/Expression.h"
#include "graph/visitor/ExprVisitorImpl.h"

namespace nebula {
namespace graph {

class ValidateContext;

class ValidatePatternExpressionVisitor final : public ExprVisitorImpl {
public:
explicit ValidatePatternExpressionVisitor(ObjectPool *pool, ValidateContext *vctx)
: pool_(pool), vctx_(vctx) {}

bool ok() const override {
// TODO: delete this interface
return true;
}

private:
using ExprVisitorImpl::visit;
void visit(ConstantExpression *) override {}
void visit(LabelExpression *) override {}
void visit(UUIDExpression *) override {}
void visit(VariableExpression *) override {}
void visit(VersionedVariableExpression *) override {}
void visit(TagPropertyExpression *) override {}
void visit(LabelTagPropertyExpression *) override {}
void visit(EdgePropertyExpression *) override {}
void visit(InputPropertyExpression *) override {}
void visit(VariablePropertyExpression *) override {}
void visit(DestPropertyExpression *) override {}
void visit(SourcePropertyExpression *) override {}
void visit(EdgeSrcIdExpression *) override {}
void visit(EdgeTypeExpression *) override {}
void visit(EdgeRankExpression *) override {}
void visit(EdgeDstIdExpression *) override {}
void visit(VertexExpression *) override {}
void visit(EdgeExpression *) override {}
void visit(ColumnExpression *) override {}

void visit(ListComprehensionExpression *expr) override;
// match path pattern expression
void visit(MatchPathPatternExpression *expr) override;

Expression *andAll(const std::vector<Expression *> &exprs) const;

private:
ObjectPool *pool_{nullptr};
ValidateContext *vctx_{nullptr};

std::list<std::string> localVariables_; // local variable defined in List Comprehension
};

} // namespace graph
} // namespace nebula
Loading

0 comments on commit 120d852

Please sign in to comment.