diff --git a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp index 2b61a62867c..68750ef37f2 100644 --- a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp +++ b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp @@ -26,7 +26,7 @@ OptimizeLeftJoinPredicateRule::OptimizeLeftJoinPredicateRule() { const Pattern& OptimizeLeftJoinPredicateRule::pattern() const { static Pattern pattern = Pattern::create( - PlanNode::Kind::kBiLeftJoin, + PlanNode::Kind::kHashLeftJoin, {Pattern::create(PlanNode::Kind::kUnknown), Pattern::create(PlanNode::Kind::kProject, {Pattern::create(PlanNode::Kind::kAppendVertices, @@ -38,7 +38,7 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( OptContext* octx, const MatchedResult& matched) const { auto* leftJoinGroupNode = matched.node; auto* leftJoinGroup = leftJoinGroupNode->group(); - auto* leftJoin = static_cast(leftJoinGroupNode->node()); + auto* leftJoin = static_cast(leftJoinGroupNode->node()); auto* projectGroupNode = matched.dependencies[1].node; auto* projectGroup = projectGroupNode->group(); @@ -68,14 +68,14 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( auto& probeKeys = leftJoin->probeKeys(); // Use visitor to collect all function `id` in the hashKeys - - std::vector hashKeyIdx; + bool found = false; + size_t hashKeyIdx; for (size_t i = 0; i < hashKeys.size(); ++i) { - auto* key = hashKeys[i]; - if (key->kind() != Expression::Kind::kFunctionCall) { + auto* hashKey = hashKeys[i]; + if (hashKey->kind() != Expression::Kind::kFunctionCall) { continue; } - auto* func = static_cast(key); + auto* func = static_cast(hashKey); if (func->name() != "id" || func->name() != "_joinkey") { continue; } @@ -87,14 +87,22 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( } auto& alias = static_cast(arg)->prop(); if (alias != avNodeAlias) continue; - // FIXME(jie): Must check if probe keys contain the same key - hashKeyIdx.emplace_back(i); + // Must check if probe keys contain the same key + if (probeKeys[i] != hashKey) { + return TransformResult::noTransform(); + } + if (found) { + return TransformResult::noTransform(); + } + hashKeyIdx = i; + found = true; } - if (hashKeyIdx.size() != 1) { + if (!found) { return TransformResult::noTransform(); } - std::vector prjIdx; + found = false; + size_t prjIdx; for (size_t i = 0; i < project->columns()->size(); ++i) { const auto* col = project->columns()->columns()[i]; if (col->expr()->kind() != Expression::Kind::kInputProperty) { @@ -102,22 +110,26 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( } auto* inputProp = static_cast(col->expr()); if (inputProp->prop() != avNodeAlias) continue; - prjIdx.push_back(i); + if (found) { + return TransformResult::noTransform(); + } + prjIdx = i; + found = true; } - if (prjIdx.size() != 1) { + if (!found) { return TransformResult::noTransform(); } auto* pool = octx->qctx()->objPool(); - // Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`, and let the new left join - // use it as hash key + // Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`, + // and let the new left join use it as hash key auto* args = ArgumentList::make(pool); args->addArgument(InputPropertyExpression::make(pool, tvEdgeAlias)); auto* newPrjExpr = FunctionCallExpression::make(pool, "none_direct_dst", args); auto* newYieldColumns = pool->makeAndAdd(); for (size_t i = 0; i < project->columns()->size(); ++i) { - if (i == prjIdx[0]) { + if (i == prjIdx) { newYieldColumns->addColumn(pool->makeAndAdd(newPrjExpr, newPrjExpr->toString())); } else { newYieldColumns->addColumn(project->columns()->columns()[i]); @@ -125,28 +137,30 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( } auto* newProject = graph::Project::make(octx->qctx(), nullptr, newYieldColumns); + // $-.`none_direct_dst(tvEdgeAlias)` auto* newHashExpr = InputPropertyExpression::make(pool, newPrjExpr->toString()); std::vector newHashKeys; for (size_t i = 0; i < hashKeys.size(); ++i) { - if (i == hashKeyIdx[0]) { + if (i == hashKeyIdx) { newHashKeys.emplace_back(newHashExpr); } else { newHashKeys.emplace_back(hashKeys[i]); } } auto* newLeftJoin = - graph::BiLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys); + graph::HashLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys); TransformResult result; result.eraseAll = true; newProject->setInputVar(appendVertices->inputVar()); - newProject->setOutputVar(project->outputVar()); auto newProjectGroup = OptGroup::create(octx); auto* newProjectGroupNode = newProjectGroup->makeGroupNode(newProject); newProjectGroupNode->setDeps(projectGroupNode->dependencies()); - newLeftJoin->setDep(1, newProject); + newLeftJoin->setLeftVar(leftJoin->leftInputVar()); + newLeftJoin->setRightVar(newProject->outputVar()); + newLeftJoin->setOutputVar(leftJoin->outputVar()); auto* newLeftJoinGroupNode = OptGroupNode::create(octx, newLeftJoin, leftJoinGroup); newLeftJoinGroupNode->dependsOn(leftJoinGroupNode->dependencies()[0]); newLeftJoinGroupNode->dependsOn(newProjectGroup); diff --git a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h index d075aefa7c1..21ff0c35e5c 100644 --- a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h +++ b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h @@ -11,22 +11,24 @@ namespace nebula { namespace opt { /* -Before: - BiLeftJoin({id(v)}, id(v)) - / \ - ... Project - \ - AppendVertices(v) - \ - Traverse(e) - -After: - BiLeftJoin({id(v)}, none_direct_dst(e)) - / \ - ... Project - \ - Traverse(e) -*/ + * Before: + * HashLeftJoin({id(v)}, {id(v)}) + * / \ + * ... Project + * / \ + * AppendVertices(v) AppendVertices(v) + * / \ + * ... Traverse(e) + * + * After: + * HashLeftJoin({id(v)}, {$-.`none_direct_dst(e)`}) + * / \ + * ... Project(none_direct_dst(e)) + * / \ + * AppendVertices(v) Traverse(e) + * / + * ... + */ class OptimizeLeftJoinPredicateRule final : public OptRule { public: const Pattern &pattern() const override; diff --git a/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature b/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature index 3386e1313d8..4b8c0df23fe 100644 --- a/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature +++ b/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature @@ -1,6 +1,7 @@ # Copyright (c) 2021 vesoft inc. All rights reserved. # # This source code is licensed under Apache 2.0 License. +@jie Feature: Optimize left join predicate Background: @@ -44,21 +45,21 @@ Feature: Optimize left join predicate | "Heat" | "Heat" | 0 | | "Jazz" | "Jazz" | 0 | And the execution plan should be: - | id | name | dependencies | profiling data | operator info | - | 21 | TopN | 18 | | | - | 18 | Project | 17 | | | - | 17 | Aggregate | 16 | | | - | 16 | BiLeftJoin | 10,15 | | | - | 10 | Dedup | 28 | | | - | 28 | Project | 22 | | | - | 22 | Filter | 26 | | | - | 26 | AppendVertices | 25 | | | - | 25 | Traverse | 24 | | | - | 24 | Traverse | 2 | | | - | 2 | Dedup | 1 | | | - | 1 | PassThrough | 3 | | | - | 3 | Start | | | | - | 15 | Project | 14 | | | - | 14 | Traverse | 12 | | | - | 12 | Traverse | 11 | | | - | 11 | Argument | | | | + | id | name | dependencies | operator info | + | 21 | TopN | 18 | | + | 18 | Project | 17 | | + | 17 | Aggregate | 16 | | + | 16 | HashLeftJoin | 10,15 | {"probeKeys": ["_joinkey($-.friendTeam)", "_joinkey($-.friendTeam)"], "hashKeys": ["$-.none_direct_dst(__VAR_3)", "_joinkey($-.friendTeam)"]} | + | 10 | Dedup | 28 | | + | 28 | Project | 22 | | + | 22 | Filter | 26 | | + | 26 | AppendVertices | 25 | | + | 25 | Traverse | 24 | | + | 24 | Traverse | 2 | | + | 2 | Dedup | 1 | | + | 1 | PassThrough | 3 | | + | 3 | Start | | | + | 15 | Project | 14 | {"columns": ["$-.friend AS friend, $-.friend2 AS friend2, none_direct_dst(__VAR_3) AS none_direct_dst(__VAR_3)"]} | + | 14 | Traverse | 12 | | + | 12 | Traverse | 11 | | + | 11 | Argument | | |