Skip to content

Commit

Permalink
[fix](nereids)fix nullable property of ForEachCombinator (#37980)
Browse files Browse the repository at this point in the history
## Proposed changes

pick from master #37796

<!--Describe your changes.-->
  • Loading branch information
starocean999 authored Jul 17, 2024
1 parent 1875267 commit b2a4cff
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
Expand All @@ -36,16 +36,20 @@
/**
* combinator foreach
*/
public class ForEachCombinator extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable, Combinator {
public class ForEachCombinator extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, Combinator {

private final AggregateFunction nested;

/**
* constructor of ForEachCombinator
*/
public ForEachCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, arguments);
this(arguments, false, nested);
}

public ForEachCombinator(List<Expression> arguments, boolean alwaysNullable, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, false, alwaysNullable, arguments);

this.nested = Objects.requireNonNull(nested, "nested can not be null");
}
Expand All @@ -56,7 +60,7 @@ public static ForEachCombinator create(AggregateFunction nested) {

@Override
public ForEachCombinator withChildren(List<Expression> children) {
return new ForEachCombinator(children, nested);
return new ForEachCombinator(children, alwaysNullable, nested);
}

@Override
Expand Down Expand Up @@ -88,4 +92,9 @@ public AggregateFunction getNestedFunction() {
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
throw new UnsupportedOperationException("Unimplemented method 'withDistinctAndChildren'");
}

@Override
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
return new ForEachCombinator(children, alwaysNullable, nested);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ default R visitUnionCombinator(UnionCombinator combinator, C context) {
}

default R visitForEachCombinator(ForEachCombinator combinator, C context) {
return visitAggregateFunction(combinator, context);
return visitNullableAggregateFunction(combinator, context);
}

default R visitJavaUdaf(JavaUdaf javaUdaf, C context) {
Expand Down
6 changes: 6 additions & 0 deletions regression-test/data/function_p0/test_agg_foreach_notnull.out
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_sum_not_null --
[1, 2, 3]
[20]
[100]
[null, 2]

-- !sql --
[1, 2, 3] [1, 2, 3] [100, 2, 3] [100, 2, 3] [40.333333333333336, 2, 3] [85.95867768595042, 2, 3]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ suite("test_agg_foreach_not_null") {
(4,[null,2],[[2],null],[null,'c']);
"""

qt_select_sum_not_null """
select sum_foreach(a) from foreach_table_not_null group by id order by id;
"""

// this case also test combinator should be case-insensitive
qt_sql """
select min_ForEach(a), min_by_foreach(a,a),max_foreach(a),max_by_foreach(a,a) , avg_foreach(a),avg_weighted_foreach(a,a) from foreach_table_not_null ;
Expand Down

0 comments on commit b2a4cff

Please sign in to comment.