Skip to content

Commit

Permalink
make evaluation of inner var expr thread-safe (#4913)
Browse files Browse the repository at this point in the history
  • Loading branch information
jievince authored Nov 22, 2022
1 parent f2cee66 commit 77f13a1
Show file tree
Hide file tree
Showing 20 changed files with 133 additions and 42 deletions.
7 changes: 7 additions & 0 deletions src/common/context/ExpressionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class ExpressionContext {
// Get the latest version value for the given variable name, such as $a, $b
virtual const Value& getVar(const std::string& var) const = 0;

// Set the value of innerVar. The innerVar is a variable defined in an expression.
// e.g. ListComprehension
virtual void setInnerVar(const std::string& var, Value val) = 0;

// Get the value of innerVar.
virtual const Value& getInnerVar(const std::string& var) const = 0;

// Get the given version value for the given variable name, such as $a, $b
virtual const Value& getVersionedVar(const std::string& var, int64_t version) const = 0;

Expand Down
2 changes: 1 addition & 1 deletion src/common/expression/ListComprehensionExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const Value& ListComprehensionExpression::eval(ExpressionContext& ctx) {

for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
if (filter_ != nullptr) {
auto& filterVal = filter_->eval(ctx);
if (!filterVal.empty() && !filterVal.isNull() && !filterVal.isImplicitBool()) {
Expand Down
8 changes: 4 additions & 4 deletions src/common/expression/PredicateExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ const Value& PredicateExpression::eval(ExpressionContext& ctx) {
result_ = true;
for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
auto& filterVal = filter_->eval(ctx);
if (!filterVal.empty() && !filterVal.isNull() && !filterVal.isImplicitBool()) {
return Value::kNullBadType;
Expand All @@ -104,7 +104,7 @@ const Value& PredicateExpression::eval(ExpressionContext& ctx) {
result_ = false;
for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
auto& filterVal = filter_->eval(ctx);
if (!filterVal.empty() && !filterVal.isNull() && !filterVal.isImplicitBool()) {
return Value::kNullBadType;
Expand All @@ -120,7 +120,7 @@ const Value& PredicateExpression::eval(ExpressionContext& ctx) {
result_ = false;
for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
auto& filterVal = filter_->eval(ctx);
if (!filterVal.empty() && !filterVal.isNull() && !filterVal.isImplicitBool()) {
return Value::kNullBadType;
Expand All @@ -140,7 +140,7 @@ const Value& PredicateExpression::eval(ExpressionContext& ctx) {
result_ = true;
for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
auto& filterVal = filter_->eval(ctx);
if (!filterVal.empty() && !filterVal.isNull() && !filterVal.isImplicitBool()) {
return Value::kNullBadType;
Expand Down
8 changes: 4 additions & 4 deletions src/common/expression/ReduceExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ const Value& ReduceExpression::eval(ExpressionContext& ctx) {
}
auto& list = listVal.getList();

ctx.setVar(accumulator_, initVal);
ctx.setInnerVar(accumulator_, initVal);
for (size_t i = 0; i < list.size(); ++i) {
auto& v = list[i];
ctx.setVar(innerVar_, v);
ctx.setInnerVar(innerVar_, v);
auto& mappingVal = mapping_->eval(ctx);
ctx.setVar(accumulator_, mappingVal);
ctx.setInnerVar(accumulator_, mappingVal);
}

result_ = ctx.getVar(accumulator_);
result_ = ctx.getInnerVar(accumulator_);
return result_;
}

Expand Down
3 changes: 3 additions & 0 deletions src/common/expression/VariableExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

namespace nebula {
const Value& VariableExpression::eval(ExpressionContext& ctx) {
if (isInner_) {
return ctx.getInnerVar(var_);
}
return ctx.getVar(var_);
}

Expand Down
15 changes: 10 additions & 5 deletions src/common/expression/VariableExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
namespace nebula {
class VariableExpression final : public Expression {
public:
static VariableExpression* make(ObjectPool* pool,
const std::string& var = "",
bool isInner = false) {
return pool->makeAndAdd<VariableExpression>(pool, var, isInner);
// Make a non-inner variable expression
static VariableExpression* make(ObjectPool* pool, const std::string& var = "") {
return pool->makeAndAdd<VariableExpression>(pool, var, false);
}

// Make a inner variable expression. Inner variable is a variable defined in an expression.
// e.g. ListComprehensionExpression [i IN range(1, 10) | i+1]
static VariableExpression* makeInner(ObjectPool* pool, const std::string& var = "") {
return pool->makeAndAdd<VariableExpression>(pool, var, true);
}

const std::string& var() const {
Expand All @@ -39,7 +44,7 @@ class VariableExpression final : public Expression {
void accept(ExprVisitor* visitor) override;

Expression* clone() const override {
return VariableExpression::make(pool_, var(), isInner_);
return pool_->makeAndAdd<VariableExpression>(pool_, var(), isInner_);
}

private:
Expand Down
12 changes: 12 additions & 0 deletions src/common/expression/test/ExpressionContextMock.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ class ExpressionContextMock final : public ExpressionContext {
}
}

void setInnerVar(const std::string& var, Value val) override {
exprValueMap_[var] = std::move(val);
}

const Value& getInnerVar(const std::string& var) const override {
auto it = exprValueMap_.find(var);
DCHECK(it != exprValueMap_.end());
return it->second;
}

const Value& getVersionedVar(const std::string& var, int64_t version) const override {
auto found = indices_.find(var);
if (found == indices_.end()) {
Expand Down Expand Up @@ -143,5 +153,7 @@ class ExpressionContextMock final : public ExpressionContext {
static std::unordered_map<std::string, std::size_t> indices_;
static std::vector<Value> vals_;
std::unordered_map<std::string, std::regex> regex_;
// Expression value map that stores the value of innerVar
std::unordered_map<std::string, Value> exprValueMap_;
};
} // namespace nebula
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ TEST_F(ListComprehensionExpressionTest, ListComprehensionEvaluate) {
"n",
ListExpression::make(&pool, listItems),
RelationalExpression::makeGE(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 2)),
&pool, VariableExpression::makeInner(&pool, "n"), ConstantExpression::make(&pool, 2)),
ArithmeticExpression::makeAdd(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 10)));
&pool, VariableExpression::makeInner(&pool, "n"), ConstantExpression::make(&pool, 10)));

auto value = Expression::eval(expr, gExpCtxt);
List expected;
Expand Down Expand Up @@ -57,7 +57,7 @@ TEST_F(ListComprehensionExpressionTest, ListComprehensionEvaluate) {
ArithmeticExpression::makeAdd(
&pool,
AttributeExpression::make(&pool,
VariableExpression::make(&pool, "n"),
VariableExpression::makeInner(&pool, "n"),
ConstantExpression::make(&pool, "age")),
ConstantExpression::make(&pool, 5)));

Expand Down
10 changes: 5 additions & 5 deletions src/common/expression/test/PredicateExpressionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TEST_F(PredicateExpressionTest, PredicateEvaluate) {
"n",
ListExpression::make(&pool, listItems),
RelationalExpression::makeGE(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 2)));
&pool, VariableExpression::makeInner(&pool, "n"), ConstantExpression::make(&pool, 2)));

auto value = Expression::eval(expr, gExpCtxt);
ASSERT_TRUE(value.isBool());
Expand All @@ -51,7 +51,7 @@ TEST_F(PredicateExpressionTest, PredicateEvaluate) {
RelationalExpression::makeGE(
&pool,
AttributeExpression::make(&pool,
VariableExpression::make(&pool, "n"),
VariableExpression::makeInner(&pool, "n"),
ConstantExpression::make(&pool, "age")),
ConstantExpression::make(&pool, 19)));

Expand All @@ -74,7 +74,7 @@ TEST_F(PredicateExpressionTest, PredicateEvaluate) {
"n",
ListExpression::make(&pool, listItems),
RelationalExpression::makeEQ(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 2)));
&pool, VariableExpression::makeInner(&pool, "n"), ConstantExpression::make(&pool, 2)));

auto value = Expression::eval(expr, gExpCtxt);
ASSERT_TRUE(value.isBool());
Expand All @@ -101,7 +101,7 @@ TEST_F(PredicateExpressionTest, PredicateEvaluate) {
RelationalExpression::makeGE(
&pool,
AttributeExpression::make(&pool,
VariableExpression::make(&pool, "n"),
VariableExpression::makeInner(&pool, "n"),
ConstantExpression::make(&pool, "age")),
ConstantExpression::make(&pool, 19)));

Expand All @@ -117,7 +117,7 @@ TEST_F(PredicateExpressionTest, PredicateEvaluate) {
"n",
ConstantExpression::make(&pool, Value(NullType::__NULL__)),
RelationalExpression::makeEQ(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 1)));
&pool, VariableExpression::makeInner(&pool, "n"), ConstantExpression::make(&pool, 1)));

auto value = Expression::eval(expr, gExpCtxt);
ASSERT_EQ(Value::kNullValue, value.getNull());
Expand Down
7 changes: 4 additions & 3 deletions src/common/expression/test/ReduceExpressionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ TEST_F(ReduceExpressionTest, ReduceEvaluate) {
FunctionCallExpression::make(&pool, "range", argList),
ArithmeticExpression::makeAdd(
&pool,
VariableExpression::make(&pool, "totalNum"),
ArithmeticExpression::makeMultiply(
&pool, VariableExpression::make(&pool, "n"), ConstantExpression::make(&pool, 2))));
VariableExpression::makeInner(&pool, "totalNum"),
ArithmeticExpression::makeMultiply(&pool,
VariableExpression::makeInner(&pool, "n"),
ConstantExpression::make(&pool, 2))));

auto value = Expression::eval(expr, gExpCtxt);
ASSERT_EQ(Value::Type::INT, value.type());
Expand Down
9 changes: 9 additions & 0 deletions src/common/utils/DefaultValueContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class DefaultValueContext final : public ExpressionContext {
return Value::kEmpty;
}

void setInnerVar(const std::string&, Value) override {
LOG(FATAL) << "Not allowed to call";
}

const Value& getInnerVar(const std::string&) const override {
LOG(FATAL) << "Not allowed to call";
return Value::kEmpty;
}

const Value& getVersionedVar(const std::string&, int64_t) const override {
LOG(FATAL) << "Not allowed to call";
return Value::kEmpty;
Expand Down
10 changes: 10 additions & 0 deletions src/graph/context/QueryExpressionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ const Value& QueryExpressionContext::getVar(const std::string& var) const {
return ectx_->getValue(var);
}

void QueryExpressionContext::setInnerVar(const std::string& var, Value val) {
exprValueMap_[var] = std::move(val);
}

const Value& QueryExpressionContext::getInnerVar(const std::string& var) const {
auto it = exprValueMap_.find(var);
if (it == exprValueMap_.end()) return Value::kEmpty;
return it->second;
}

const Value& QueryExpressionContext::getVersionedVar(const std::string& var,
int64_t version) const {
if (ectx_ == nullptr) {
Expand Down
10 changes: 10 additions & 0 deletions src/graph/context/QueryExpressionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ class QueryExpressionContext final : public ExpressionContext {
// Get the latest version value for the given variable name, such as $a, $b
const Value& getVar(const std::string& var) const override;

// Set the value of innerVar. The innerVar is a variable defined in an expression.
// e.g. ListComprehension
void setInnerVar(const std::string& var, Value val) override;

// Get the value of innerVar
const Value& getInnerVar(const std::string& var) const override;

// Get the given version value for the given variable name, such as $a, $b
const Value& getVersionedVar(const std::string& var, int64_t version) const override;

Expand Down Expand Up @@ -75,6 +82,9 @@ class QueryExpressionContext final : public ExpressionContext {
// could be evaluated as constant value.
ExecutionContext* ectx_{nullptr};
Iterator* iter_{nullptr};

// Expression value map that stores the value of innerVar
std::unordered_map<std::string, Value> exprValueMap_;
};

} // namespace graph
Expand Down
2 changes: 1 addition & 1 deletion src/graph/planner/match/MatchSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ static YieldColumn* buildEdgeColumn(QueryContext* qctx, const EdgeInfo& edge) {
} else {
auto innerVar = qctx->vctx()->anonVarGen()->getVar();
auto* args = ArgumentList::make(pool);
args->addArgument(VariableExpression::make(pool, innerVar));
args->addArgument(VariableExpression::makeInner(pool, innerVar));
auto* filter = FunctionCallExpression::make(pool, "is_edge", args);
expr = ListComprehensionExpression::make(
pool, innerVar, InputPropertyExpression::make(pool, edge.alias), filter);
Expand Down
20 changes: 8 additions & 12 deletions src/graph/util/ParserUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ void ParserUtil::rewriteLC(QueryContext *qctx,
// So to avoid conflict, we create a global unique anonymous variable name for it
// TODO store inner variable in inner
const auto &newVarName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newVarName, Value());
auto *pool = qctx->objPool();

auto matcher = [](const Expression *expr) -> bool {
Expand All @@ -35,7 +34,7 @@ void ParserUtil::rewriteLC(QueryContext *qctx,
case Expression::Kind::kLabel: {
auto *label = static_cast<const LabelExpression *>(expr);
if (label->name() == oldVarName) {
ret = VariableExpression::make(pool, newVarName, true);
ret = VariableExpression::makeInner(pool, newVarName);
} else {
ret = label->clone();
}
Expand All @@ -46,7 +45,7 @@ void ParserUtil::rewriteLC(QueryContext *qctx,
if (la->left()->name() == oldVarName) {
const auto &value = la->right()->value();
ret = AttributeExpression::make(pool,
VariableExpression::make(pool, newVarName, true),
VariableExpression::makeInner(pool, newVarName),
ConstantExpression::make(pool, value));
} else {
ret = la->clone();
Expand Down Expand Up @@ -87,7 +86,6 @@ void ParserUtil::rewritePred(QueryContext *qctx,
PredicateExpression *pred,
const std::string &oldVarName) {
const auto &newVarName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newVarName, Value());
auto *pool = qctx->objPool();

auto matcher = [](const Expression *expr) -> bool {
Expand All @@ -100,7 +98,7 @@ void ParserUtil::rewritePred(QueryContext *qctx,
if (expr->kind() == Expression::Kind::kLabel) {
auto *label = static_cast<const LabelExpression *>(expr);
if (label->name() == oldVarName) {
ret = VariableExpression::make(pool, newVarName, true);
ret = VariableExpression::makeInner(pool, newVarName);
} else {
ret = label->clone();
}
Expand All @@ -110,7 +108,7 @@ void ParserUtil::rewritePred(QueryContext *qctx,
if (la->left()->name() == oldVarName) {
const auto &value = la->right()->value();
ret = AttributeExpression::make(pool,
VariableExpression::make(pool, newVarName, true),
VariableExpression::makeInner(pool, newVarName),
ConstantExpression::make(pool, value));
} else {
ret = la->clone();
Expand All @@ -133,9 +131,7 @@ void ParserUtil::rewriteReduce(QueryContext *qctx,
const std::string &oldAccName,
const std::string &oldVarName) {
const auto &newAccName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newAccName, Value());
const auto &newVarName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newVarName, Value());
auto *pool = qctx->objPool();

auto matcher = [](const Expression *expr) -> bool {
Expand All @@ -147,9 +143,9 @@ void ParserUtil::rewriteReduce(QueryContext *qctx,
if (expr->kind() == Expression::Kind::kLabel) {
auto *label = static_cast<const LabelExpression *>(expr);
if (label->name() == oldAccName) {
ret = VariableExpression::make(pool, newAccName, true);
ret = VariableExpression::makeInner(pool, newAccName);
} else if (label->name() == oldVarName) {
ret = VariableExpression::make(pool, newVarName, true);
ret = VariableExpression::makeInner(pool, newVarName);
} else {
ret = label->clone();
}
Expand All @@ -159,12 +155,12 @@ void ParserUtil::rewriteReduce(QueryContext *qctx,
if (la->left()->name() == oldAccName) {
const auto &value = la->right()->value();
ret = AttributeExpression::make(pool,
VariableExpression::make(pool, newAccName, true),
VariableExpression::makeInner(pool, newAccName),
ConstantExpression::make(pool, value));
} else if (la->left()->name() == oldVarName) {
const auto &value = la->right()->value();
ret = AttributeExpression::make(pool,
VariableExpression::make(pool, newVarName, true),
VariableExpression::makeInner(pool, newVarName),
ConstantExpression::make(pool, value));
} else {
ret = la->clone();
Expand Down
7 changes: 5 additions & 2 deletions src/graph/visitor/ValidatePatternExpressionVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ void ValidatePatternExpressionVisitor::visit(MatchPathPatternExpression *expr) {
// TODO we should check variable is Node type
// from local variable
node->setVariableDefinedSource(MatchNode::VariableDefinedSource::kExpression);
auto *listElement = VariableExpression::make(pool_, listElementVar);
auto *listElement = VariableExpression::makeInner(pool_, listElementVar);
// Note: this require build path by node pattern order
auto *listElementId = FunctionCallExpression::make(
pool_,
"_nodeid",
{listElement, ConstantExpression::make(pool_, static_cast<int64_t>(i))});
auto *nodeValue = VariableExpression::make(pool_, node->alias());
// The alias of node is converted to a inner variable.
// e.g. MATCH (v:player) WHERE [t in [v] | (v)-[:like]->(t)] RETURN v
// More cases could be found in PathExprRefLocalVariable.feature
auto *nodeValue = VariableExpression::makeInner(pool_, node->alias());
auto *nodeId = FunctionCallExpression::make(pool_, "id", {nodeValue});
auto *equal = RelationalExpression::makeEQ(pool_, listElementId, nodeId);
nodeFilters.emplace_back(equal);
Expand Down
Loading

0 comments on commit 77f13a1

Please sign in to comment.