Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vertex filter for GetNeighbors #4671

Merged
merged 5 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/clients/storage/StorageClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ StorageRpcRespFuture<cpp2::GetNeighborsResponse> StorageClient::getNeighbors(
bool random,
const std::vector<cpp2::OrderBy>& orderBy,
int64_t limit,
const Expression* filter) {
const Expression* filter,
const Expression* tagFilter) {
auto cbStatus = getIdFromValue(param.space);
if (!cbStatus.ok()) {
return folly::makeFuture<StorageRpcResponse<cpp2::GetNeighborsResponse>>(
Expand Down Expand Up @@ -97,6 +98,9 @@ StorageRpcRespFuture<cpp2::GetNeighborsResponse> StorageClient::getNeighbors(
if (filter != nullptr) {
spec.filter_ref() = filter->encode();
}
if (tagFilter != nullptr) {
spec.tag_filter_ref() = tagFilter->encode();
}
req.traverse_spec_ref() = std::move(spec);
}

Expand Down
3 changes: 2 additions & 1 deletion src/clients/storage/StorageClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class StorageClient
bool random = false,
const std::vector<cpp2::OrderBy>& orderBy = std::vector<cpp2::OrderBy>(),
int64_t limit = std::numeric_limits<int64_t>::max(),
const Expression* filter = nullptr);
const Expression* filter = nullptr,
const Expression* tagFilter = nullptr);

StorageRpcRespFuture<cpp2::GetDstBySrcResponse> getDstBySrc(
const CommonRequestParam& param,
Expand Down
1 change: 1 addition & 0 deletions src/graph/executor/algo/BatchShortestPath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ folly::Future<Status> BatchShortestPath::getNeighbors(size_t rowNum, size_t step
false,
{},
-1,
nullptr,
nullptr)
.via(qctx_->rctx()->runner())
.thenValue([this, rowNum, reverse, stepNum, getNbrTime](auto&& resp) {
Expand Down
1 change: 1 addition & 0 deletions src/graph/executor/algo/SingleShortestPath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ folly::Future<Status> SingleShortestPath::getNeighbors(size_t rowNum,
false,
{},
-1,
nullptr,
nullptr)
.via(qctx_->rctx()->runner())
.thenValue([this, rowNum, stepNum, getNbrTime, reverse](auto&& resp) {
Expand Down
3 changes: 2 additions & 1 deletion src/graph/executor/query/GetNeighborsExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ folly::Future<Status> GetNeighborsExecutor::execute() {
gn_->random(),
gn_->orderBy(),
gn_->limit(qec),
gn_->filter())
gn_->filter(),
nullptr)
.via(runner())
.ensure([this, getNbrTime]() {
SCOPED_TIMER(&execTime_);
Expand Down
3 changes: 2 additions & 1 deletion src/graph/executor/query/TraverseExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ folly::Future<Status> TraverseExecutor::getNeighbors() {
finalStep ? traverse_->random() : false,
finalStep ? traverse_->orderBy() : std::vector<storage::cpp2::OrderBy>(),
finalStep ? traverse_->limit(qctx()) : -1,
selectFilter())
selectFilter(),
nullptr)
.via(runner())
.thenValue([this, getNbrTime](StorageRpcResponse<GetNeighborsResponse>&& resp) mutable {
SCOPED_TIMER(&execTime_);
Expand Down
2 changes: 2 additions & 0 deletions src/interface/storage.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ struct TraverseSpec {
10: optional i64 limit,
// If provided, only the rows satisfied the given expression will be returned
11: optional binary filter,
// only contain filter expression for tag, tag_filter is a subset of filter
12: optional binary tag_filter,
}


Expand Down
1 change: 1 addition & 0 deletions src/storage/CommonUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ enum class ResultStatus {
NORMAL = 0,
ILLEGAL_DATA = -1,
FILTER_OUT = -2,
TAG_FILTER_OUT = -3,
};

struct PropContext;
Expand Down
23 changes: 19 additions & 4 deletions src/storage/exec/FilterNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ class FilterNode : public IterateNode<T> {
FilterNode(RuntimeContext* context,
IterateNode<T>* upstream,
StorageExpressionContext* expCtx = nullptr,
Expression* exp = nullptr)
: IterateNode<T>(upstream), context_(context), expCtx_(expCtx), filterExp_(exp) {
Expression* exp = nullptr,
Expression* tagFilterExp = nullptr)
: IterateNode<T>(upstream),
context_(context),
expCtx_(expCtx),
filterExp_(exp),
tagFilterExp_(tagFilterExp) {
IterateNode<T>::name_ = "FilterNode";
}

Expand All @@ -55,7 +60,9 @@ class FilterNode : public IterateNode<T> {
break;
}
if (this->valid() && !check()) {
context_->resultStat_ = ResultStatus::FILTER_OUT;
if (context_->resultStat_ != ResultStatus::TAG_FILTER_OUT) {
context_->resultStat_ = ResultStatus::FILTER_OUT;
}
this->next();
continue;
}
Expand Down Expand Up @@ -93,6 +100,13 @@ class FilterNode : public IterateNode<T> {
// return true when the value iter points to a value which can filter
bool checkTagAndEdge() {
expCtx_->reset(this->reader(), this->key().str());
if (tagFilterExp_ != nullptr) {
auto res = tagFilterExp_->eval(*expCtx_);
if (!res.isBool() || !res.getBool()) {
context_->resultStat_ = ResultStatus::TAG_FILTER_OUT;
return false;
}
}
// result is false when filter out
auto result = filterExp_->eval(*expCtx_);
// NULL is always false
Expand All @@ -103,7 +117,8 @@ class FilterNode : public IterateNode<T> {
private:
RuntimeContext* context_;
StorageExpressionContext* expCtx_;
Expression* filterExp_;
Expression* filterExp_{nullptr};
Expression* tagFilterExp_{nullptr};
FilterMode mode_{FilterMode::TAG_AND_EDGE};
int32_t callCheck{0};
};
Expand Down
6 changes: 6 additions & 0 deletions src/storage/exec/GetNeighborsNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class GetNeighborsNode : public QueryNode<VertexID> {
return nebula::cpp2::ErrorCode::E_INVALID_DATA;
}

if (context_->resultStat_ == ResultStatus::TAG_FILTER_OUT) {
pengweisong marked this conversation as resolved.
Show resolved Hide resolved
// if the filter condition of the tag is not satisfied
// do not return the data for this vertex and corresponding edge
return nebula::cpp2::ErrorCode::SUCCEEDED;
}

std::vector<Value> row;
// vertexId is the first column
if (context_->isIntId()) {
Expand Down
23 changes: 13 additions & 10 deletions src/storage/query/GetNeighborsProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ StoragePlan<VertexID> GetNeighborsProcessor::buildPlan(RuntimeContext* context,
}

if (filter_) {
auto filter =
std::make_unique<FilterNode<VertexID>>(context, upstream, expCtx, filter_->clone());
auto filter = std::make_unique<FilterNode<VertexID>>(
context, upstream, expCtx, filter_->clone(), tagFilter_ ? tagFilter_->clone() : nullptr);
filter->addDependency(upstream);
upstream = filter.get();
if (edges.empty()) {
filter.get()->setFilterMode(FilterMode::TAG_ONLY);
filter->setFilterMode(FilterMode::TAG_ONLY);
}
plan.addNode(std::move(filter));
}
Expand Down Expand Up @@ -313,13 +313,16 @@ nebula::cpp2::ErrorCode GetNeighborsProcessor::checkAndBuildContexts(
if (code != nebula::cpp2::ErrorCode::SUCCEEDED) {
return code;
}
code = buildFilter(req, [](const cpp2::GetNeighborsRequest& r) -> const std::string* {
if (r.get_traverse_spec().filter_ref().has_value()) {
return r.get_traverse_spec().get_filter();
} else {
return nullptr;
}
});
code =
buildFilter(req, [](const cpp2::GetNeighborsRequest& r, bool onlyTag) -> const std::string* {
if (onlyTag) {
return r.get_traverse_spec().tag_filter_ref().has_value()
? r.get_traverse_spec().get_tag_filter()
: nullptr;
}
return r.get_traverse_spec().filter_ref().has_value() ? r.get_traverse_spec().get_filter()
: nullptr;
});
if (code != nebula::cpp2::ErrorCode::SUCCEEDED) {
return code;
}
Expand Down
3 changes: 2 additions & 1 deletion src/storage/query/GetPropProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ nebula::cpp2::ErrorCode GetPropProcessor::checkAndBuildContexts(const cpp2::GetP
return code;
}
}
code = buildFilter(req, [](const cpp2::GetPropRequest& r) -> const std::string* {
code = buildFilter(req, [](const cpp2::GetPropRequest& r, bool onlyTag) -> const std::string* {
UNUSED(onlyTag);
if (r.filter_ref().has_value()) {
return r.get_filter();
} else {
Expand Down
11 changes: 9 additions & 2 deletions src/storage/query/QueryBaseProcessor-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ nebula::cpp2::ErrorCode QueryBaseProcessor<REQ, RESP>::buildYields(const REQ& re

template <typename REQ, typename RESP>
nebula::cpp2::ErrorCode QueryBaseProcessor<REQ, RESP>::buildFilter(
const REQ& req, std::function<const std::string*(const REQ& req)>&& getFilter) {
const auto* filterStr = getFilter(req);
const REQ& req, std::function<const std::string*(const REQ& req, bool onlyTag)>&& getFilter) {
const auto* filterStr = getFilter(req, false);
if (filterStr == nullptr) {
return nebula::cpp2::ErrorCode::SUCCEEDED;
}
Expand All @@ -152,6 +152,13 @@ nebula::cpp2::ErrorCode QueryBaseProcessor<REQ, RESP>::buildFilter(
if (filter_ == nullptr) {
return nebula::cpp2::ErrorCode::E_INVALID_FILTER;
}
const auto* tagFilterStr = getFilter(req, true);
if (tagFilterStr != nullptr && !tagFilterStr->empty()) {
tagFilter_ = Expression::decode(pool, *tagFilterStr);
if (tagFilter_ == nullptr) {
return nebula::cpp2::ErrorCode::E_INVALID_FILTER;
}
}
return checkExp(filter_, false, true, false, true);
}
return nebula::cpp2::ErrorCode::SUCCEEDED;
Expand Down
3 changes: 2 additions & 1 deletion src/storage/query/QueryBaseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class QueryBaseProcessor : public BaseProcessor<RESP> {
nebula::cpp2::ErrorCode handleEdgeProps(std::vector<cpp2::EdgeProp>& edgeProps);

nebula::cpp2::ErrorCode buildFilter(
const REQ& req, std::function<const std::string*(const REQ& req)>&& getFilter);
const REQ& req, std::function<const std::string*(const REQ& req, bool onlyTag)>&& getFilter);
nebula::cpp2::ErrorCode buildYields(const REQ& req);

// build ttl info map
Expand Down Expand Up @@ -207,6 +207,7 @@ class QueryBaseProcessor : public BaseProcessor<RESP> {
TagContext tagContext_;
EdgeContext edgeContext_;
Expression* filter_{nullptr};
Expression* tagFilter_{nullptr};

// Collect prop in value expression in upsert set clause
std::unordered_set<std::string> valueProps_;
Expand Down
3 changes: 2 additions & 1 deletion src/storage/query/ScanEdgeProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ nebula::cpp2::ErrorCode ScanEdgeProcessor::checkAndBuildContexts(const cpp2::Sca
std::vector<cpp2::EdgeProp> returnProps = *req.return_columns_ref();
ret = handleEdgeProps(returnProps);
buildEdgeColName(returnProps);
ret = buildFilter(req, [](const cpp2::ScanEdgeRequest& r) -> const std::string* {
ret = buildFilter(req, [](const cpp2::ScanEdgeRequest& r, bool onlyTag) -> const std::string* {
UNUSED(onlyTag);
if (r.filter_ref().has_value()) {
return r.get_filter();
} else {
Expand Down
3 changes: 2 additions & 1 deletion src/storage/query/ScanVertexProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ nebula::cpp2::ErrorCode ScanVertexProcessor::checkAndBuildContexts(
std::vector<cpp2::VertexProp> returnProps = *req.return_columns_ref();
ret = handleVertexProps(returnProps);
buildTagColName(returnProps);
ret = buildFilter(req, [](const cpp2::ScanVertexRequest& r) -> const std::string* {
ret = buildFilter(req, [](const cpp2::ScanVertexRequest& r, bool onlyTag) -> const std::string* {
UNUSED(onlyTag);
if (r.filter_ref().has_value()) {
return r.get_filter();
} else {
Expand Down
68 changes: 68 additions & 0 deletions src/storage/test/GetNeighborsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,74 @@ TEST(GetNeighborsTest, FilterTest) {
expected.rows.emplace_back(std::move(row));
ASSERT_EQ(expected, *resp.vertices_ref());
}
{
LOG(INFO) << "Filter apply to vertices";
std::vector<VertexID> vertices = {"Tim Duncan"};
std::vector<EdgeType> over = {serve};
std::vector<std::pair<TagID, std::vector<std::string>>> tags;
std::vector<std::pair<EdgeType, std::vector<std::string>>> edges;
tags.emplace_back(player, std::vector<std::string>{"name", "age"});
edges.emplace_back(serve, std::vector<std::string>{"teamName", "startYear", "endYear"});
auto req = QueryTestUtils::buildRequest(totalParts, vertices, over, tags, edges);
// where $^.player.age > 50
const auto& exp = *RelationalExpression::makeGT(
pool,
SourcePropertyExpression::make(pool, folly::to<std::string>(player), "age"),
ConstantExpression::make(pool, Value(50)));
(*req.traverse_spec_ref()).filter_ref() = Expression::encode(exp);
(*req.traverse_spec_ref()).tag_filter_ref() = Expression::encode(exp);

auto* processor = GetNeighborsProcessor::instance(env, nullptr, threadPool.get());
auto fut = processor->getFuture();
processor->process(req);
auto resp = std::move(fut).get();

ASSERT_EQ(0, (*resp.result_ref()).failed_parts.size());
// vId, stat, player, serve, expr
nebula::DataSet expected;
expected.colNames = {
kVid, "_stats", "_tag:1:name:age", "_edge:+101:teamName:startYear:endYear", "_expr"};
ASSERT_EQ(expected.colNames, (*resp.vertices_ref()).colNames);
ASSERT_EQ(0, (*resp.vertices_ref()).rows.size());
}
{
LOG(INFO) << "Filter apply to vertices2";
std::vector<VertexID> vertices = {"Tim Duncan", "Tony Parker"};
std::vector<EdgeType> over = {serve};
std::vector<std::pair<TagID, std::vector<std::string>>> tags;
std::vector<std::pair<EdgeType, std::vector<std::string>>> edges;
tags.emplace_back(player, std::vector<std::string>{"name", "age"});
edges.emplace_back(serve, std::vector<std::string>{"teamName", "startYear", "endYear"});
auto req = QueryTestUtils::buildRequest(totalParts, vertices, over, tags, edges);
// where $^.player.age > 40
const auto& exp = *RelationalExpression::makeGT(
pool,
SourcePropertyExpression::make(pool, folly::to<std::string>(player), "age"),
ConstantExpression::make(pool, Value(40)));
(*req.traverse_spec_ref()).filter_ref() = Expression::encode(exp);
(*req.traverse_spec_ref()).tag_filter_ref() = Expression::encode(exp);

auto* processor = GetNeighborsProcessor::instance(env, nullptr, threadPool.get());
auto fut = processor->getFuture();
processor->process(req);
auto resp = std::move(fut).get();

ASSERT_EQ(0, (*resp.result_ref()).failed_parts.size());
// vId, stat, player, serve, expr
nebula::DataSet expected;
expected.colNames = {
kVid, "_stats", "_tag:1:name:age", "_edge:+101:teamName:startYear:endYear", "_expr"};
ASSERT_EQ(expected.colNames, (*resp.vertices_ref()).colNames);
auto serveEdges = nebula::List();
serveEdges.values.emplace_back(nebula::List({"Spurs", 1997, 2016}));
nebula::Row row({"Tim Duncan", Value(), nebula::List({"Tim Duncan", 44}), serveEdges, Value()});
for (size_t i = 0; i < 4; i++) {
if ((*resp.vertices_ref()).rows[i].values[0].getStr() == "Tim Duncan") {
ASSERT_EQ(row, (*resp.vertices_ref()).rows[i]);
break;
}
}
}
{
LOG(INFO) << "Filter apply to multi vertices";
std::vector<VertexID> vertices = {
Expand Down