diff --git a/src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp b/src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp index 23f04b9a21f..5ad6e287ec2 100644 --- a/src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp +++ b/src/graph/optimizer/rule/PushFilterDownTraverseRule.cpp @@ -65,7 +65,8 @@ bool isEdgeAllPredicate(const Expression* e, if (static_cast(pe->collection())->prop() != edgeAlias) { return false; } - auto ves = graph::ExpressionUtils::collectAll(pe->filter(), {Expression::Kind::kAttribute}); + auto ves = + graph::ExpressionUtils::collectAll(pe->oldFilterNode(), {Expression::Kind::kAttribute}); for (const auto& ve : ves) { auto iv = static_cast(ve)->left(); if (iv->kind() != Expression::Kind::kVar) { @@ -105,7 +106,7 @@ Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& e }; auto rewriter = [innerEdgeVar](const Expression* e) -> Expression* { DCHECK_EQ(e->kind(), Expression::Kind::kPredicate); - auto fe = static_cast(e)->filter(); + auto fe = static_cast(e)->oldFilterNode(); auto innerMatcher = [innerEdgeVar](const Expression* ae) { if (ae->kind() != Expression::Kind::kAttribute) { @@ -127,7 +128,7 @@ Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& e auto& prop = static_cast(right)->value().getStr(); return EdgePropertyExpression::make(ae->getObjPool(), "*", prop); }; - // Rewrite all the inner var edge attribute expressions of `all` predicate's filter to + // Rewrite all the inner var edge attribute expressions of `all` predicate's oldFilterNode to // EdgePropertyExpression return graph::RewriteVisitor::transform(fe, std::move(innerMatcher), std::move(innerRewriter)); }; @@ -135,18 +136,15 @@ Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& e } StatusOr PushFilterDownTraverseRule::transform( - OptContext* ctx, const MatchedResult& matched) const { - auto* filterGroupNode = matched.node; - auto* filterGroup = filterGroupNode->group(); - auto* filter = static_cast(filterGroupNode->node()); - auto* condition = filter->condition(); - - auto* tvGroupNode = matched.dependencies[0].node; - auto* tv = static_cast(tvGroupNode->node()); - auto& edgeAlias = tv->edgeAlias(); - auto srcNodeAlias = tv->nodeAlias(); - - auto qctx = ctx->qctx(); + OptContext* octx, const MatchedResult& matched) const { + auto* oldFilterGroupNode = matched.node; + auto* filterGroup = oldFilterGroupNode->group(); + auto* oldFilterNode = static_cast(oldFilterGroupNode->node()); + auto* condition = oldFilterNode->condition(); + auto* oldTvGroupNode = matched.dependencies[0].node; + auto* oldTvNode = static_cast(oldTvGroupNode->node()); + auto& edgeAlias = oldTvNode->edgeAlias(); + auto qctx = octx->qctx(); auto picker = [&edgeAlias](const Expression* expr) -> bool { bool neverPicked = false; @@ -178,6 +176,41 @@ StatusOr PushFilterDownTraverseRule::transform( } auto* edgeFilter = rewriteEdgeAllPredicate(filterPicked, edgeAlias); + auto* oldEdgeFilter = oldTvNode->eFilter(); + Expression* newEdgeFilter = + oldEdgeFilter ? LogicalExpression::makeAnd( + oldEdgeFilter->getObjPool(), edgeFilter, oldEdgeFilter->clone()) + : edgeFilter; + + // produce new Traverse node + auto* newTvNode = static_cast(oldTvNode->clone()); + newTvNode->setEdgeFilter(newEdgeFilter); + newTvNode->setInputVar(oldTvNode->inputVar()); + newTvNode->setColNames(oldTvNode->outputVarPtr()->colNames); + + TransformResult result; + result.eraseAll = true; + // czp + if (filterUnpicked) { + // produce new Filter node above + auto* newAboveFilterNode = graph::Filter::make(qctx, newTvNode, filterUnpicked); + newAboveFilterNode->setOutputVar(oldFilterNode->outputVar()); + auto newAboveFilterGroupNode = + OptGroupNode::create(octx, newAboveFilterNode, oldFilterGroupNode->group()); + + auto newTvGroup = OptGroup::create(octx); + auto newTvGroupNode = newTvGroup->makeGroupNode(newTvNode); + newTvGroupNode->setDeps(oldTvGroupNode->dependencies()); + newTvNode->setInputVar(oldTvNode->outputVar()); + newAboveFilterGroupNode->setDeps({newTvGroup}); + newAboveFilterNode->setInputVar(newTvNode->outputVar()); + result.newGroupNodes.emplace_back(newAboveFilterGroupNode); + } else { + newTvNode->setOutputVar(oldFilterNode->outputVar()); + auto newTvGroupNode = OptGroupNode::create(octx, newTvNode, oldFilterGroupNode->group()); + newTvGroupNode->setDeps(oldTvGroupNode->dependencies()); + result.newGroupNodes.emplace_back(newTvGroupNode); + } return result; }