Skip to content

Commit

Permalink
[fix](nereids) add a project node above sort node to eliminate unused…
Browse files Browse the repository at this point in the history
… order by keys (apache#17913)

if the order by keys are not simple slot in sort node, the order by exprs have to been added to sort node's output tuple. In that case, we need add a project node above sort node to eliminate the unused order by exprs. for example:

```sql
WITH t0 AS 
    (SELECT DATE_FORMAT(date,
         '%Y%m%d') AS date
    FROM cir_1756_t1 ), t3 AS 
    (SELECT date_format(date,
         '%Y%m%d') AS `date`
    FROM `cir_1756_t2`
    GROUP BY  date_format(date, '%Y%m%d')
    **ORDER BY  date_format(date, '%Y%m%d')** )
SELECT t0.date
FROM t0
LEFT JOIN t3
    ON t0.date = t3.date;
```

before:
```
+--------------------------------------------------------------------------------------------------------------------------------------------------+
| Explain String                                                                                                                                   |
+--------------------------------------------------------------------------------------------------------------------------------------------------+
| LogicalProject[159] ( distinct=false, projects=[date#1], excepts=[], canEliminate=true )                                                         |
| +--LogicalJoin[158] ( type=LEFT_OUTER_JOIN, markJoinSlotReference=Optional.empty, hashJoinConjuncts=[(date#1 = date#3)], otherJoinConjuncts=[] ) |
|    |--LogicalProject[151] ( distinct=false, projects=[date_format(date#0, '%Y%m%d') AS `date`#1], excepts=[], canEliminate=true )                |
|    |  +--LogicalOlapScan ( qualified=default_cluster:bugfix.cir_1756_t1, indexName=cir_1756_t1, selectedIndexId=412339, preAgg=ON )              |
|    +--LogicalSort[157] ( orderKeys=[date_format(cast(date#3 as DATETIME), '%Y%m%d') asc null first] )                                            |
|       +--LogicalAggregate[156] ( groupByExpr=[date#3], outputExpr=[date#3], hasRepeat=false )                                                    |
|          +--LogicalProject[155] ( distinct=false, projects=[date_format(date#2, '%Y%m%d') AS `date`apache#3], excepts=[], canEliminate=true )          |
|             +--LogicalOlapScan ( qualified=default_cluster:bugfix.cir_1756_t2, indexName=cir_1756_t2, selectedIndexId=412352, preAgg=ON )        |
+--------------------------------------------------------------------------------------------------------------------------------------------------+
```

after:
```
+--------------------------------------------------------------------------------------------------------------------------------------------------+
| Explain String                                                                                                                                   |
+--------------------------------------------------------------------------------------------------------------------------------------------------+
| LogicalProject[171] ( distinct=false, projects=[date#2], excepts=[], canEliminate=true )                                                         |
| +--LogicalJoin[170] ( type=LEFT_OUTER_JOIN, markJoinSlotReference=Optional.empty, hashJoinConjuncts=[(date#2 = date#4)], otherJoinConjuncts=[] ) |
|    |--LogicalProject[162] ( distinct=false, projects=[date_format(date#0, '%Y%m%d') AS `date`apache#2], excepts=[], canEliminate=true )                |
|    |  +--LogicalOlapScan ( qualified=default_cluster:bugfix.cir_1756_t1, indexName=cir_1756_t1, selectedIndexId=1049812, preAgg=ON )             |
|    +--LogicalProject[169] ( distinct=false, projects=[date#4], excepts=[], canEliminate=false )                                                  |
|       +--LogicalSort[168] ( orderKeys=[date_format(cast(date#4 as DATETIME), '%Y%m%d') asc null first] )                                         |
|          +--LogicalAggregate[167] ( groupByExpr=[date#4], outputExpr=[date#4], hasRepeat=false )                                                 |
|             +--LogicalProject[166] ( distinct=false, projects=[date_format(date#3, '%Y%m%d') AS `date`apache#4], excepts=[], canEliminate=true )       |
|                +--LogicalOlapScan ( qualified=default_cluster:bugfix.cir_1756_t2, indexName=cir_1756_t2, selectedIndexId=1049825, preAgg=ON )    |
+--------------------------------------------------------------------------------------------------------------------------------------------------+
```
  • Loading branch information
starocean999 authored Mar 22, 2023
1 parent 6cbf393 commit 17a1ce5
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public enum RuleType {
COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
PRUNE_ONE_ROW_RELATION_COLUMN(RuleTypeClass.REWRITE),
COLUMN_PRUNE_SORT_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_SORT(RuleTypeClass.REWRITE),
COLUMN_PRUNE_JOIN_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_REPEAT_CHILD(RuleTypeClass.REWRITE),
// expression of plan rewrite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public Rule build() {
if (projects.equals(newProjects)) {
return project;
}
return new LogicalProject<>(newProjects, project.child());
return project.withProjectsAndChild(newProjects, project.child());
}).toRule(RuleType.REWRITE_PROJECT_EXPRESSION);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public List<Rule> buildRules() {
new PruneFilterChildColumns().build(),
new PruneAggChildColumns().build(),
new PruneJoinChildrenColumns().build(),
new PruneSortColumns().build(),
new PruneSortChildColumns().build(),
new MergeProjects().build(),
new PruneRepeatChildColumns().build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand Down Expand Up @@ -65,10 +64,11 @@ public Rule build() {
boolean needAggregate = bottomProjects.stream().anyMatch(expr ->
expr.anyMatch(AggregateFunction.class::isInstance));
if (needAggregate) {
normalizedChild = new LogicalAggregate<>(
ImmutableList.of(), ImmutableList.copyOf(bottomProjects), project.child());
normalizedChild = new LogicalAggregate<>(ImmutableList.of(),
ImmutableList.copyOf(bottomProjects), project.child());
} else {
normalizedChild = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), project.child());
normalizedChild = project.withProjectsAndChild(
ImmutableList.copyOf(bottomProjects), project.child());
}
}

Expand All @@ -89,7 +89,7 @@ public Rule build() {

// 3. handle top projects
List<NamedExpression> topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs1);
return new LogicalProject<>(topProjects, normalizedLogicalWindow);
return project.withProjectsAndChild(topProjects, normalizedLogicalWindow);
}).toRule(RuleType.EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import java.util.List;
Expand All @@ -47,9 +46,10 @@ public class MergeProjects extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalProject()).then(project -> {
LogicalProject<Plan> childProject = project.child();
LogicalProject childProject = project.child();
List<NamedExpression> projectExpressions = project.mergeProjections(childProject);
return new LogicalProject<>(projectExpressions, childProject.child(0));
LogicalProject newProject = childProject.canEliminate() ? project : childProject;
return newProject.withProjectsAndChild(projectExpressions, childProject.child(0));
}).toRule(RuleType.MERGE_PROJECTS);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.logical;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;

import java.util.stream.Collectors;

/**
the sort node will create new slots for order by keys if the order by keys is not in the output
so need create a project above sort node to prune the unnecessary order by keys
*/
public class PruneSortColumns extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalSort()
.when(sort -> !sort.isOrderKeysPruned() && !sort.getOutputSet()
.containsAll(sort.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet())))
.then(sort -> {
return new LogicalProject(sort.getOutput(), ImmutableList.of(), false,
sort.withOrderKeysPruned(true));
}).toRule(RuleType.COLUMN_PRUNE_SORT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Rule build() {
}
});

LogicalProject newProject = new LogicalProject<>(newProjects, filter.child());
LogicalProject newProject = project.withProjectsAndChild(newProjects, filter.child());
LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter));
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public Rule build() {
if (apply.getSubqueryExpr() instanceof ScalarSubquery) {
newProjects.add(project.getProjects().get(0));
}
return new LogicalProject(newProjects, newCorrelate);
return project.withProjectsAndChild(newProjects, newCorrelate);
}).toRule(RuleType.PULL_UP_PROJECT_UNDER_APPLY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ public Rule build() {
if (leftOutput.equals(leftProjects)) {
left = join.left();
} else {
left = new LogicalProject<>(leftProjects, join.left());
left = project.withProjectsAndChild(leftProjects, join.left());
}
if (rightOutput.equals(rightProjects)) {
right = join.right();
} else {
right = new LogicalProject<>(rightProjects, join.right());
right = project.withProjectsAndChild(rightProjects, join.right());
}

// If condition use alias slot, we should replace condition
Expand All @@ -109,7 +109,7 @@ public Rule build() {
List<Expression> newOther = replaceJoinConjuncts(join.getOtherJoinConjuncts(), replaceMap);

Plan newJoin = join.withConjunctsChildren(newHash, newOther, left, right);
return new LogicalProject<>(newProjects, newJoin);
return project.withProjectsAndChild(newProjects, newJoin);
}).toRule(RuleType.PUSHDOWN_ALIAS_THROUGH_JOIN);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public List<Rule> buildRules() {
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(
logicalFilter(logicalProject()).then(filter -> {
LogicalProject<Plan> project = filter.child();
return new LogicalProject<>(
return project.withProjectsAndChild(
project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(), project.getAliasToProducer()),
Expand All @@ -60,7 +60,7 @@ public List<Rule> buildRules() {
LogicalLimit<LogicalProject<Plan>> limit = filter.child();
LogicalProject<Plan> project = limit.child();

return new LogicalProject<>(
return project.withProjectsAndChild(
project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(), project.getAliasToProducer()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public Rule build() {
LogicalProject<LogicalLimit<Plan>> logicalProject = ctx.root;
LogicalLimit<Plan> logicalLimit = logicalProject.child();
return new LogicalLimit<>(logicalLimit.getLimit(), logicalLimit.getOffset(),
logicalLimit.getPhase(), new LogicalProject<>(logicalProject.getProjects(),
logicalLimit.getPhase(), logicalProject.withProjectsAndChild(logicalProject.getProjects(),
logicalLimit.child()));
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_LIMIT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public Rule build() {
.filter(e -> !projects.contains(e))
.map(NamedExpression.class::cast)
.forEach(projects::add);
LogicalProject newProject = new LogicalProject(projects, child);
LogicalProject newProject = project.withProjectsAndChild(projects, child);
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ public LogicalProject<Plan> withChildren(List<Plan> children) {
return new LogicalProject<>(projects, excepts, canEliminate, children.get(0), isDistinct);
}

public LogicalProject<Plan> withProjectsAndChild(List<NamedExpression> projects, Plan child) {
return new LogicalProject<>(projects, excepts, canEliminate,
Optional.empty(), Optional.empty(), child, isDistinct);
}

@Override
public LogicalProject<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalProject<>(projects, excepts, canEliminate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,29 @@ public class LogicalSort<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYP

private final List<OrderKey> orderKeys;

private final boolean orderKeysPruned;

public LogicalSort(List<OrderKey> orderKeys, CHILD_TYPE child) {
this(orderKeys, Optional.empty(), Optional.empty(), child);
}

public LogicalSort(List<OrderKey> orderKeys, CHILD_TYPE child, boolean orderKeysPruned) {
this(orderKeys, Optional.empty(), Optional.empty(), child, orderKeysPruned);
}

/**
* Constructor for LogicalSort.
*/
public LogicalSort(List<OrderKey> orderKeys, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
this(orderKeys, groupExpression, logicalProperties, child, false);
}

public LogicalSort(List<OrderKey> orderKeys, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child, boolean orderKeysPruned) {
super(PlanType.LOGICAL_SORT, groupExpression, logicalProperties, child);
this.orderKeys = ImmutableList.copyOf(Objects.requireNonNull(orderKeys, "orderKeys can not be null"));
this.orderKeysPruned = orderKeysPruned;
}

@Override
Expand All @@ -68,6 +80,10 @@ public List<OrderKey> getOrderKeys() {
return orderKeys;
}

public boolean isOrderKeysPruned() {
return orderKeysPruned;
}

@Override
public String toString() {
return Utils.toSqlString("LogicalSort[" + id.asInt() + "]",
Expand Down Expand Up @@ -106,21 +122,27 @@ public List<? extends Expression> getExpressions() {
@Override
public LogicalSort<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalSort<>(orderKeys, children.get(0));
return new LogicalSort<>(orderKeys, children.get(0), orderKeysPruned);
}

@Override
public LogicalSort<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child());
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child(),
orderKeysPruned);
}

@Override
public LogicalSort<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalSort<>(orderKeys, Optional.empty(), logicalProperties, child());
return new LogicalSort<>(orderKeys, Optional.empty(), logicalProperties, child(), false);
}

public LogicalSort<Plan> withOrderKeys(List<OrderKey> orderKeys) {
return new LogicalSort<>(orderKeys, Optional.empty(),
Optional.of(getLogicalProperties()), child());
Optional.of(getLogicalProperties()), child(), false);
}

public LogicalSort<Plan> withOrderKeysPruned(boolean orderKeysPruned) {
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child(),
orderKeysPruned);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select --
20200202

Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("inlineview_with_project") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql """
drop table if exists cir_1756_t1;
"""

sql """
drop table if exists cir_1756_t2;
"""

sql """
create table cir_1756_t1 (`date` date not null)
ENGINE=OLAP
DISTRIBUTED BY HASH(`date`) BUCKETS 5
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"in_memory" = "false",
"storage_format" = "V2"
);
"""

sql """
create table cir_1756_t2 ( `date` date not null )
ENGINE=OLAP
DISTRIBUTED BY HASH(`date`) BUCKETS 5
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"in_memory" = "false",
"storage_format" = "V2"
);
"""

sql """
insert into cir_1756_t1 values("2020-02-02");
"""

sql """
insert into cir_1756_t2 values("2020-02-02");
"""

qt_select """
WITH t0 AS
(SELECT DATE_FORMAT(date,
'%Y%m%d') AS date
FROM cir_1756_t1 ), t3 AS
(SELECT date_format(date,
'%Y%m%d') AS `date`
FROM `cir_1756_t2`
GROUP BY date_format(date, '%Y%m%d')
ORDER BY date_format(date, '%Y%m%d') )
SELECT t0.date
FROM t0
LEFT JOIN t3
ON t0.date = t3.date;
"""

sql """
drop table if exists cir_1756_t1;
"""

sql """
drop table if exists cir_1756_t2;
"""
}

0 comments on commit 17a1ce5

Please sign in to comment.