Skip to content

Commit

Permalink
[opt](Nereids) use 1 as narrowest column when do column pruning on un…
Browse files Browse the repository at this point in the history
…ion (apache#41719) (apache#42860)

pick from master apache#41719

just like previous PR apache#41548

this PR process union node to ensure not require any column from its
children when it is required by its parent with empty slot set
  • Loading branch information
morrySnow authored Oct 29, 2024
1 parent 1ea434f commit 189e94f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
Expand All @@ -42,6 +43,7 @@
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
Expand Down Expand Up @@ -345,6 +347,8 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)
}
List<NamedExpression> prunedOutputs = Lists.newArrayList();
List<List<NamedExpression>> constantExprsList = union.getConstantExprsList();
List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
List<Plan> children = union.children();
List<Integer> extractColumnIndex = Lists.newArrayList();
for (int i = 0; i < originOutput.size(); i++) {
NamedExpression output = originOutput.get(i);
Expand All @@ -353,31 +357,41 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)
extractColumnIndex.add(i);
}
}
if (prunedOutputs.isEmpty()) {
List<NamedExpression> candidates = Lists.newArrayList(originOutput);
candidates.retainAll(keys);
if (candidates.isEmpty()) {
candidates = originOutput;
}
NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates);
prunedOutputs = ImmutableList.of(minimumColumn);
extractColumnIndex.add(originOutput.indexOf(minimumColumn));
}

int len = extractColumnIndex.size();
ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
= ImmutableList.builderWithExpectedSize(constantExprsList.size());
for (List<NamedExpression> row : constantExprsList) {
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
for (int idx : extractColumnIndex) {
newRow.add(row.get(idx));
if (prunedOutputs.isEmpty()) {
// process prune all columns
NamedExpression originSlot = originOutput.get(0);
prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
TinyIntType.INSTANCE, false, originSlot.getQualifier()));
regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
children = Lists.newArrayListWithCapacity(children.size());
for (int i = 0; i < union.getArity(); i++) {
LogicalProject<?> project = new LogicalProject<>(
ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i));
regularChildrenOutputs.add((List) project.getOutput());
children.add(project);
}
for (int i = 0; i < constantExprsList.size(); i++) {
prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
}
} else {
int len = extractColumnIndex.size();
for (List<NamedExpression> row : constantExprsList) {
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
for (int idx : extractColumnIndex) {
newRow.add(row.get(idx));
}
prunedConstantExprsList.add(newRow.build());
}
prunedConstantExprsList.add(newRow.build());
}
if (prunedOutputs.equals(originOutput)) {

if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) {
return union;
} else {
return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build());
return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
regularChildrenOutputs, prunedConstantExprsList.build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DoubleType;
Expand Down Expand Up @@ -313,6 +314,21 @@ public void pruneAggregateOutput() {
);
}

@Test
public void pruneUnionAllWithCount() {
PlanChecker.from(connectContext)
.analyze("select count() from (select 1, 2 union all select id, age from student) t")
.customRewrite(new ColumnPruning())
.matches(
logicalProject(
logicalUnion(
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral),
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
)
).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
);
}

private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
return getOutputQualifiedNames(p.getOutputs());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
suite("const_expr_column_pruning") {
sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
// should only keep one column in union
sql "select count(1) from(select 3, 6 union all select 1, 3) t"
sql "select count(a) from(select 3 a, 6 union all select 1, 3) t"
}
sql """select count(1) from(select 3, 6 union all select 1, 3) t"""
sql """select count(1) from(select 3, 6 union all select "1", 3) t"""
sql """select count(a) from(select 3 a, 6 union all select "1", 3) t"""
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,10 @@ suite("window_column_pruning") {
sql "select id from (select id, rank() over() px from window_column_pruning union all select id, rank() over() px from window_column_pruning) a"
notContains "rank"
}

explain {
sql "select count() from (select row_number() over(partition by id) from window_column_pruning) tmp"
notContains "row_number"
}
}

0 comments on commit 189e94f

Please sign in to comment.