Skip to content

Commit

Permalink
fix normalize agg
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Apr 28, 2024
1 parent d80e4ba commit 9cef386
Showing 1 changed file with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
Expand Down Expand Up @@ -127,10 +126,9 @@ private LogicalAggregate<Plan> normalizeRepeat(LogicalRepeat<Plan> repeat) {
.collect(ImmutableList.toImmutableList());

// replace the arguments of grouping scalar function to virtual slots
List<NamedExpression> normalizedAggOutput = repeat.getOutputExpressions().stream()
.map(expr -> (NamedExpression) expr.rewriteDownShortCircuit(
e -> normalizeGroupingScalarFunction(context, e)))
.collect(Collectors.toList());
// replace some complex expression to slot, e.g. `a + 1`
List<NamedExpression> normalizedAggOutput = context.normalizeToUseSlotRef(
repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction);

Set<VirtualSlotReference> virtualSlotsInFunction =
ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance);
Expand Down Expand Up @@ -185,17 +183,21 @@ private Set<Expression> collectNeedToSlotExpressions(LogicalRepeat<Plan> repeat)
.flatMap(function -> function.getArguments().stream())
.collect(ImmutableSet.toImmutableSet());

List<AggregateFunction> aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions());

ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream().map(arg -> {
if (arg instanceof OrderExpression) {
return arg.child(0);
} else {
return arg;
}
}))
.collect(ImmutableSet.toImmutableSet());
// List<AggregateFunction> aggregateFunctions = CollectNonWindowedAggFuncs
// .collect(repeat.getOutputExpressions());

// ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
// .flatMap(function -> function.getArguments().stream().map(arg -> {
// if (arg instanceof OrderExpression) {
// return arg.child(0);
// } else {
// return arg;
// }
// }))
// .collect(ImmutableSet.toImmutableSet());

Set<SlotReference> argumentsOfAggregateFunction = ExpressionUtils.collect(
repeat.getOutputExpressions(), expr -> expr.getClass().equals(SlotReference.class));

ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
// grouping sets should be pushed down, e.g. grouping sets((k + 1)),
Expand Down

0 comments on commit 9cef386

Please sign in to comment.