Skip to content

Commit

Permalink
fix compile
Browse files Browse the repository at this point in the history
  • Loading branch information
czpmango committed Apr 3, 2023
1 parent 4ecf0de commit 036df0e
Showing 1 changed file with 48 additions and 15 deletions.
63 changes: 48 additions & 15 deletions src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ bool isEdgeAllPredicate(const Expression* e,
if (static_cast<const PropertyExpression*>(pe->collection())->prop() != edgeAlias) {
return false;
}
auto ves = graph::ExpressionUtils::collectAll(pe->filter(), {Expression::Kind::kAttribute});
auto ves =
graph::ExpressionUtils::collectAll(pe->oldFilterNode(), {Expression::Kind::kAttribute});
for (const auto& ve : ves) {
auto iv = static_cast<const AttributeExpression*>(ve)->left();
if (iv->kind() != Expression::Kind::kVar) {
Expand Down Expand Up @@ -105,7 +106,7 @@ Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& e
};
auto rewriter = [innerEdgeVar](const Expression* e) -> Expression* {
DCHECK_EQ(e->kind(), Expression::Kind::kPredicate);
auto fe = static_cast<const PredicateExpression*>(e)->filter();
auto fe = static_cast<const PredicateExpression*>(e)->oldFilterNode();

auto innerMatcher = [innerEdgeVar](const Expression* ae) {
if (ae->kind() != Expression::Kind::kAttribute) {
Expand All @@ -127,26 +128,23 @@ Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& e
auto& prop = static_cast<const ConstantExpression*>(right)->value().getStr();
return EdgePropertyExpression::make(ae->getObjPool(), "*", prop);
};
// Rewrite all the inner var edge attribute expressions of `all` predicate's filter to
// Rewrite all the inner var edge attribute expressions of `all` predicate's oldFilterNode to
// EdgePropertyExpression
return graph::RewriteVisitor::transform(fe, std::move(innerMatcher), std::move(innerRewriter));
};
return graph::RewriteVisitor::transform(expr, std::move(matcher), std::move(rewriter));
}

StatusOr<OptRule::TransformResult> PushFilterDownTraverseRule::transform(
OptContext* ctx, const MatchedResult& matched) const {
auto* filterGroupNode = matched.node;
auto* filterGroup = filterGroupNode->group();
auto* filter = static_cast<graph::Filter*>(filterGroupNode->node());
auto* condition = filter->condition();

auto* tvGroupNode = matched.dependencies[0].node;
auto* tv = static_cast<graph::Traverse*>(tvGroupNode->node());
auto& edgeAlias = tv->edgeAlias();
auto srcNodeAlias = tv->nodeAlias();

auto qctx = ctx->qctx();
OptContext* octx, const MatchedResult& matched) const {
auto* oldFilterGroupNode = matched.node;
auto* filterGroup = oldFilterGroupNode->group();
auto* oldFilterNode = static_cast<graph::Filter*>(oldFilterGroupNode->node());
auto* condition = oldFilterNode->condition();
auto* oldTvGroupNode = matched.dependencies[0].node;
auto* oldTvNode = static_cast<graph::Traverse*>(oldTvGroupNode->node());
auto& edgeAlias = oldTvNode->edgeAlias();
auto qctx = octx->qctx();

auto picker = [&edgeAlias](const Expression* expr) -> bool {
bool neverPicked = false;
Expand Down Expand Up @@ -178,6 +176,41 @@ StatusOr<OptRule::TransformResult> PushFilterDownTraverseRule::transform(
}

auto* edgeFilter = rewriteEdgeAllPredicate(filterPicked, edgeAlias);
auto* oldEdgeFilter = oldTvNode->eFilter();
Expression* newEdgeFilter =
oldEdgeFilter ? LogicalExpression::makeAnd(
oldEdgeFilter->getObjPool(), edgeFilter, oldEdgeFilter->clone())
: edgeFilter;

// produce new Traverse node
auto* newTvNode = static_cast<graph::Traverse*>(oldTvNode->clone());
newTvNode->setEdgeFilter(newEdgeFilter);
newTvNode->setInputVar(oldTvNode->inputVar());
newTvNode->setColNames(oldTvNode->outputVarPtr()->colNames);

TransformResult result;
result.eraseAll = true;
// czp
if (filterUnpicked) {
// produce new Filter node above
auto* newAboveFilterNode = graph::Filter::make(qctx, newTvNode, filterUnpicked);
newAboveFilterNode->setOutputVar(oldFilterNode->outputVar());
auto newAboveFilterGroupNode =
OptGroupNode::create(octx, newAboveFilterNode, oldFilterGroupNode->group());

auto newTvGroup = OptGroup::create(octx);
auto newTvGroupNode = newTvGroup->makeGroupNode(newTvNode);
newTvGroupNode->setDeps(oldTvGroupNode->dependencies());
newTvNode->setInputVar(oldTvNode->outputVar());
newAboveFilterGroupNode->setDeps({newTvGroup});
newAboveFilterNode->setInputVar(newTvNode->outputVar());
result.newGroupNodes.emplace_back(newAboveFilterGroupNode);
} else {
newTvNode->setOutputVar(oldFilterNode->outputVar());
auto newTvGroupNode = OptGroupNode::create(octx, newTvNode, oldFilterGroupNode->group());
newTvGroupNode->setDeps(oldTvGroupNode->dependencies());
result.newGroupNodes.emplace_back(newTvGroupNode);
}

return result;
}
Expand Down

0 comments on commit 036df0e

Please sign in to comment.