Skip to content

Commit

Permalink
[Feat](nereids) add rewrite rule PushCountIntoUnionAll
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Apr 11, 2024
1 parent e9b45dc commit c16524d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ public Rule build() {
private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
LogicalUnion logicalUnion = agg.child();
List<Slot> outputs = logicalUnion.getOutput();
Map<Slot, Integer> mmap = new HashMap<>();
Map<Slot, Integer> replaceMap = new HashMap<>();
for (int i = 0; i < outputs.size(); i++) {
mmap.put(outputs.get(i), i);
replaceMap.put(outputs.get(i), i);
}
int childSize = logicalUnion.children().size();
List<Expression> upperGroupByExpressions = agg.getGroupByExpressions();
Expand All @@ -69,9 +69,9 @@ private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
for (int i = 0; i < childSize; i++) {
Plan child = logicalUnion.children().get(i);
List<Slot> childOutputs = child.getOutput();
List<Expression> groupByExpressions = replaceExpressionByUnionAll(upperGroupByExpressions, mmap,
List<Expression> groupByExpressions = replaceExpressionByUnionAll(upperGroupByExpressions, replaceMap,
childOutputs);
List<NamedExpression> outputExpressions = replaceExpressionByUnionAll(upperOutputExpressions, mmap,
List<NamedExpression> outputExpressions = replaceExpressionByUnionAll(upperOutputExpressions, replaceMap,
childOutputs);
LogicalAggregate<Plan> logicalAggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
child);
Expand Down Expand Up @@ -110,31 +110,31 @@ private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
}

private <E extends Expression> List<E> replaceExpressionByUnionAll(List<E> expressions,
Map<Slot, Integer> mmap, List<Slot> childOutputs) {
// 遍历expressions,如果出现了mmap中的slot,那么替换为childOutputs[mmap[slot]]
Map<Slot, Integer> replaceMap, List<Slot> childOutputs) {
// Traverse expressions. If a slot in replaceMap appears, replace it with childOutputs[replaceMap[slot]]
return ExpressionUtils.rewriteDownShortCircuit(expressions, expr -> {
if (expr instanceof Alias && ((Alias) expr).child() instanceof Count) {
Count cnt = (Count) ((Alias) expr).child();
if (cnt.isCountStar()) {
return new Alias(new Count());
} else {
Expression newCntChild = cnt.child(0).rewriteDownShortCircuit(e -> {
if (e instanceof SlotReference && mmap.containsKey(e)) {
return childOutputs.get(mmap.get(e));
if (e instanceof SlotReference && replaceMap.containsKey(e)) {
return childOutputs.get(replaceMap.get(e));
}
return e;
});
return new Alias(new Count(newCntChild));
}
} else if (expr instanceof SlotReference && mmap.containsKey(expr)) {
return childOutputs.get(mmap.get(expr));
} else if (expr instanceof SlotReference && replaceMap.containsKey(expr)) {
return childOutputs.get(replaceMap.get(expr));
}
return expr;
});
}

private boolean hasUnsuportedAggFunc(LogicalAggregate aggregate) {
// 如果有不是count的aggfunc,或者有count distinct 都不支持这个规则
// only support count, and support count(distinct)
return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr -> {
if (expr instanceof AggregateFunction) {
return !(expr instanceof Count) || ((Count) expr).isDistinct();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
private final boolean ordinalIsResolved;
private final boolean generated;
private final boolean hasPushed;
public boolean done = false;

/**
* Desc: Constructor for LogicalAggregate.
Expand Down

0 comments on commit c16524d

Please sign in to comment.