diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java index d4137f148e4021..904cb0848a9f41 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java @@ -235,14 +235,17 @@ public static AggregateInfo create( * Used by new optimizer. */ public static AggregateInfo create( - ArrayList groupingExprs, ArrayList aggExprs, - TupleDescriptor tupleDesc, TupleDescriptor intermediateTupleDesc, AggPhase phase) { + ArrayList groupingExprs, ArrayList aggExprs, List aggExprIds, + boolean isPartialAgg, TupleDescriptor tupleDesc, TupleDescriptor intermediateTupleDesc, AggPhase phase) { AggregateInfo result = new AggregateInfo(groupingExprs, aggExprs, phase); result.outputTupleDesc = tupleDesc; result.intermediateTupleDesc = intermediateTupleDesc; int aggExprSize = result.getAggregateExprs().size(); for (int i = 0; i < aggExprSize; i++) { result.materializedSlots.add(i); + String label = (isPartialAgg ? "partial_" : "") + + aggExprs.get(i).toSql() + "[#" + aggExprIds.get(i) + "]"; + result.materializedSlotLabels.add(label); } return result; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java index a8d4aef1612fe8..2501b308f24c2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java @@ -72,6 +72,7 @@ public abstract class AggregateInfoBase { // exprs that need to be materialized. // Populated in materializeRequiredSlots() which must be implemented by subclasses. protected ArrayList materializedSlots = Lists.newArrayList(); + protected List materializedSlotLabels = Lists.newArrayList(); protected AggregateInfoBase(ArrayList groupingExprs, ArrayList aggExprs) { @@ -94,6 +95,7 @@ protected AggregateInfoBase(AggregateInfoBase other) { intermediateTupleDesc = other.intermediateTupleDesc; outputTupleDesc = other.outputTupleDesc; materializedSlots = Lists.newArrayList(other.materializedSlots); + materializedSlotLabels = Lists.newArrayList(other.materializedSlotLabels); } /** @@ -234,6 +236,10 @@ public TupleId getOutputTupleId() { return outputTupleDesc.getId(); } + public List getMaterializedAggregateExprLabels() { + return Lists.newArrayList(materializedSlotLabels); + } + public boolean requiresIntermediateTuple() { Preconditions.checkNotNull(intermediateTupleDesc); Preconditions.checkNotNull(outputTupleDesc); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java index b33231d658f159..96b0ac47551090 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java @@ -70,7 +70,7 @@ public FunctionBuilder findFunctionBuilder(String name, List a .filter(functionBuilder -> functionBuilder.canApply(arguments)) .collect(Collectors.toList()); if (candidateBuilders.isEmpty()) { - String candidateHints = getCandidateHint(name, candidateBuilders); + String candidateHints = getCandidateHint(name, functionBuilders); throw new AnalysisException("Can not found function '" + name + "' which has " + arity + " arity. Candidate functions are: " + candidateHints); } @@ -93,6 +93,6 @@ private void registerBuiltinFunctions(Map> name2Bu public String getCandidateHint(String name, List candidateBuilders) { return candidateBuilders.stream() .map(builder -> name + builder.toString()) - .collect(Collectors.joining(", ")); + .collect(Collectors.joining(", ", "[", "]")); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java index 56f02f9832534f..8cf8477387b391 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java @@ -518,7 +518,7 @@ public BaseAnalysisTask createAnalysisTask(AnalysisTaskScheduler scheduler, Anal * @return estimated row count */ public long estimatedRowCount() { - long cardinality = 1; + long cardinality = 0; if (this instanceof OlapTable) { OlapTable table = (OlapTable) this; for (long selectedPartitionId : table.getPartitionIds()) { @@ -527,6 +527,6 @@ public long estimatedRowCount() { cardinality += baseIndex.getRowCount(); } } - return cardinality; + return Math.max(cardinality, 1); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/util/ReflectionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/common/util/ReflectionUtils.java index 0671f0440535cc..92085e3a92c08c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/util/ReflectionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/util/ReflectionUtils.java @@ -17,6 +17,7 @@ package org.apache.doris.common.util; +import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.Logger; import java.io.ByteArrayOutputStream; @@ -26,10 +27,21 @@ import java.lang.management.ThreadMXBean; import java.lang.reflect.Constructor; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; public class ReflectionUtils { private static final Class[] emptyArray = new Class[]{}; + private static final Map boxToPrimitiveTypes = ImmutableMap.builder() + .put(Boolean.class, boolean.class) + .put(Character.class, char.class) + .put(Byte.class, byte.class) + .put(Short.class, short.class) + .put(Integer.class, int.class) + .put(Long.class, long.class) + .put(Float.class, float.class) + .put(Double.class, double.class) + .build(); /** * Cache of constructors for each class. Pins the classes so they @@ -162,4 +174,11 @@ static void clearCache() { static int getCacheSize() { return CONSTRUCTOR_CACHE.size(); } + + public static Optional getPrimitiveType(Class targetClass) { + if (targetClass.isPrimitive()) { + return Optional.of(targetClass); + } + return Optional.ofNullable(boxToPrimitiveTypes.get(targetClass)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index e650ac58ac87bc..b05741b5b9243d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -76,8 +76,8 @@ public class CascadesContext { private List tables = null; - public CascadesContext(Memo memo, StatementContext statementContext) { - this(memo, statementContext, new CTEContext()); + public CascadesContext(Memo memo, StatementContext statementContext, PhysicalProperties requestProperties) { + this(memo, statementContext, new CTEContext(), requestProperties); } /** @@ -86,20 +86,22 @@ public CascadesContext(Memo memo, StatementContext statementContext) { * @param memo {@link Memo} reference * @param statementContext {@link StatementContext} reference */ - public CascadesContext(Memo memo, StatementContext statementContext, CTEContext cteContext) { + public CascadesContext(Memo memo, StatementContext statementContext, + CTEContext cteContext, PhysicalProperties requireProperties) { this.memo = memo; this.statementContext = statementContext; this.ruleSet = new RuleSet(); this.jobPool = new JobStack(); this.jobScheduler = new SimpleJobScheduler(); - this.currentJobContext = new JobContext(this, PhysicalProperties.ANY, Double.MAX_VALUE); + this.currentJobContext = new JobContext(this, requireProperties, Double.MAX_VALUE); this.subqueryExprIsAnalyzed = new HashMap<>(); this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable()); this.cteContext = cteContext; } - public static CascadesContext newContext(StatementContext statementContext, Plan initPlan) { - return new CascadesContext(new Memo(initPlan), statementContext); + public static CascadesContext newContext(StatementContext statementContext, + Plan initPlan, PhysicalProperties requireProperties) { + return new CascadesContext(new Memo(initPlan), statementContext, requireProperties); } public NereidsAnalyzer newAnalyzer() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index 433f96f183b559..03847cc6e8e24e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -39,6 +39,8 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.commands.Command; +import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -86,7 +88,10 @@ public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt; ExplainLevel explainLevel = getExplainLevel(queryStmt.getExplainOptions()); - Plan resultPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, explainLevel); + + LogicalPlan parsedPlan = logicalPlanAdapter.getLogicalPlan(); + PhysicalProperties requireProperties = buildInitRequireProperties(parsedPlan); + Plan resultPlan = plan(parsedPlan, requireProperties, explainLevel); if (explainLevel.isPlanLevel) { return; } @@ -129,11 +134,11 @@ public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) * Do analyze and optimize for query plan. * * @param plan wait for plan - * @param outputProperties physical properties constraints + * @param requireProperties request physical properties constraints * @return plan generated by this planner * @throws AnalysisException throw exception if failed in ant stage */ - public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainLevel explainLevel) { + public Plan plan(LogicalPlan plan, PhysicalProperties requireProperties, ExplainLevel explainLevel) { if (explainLevel == ExplainLevel.PARSED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) { parsedPlan = plan; if (explainLevel == ExplainLevel.PARSED_PLAN) { @@ -144,7 +149,8 @@ public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainL // pre-process logical plan out of memo, e.g. process SET_VAR hint plan = preprocess(plan); - initCascadesContext(plan); + initCascadesContext(plan, requireProperties); + try (Lock lock = new Lock(plan, cascadesContext)) { // resolve column, table and function analyze(); @@ -174,7 +180,7 @@ public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainL // cost-based optimize and explore plan space optimize(); - PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY); + PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), requireProperties); // post-process physical plan out of memo, just for future use. physicalPlan = postProcess(physicalPlan); @@ -190,8 +196,8 @@ private LogicalPlan preprocess(LogicalPlan logicalPlan) { return new PlanPreprocessors(statementContext).process(logicalPlan); } - private void initCascadesContext(LogicalPlan plan) { - cascadesContext = CascadesContext.newContext(statementContext, plan); + private void initCascadesContext(LogicalPlan plan, PhysicalProperties requireProperties) { + cascadesContext = CascadesContext.newContext(statementContext, plan, requireProperties); } private void analyze() { @@ -258,7 +264,8 @@ private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physical throws AnalysisException { try { GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow( - () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second; + () -> new AnalysisException("lowestCostPlans with physicalProperties(" + + physicalProperties + ") doesn't exist in root group")).second; List inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties); List planChildren = Lists.newArrayList(); @@ -335,6 +342,11 @@ public CascadesContext getCascadesContext() { return cascadesContext; } + public static PhysicalProperties buildInitRequireProperties(Plan initPlan) { + boolean isQuery = !(initPlan instanceof Command) || (initPlan instanceof ExplainCommand); + return isQuery ? PhysicalProperties.GATHER : PhysicalProperties.ANY; + } + private ExplainLevel getExplainLevel(ExplainOptions explainOptions) { if (explainOptions == null) { return ExplainLevel.NONE; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index ca7e13d6b5542a..a91266ea39ac69 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -24,6 +24,13 @@ import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import com.google.common.collect.Maps; + +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; + /** * Statement context for nereids */ @@ -37,6 +44,9 @@ public class StatementContext { private final IdGenerator relationIdGenerator = RelationId.createGenerator(); + @GuardedBy("this") + private final Map> contextCacheMap = Maps.newLinkedHashMap(); + private StatementBase parsedStatement; public StatementContext() { @@ -79,4 +89,14 @@ public RelationId getNextRelationId() { public void setParsedStatement(StatementBase parsedStatement) { this.parsedStatement = parsedStatement; } + + /** getOrRegisterCache */ + public synchronized T getOrRegisterCache(String key, Supplier cacheSupplier) { + Supplier supplier = (Supplier) contextCacheMap.get(key); + if (supplier == null) { + contextCacheMap.put(key, (Supplier) Suppliers.memoize(cacheSupplier)); + supplier = cacheSupplier; + } + return supplier.get(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java index f89abe7c709474..5261df39a88351 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java @@ -35,13 +35,11 @@ public class UnboundFunction extends Expression implements Unbound, PropagateNul private final String name; private final boolean isDistinct; - private final boolean isStar; - public UnboundFunction(String name, boolean isDistinct, boolean isStar, List arguments) { + public UnboundFunction(String name, boolean isDistinct, List arguments) { super(arguments.toArray(new Expression[0])); - this.name = Objects.requireNonNull(name, "name can not be null"); + this.name = Objects.requireNonNull(name, "name cannot be null"); this.isDistinct = isDistinct; - this.isStar = isStar; } public String getName() { @@ -52,10 +50,6 @@ public boolean isDistinct() { return isDistinct; } - public boolean isStar() { - return isStar; - } - public List getArguments() { return children(); } @@ -65,13 +59,13 @@ public String toSql() throws UnboundException { String params = children.stream() .map(Expression::toSql) .collect(Collectors.joining(", ")); - return name + "(" + (isDistinct ? "DISTINCT " : "") + params + ")"; + return name + "(" + (isDistinct ? "distinct " : "") + params + ")"; } @Override public String toString() { String params = Joiner.on(", ").join(children); - return "'" + name + "(" + (isDistinct ? "DISTINCT " : "") + params + ")"; + return "'" + name + "(" + (isDistinct ? "distinct " : "") + params + ")"; } @Override @@ -81,7 +75,7 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public UnboundFunction withChildren(List children) { - return new UnboundFunction(name, isDistinct, isStar, children); + return new UnboundFunction(name, isDistinct, children); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/annotation/DependsRules.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/annotation/DependsRules.java new file mode 100644 index 00000000000000..49ff26c89812aa --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/annotation/DependsRules.java @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicating that the current rule depends on other rules. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE}) +public @interface DependsRules { + /** depends rules */ + Class[] value() default {}; +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java index 30b54c904b5596..904ef7fa880a37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java @@ -25,15 +25,16 @@ import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecReplicated; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.qe.ConnectContext; @@ -98,6 +99,15 @@ public CostEstimate visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, Pla return CostEstimate.ofCpu(statistics.getRowCount()); } + @Override + public CostEstimate visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) { + CostEstimate costEstimate = storageLayerAggregate.getRelation().accept(this, context); + // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible + return new CostEstimate(costEstimate.getCpuCost() * 0.7, costEstimate.getMemoryCost(), + costEstimate.getNetworkCost(), costEstimate.getPenalty()); + } + @Override public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { return CostEstimate.ofCpu(1); @@ -190,7 +200,8 @@ public CostEstimate visitPhysicalDistribute( } @Override - public CostEstimate visitPhysicalAggregate(PhysicalAggregate aggregate, PlanContext context) { + public CostEstimate visitPhysicalHashAggregate( + PhysicalHashAggregate aggregate, PlanContext context) { // TODO: stage..... StatsDeriveResult statistics = context.getStatisticsWithCheck(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index d628660c9bfd0e..f019dee49e41f9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -33,10 +33,12 @@ import org.apache.doris.analysis.IsNullPredicate; import org.apache.doris.analysis.LikePredicate; import org.apache.doris.analysis.SlotRef; +import org.apache.doris.analysis.StringLiteral; import org.apache.doris.analysis.TimestampArithmeticExpr; import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.Type; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; @@ -65,11 +67,16 @@ import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; +import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; -import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.types.coercion.AbstractDataType; import org.apache.doris.thrift.TFunctionBinaryType; @@ -190,6 +197,31 @@ public Expr visitLiteral(Literal literal, PlanTranslatorContext context) { return literal.toLegacyLiteral(); } + @Override + public Expr visitNullLiteral(NullLiteral nullLiteral, PlanTranslatorContext context) { + org.apache.doris.analysis.NullLiteral nullLit = new org.apache.doris.analysis.NullLiteral(); + nullLit.setType(nullLiteral.getDataType().toCatalogDataType()); + return nullLit; + } + + @Override + public Expr visitDateLiteral(DateLiteral dateLiteral, PlanTranslatorContext context) { + // BE not support date v2 literal and datetime v2 literal + if (dateLiteral instanceof DateV2Literal) { + return new CastExpr(Type.DATEV2, new StringLiteral(dateLiteral.toString())); + } + return super.visitDateLiteral(dateLiteral, context); + } + + @Override + public Expr visitDateTimeLiteral(DateTimeLiteral dateTimeLiteral, PlanTranslatorContext context) { + // BE not support date v2 literal and datetime v2 literal + if (dateTimeLiteral instanceof DateTimeV2Literal) { + return new CastExpr(Type.DATETIMEV2, new StringLiteral(dateTimeLiteral.toString())); + } + return super.visitDateTimeLiteral(dateTimeLiteral, context); + } + @Override public Expr visitBetween(Between between, PlanTranslatorContext context) { throw new RuntimeException("Unexpected invocation"); @@ -261,62 +293,6 @@ public Expr visitInPredicate(InPredicate inPredicate, PlanTranslatorContext cont false); } - // TODO: Supports for `distinct` - @Override - public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) { - List catalogArguments = function.getArguments() - .stream() - .map(arg -> arg.accept(this, context)) - .collect(ImmutableList.toImmutableList()); - - // aggFnArguments is used to build TAggregateExpr.param_types, so backend can find the aggregate function - List aggFnArguments = function.getArgumentsBeforeDisassembled() - .stream() - .map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable())) - .collect(ImmutableList.toImmutableList()); - - FunctionParams aggFnParams; - if (function instanceof Count && ((Count) function).isStar()) { - aggFnParams = FunctionParams.createStarParam(); - } else { - aggFnParams = new FunctionParams(function.isDistinct(), aggFnArguments); - } - - ImmutableList argTypes = catalogArguments.stream() - .map(arg -> arg.getType()) - .collect(ImmutableList.toImmutableList()); - - NullableMode nullableMode = function.nullable() - ? NullableMode.ALWAYS_NULLABLE - : NullableMode.ALWAYS_NOT_NULLABLE; - - boolean isAnalyticFunction = false; - String functionName = function.isDistinct() ? "MULTI_DISTINCT_" + function.getName() : function.getName(); - if (function.getAggregateParam().aggPhase == AggPhase.DISTINCT_LOCAL - || function.getAggregateParam().aggPhase == AggPhase.DISTINCT_GLOBAL) { - if (function.getName().equalsIgnoreCase("count")) { - functionName = "SUM"; - } else { - functionName = function.getName(); - } - } - org.apache.doris.catalog.AggregateFunction catalogFunction = new org.apache.doris.catalog.AggregateFunction( - new FunctionName(functionName), argTypes, - function.getDataType().toCatalogDataType(), - function.getIntermediateTypes().toCatalogDataType(), - function.hasVarArguments(), - null, "", "", null, "", - null, "", null, false, - isAnalyticFunction, false, TFunctionBinaryType.BUILTIN, - true, true, nullableMode - ); - - boolean isMergeFn = function.isGlobal() && function.isDisassembled(); - - // create catalog FunctionCallExpr without analyze again - return new FunctionCallExpr(catalogFunction, aggFnParams, aggFnParams, isMergeFn, catalogArguments); - } - @Override public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) { List arguments = function.getArguments() @@ -340,6 +316,22 @@ public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext c return new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments)); } + @Override + public Expr visitAggregateExpression(AggregateExpression aggregateExpression, PlanTranslatorContext context) { + // aggFnArguments is used to build TAggregateExpr.param_types, so backend can find the aggregate function + List aggFnArguments = aggregateExpression.getFunction().children() + .stream() + .map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable())) + .collect(ImmutableList.toImmutableList()); + + Expression child = aggregateExpression.child(); + List currentPhaseArguments = child instanceof AggregateFunction + ? child.children() + : aggregateExpression.children(); + return translateAggregateFunction(aggregateExpression.getFunction(), + currentPhaseArguments, aggFnArguments, aggregateExpression.getAggregateParam(), context); + } + @Override public Expr visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, PlanTranslatorContext context) { return new ArithmeticExpr(binaryArithmetic.getLegacyOperator(), @@ -377,6 +369,57 @@ public Expr visitIsNull(IsNull isNull, PlanTranslatorContext context) { return new IsNullPredicate(isNull.child().accept(this, context), false); } + // TODO: Supports for `distinct` + private Expr translateAggregateFunction(AggregateFunction function, + List currentPhaseArguments, List aggFnArguments, + AggregateParam aggregateParam, PlanTranslatorContext context) { + List catalogArguments = currentPhaseArguments + .stream() + .map(arg -> arg.accept(this, context)) + .collect(ImmutableList.toImmutableList()); + + FunctionParams fnParams; + FunctionParams aggFnParams; + if (function instanceof Count && ((Count) function).isStar()) { + if (catalogArguments.isEmpty()) { + // for explain display the label: count(*) + fnParams = FunctionParams.createStarParam(); + } else { + fnParams = new FunctionParams(function.isDistinct(), catalogArguments); + } + aggFnParams = FunctionParams.createStarParam(); + } else { + fnParams = new FunctionParams(function.isDistinct(), catalogArguments); + aggFnParams = new FunctionParams(function.isDistinct(), aggFnArguments); + } + + ImmutableList argTypes = catalogArguments.stream() + .map(arg -> arg.getType()) + .collect(ImmutableList.toImmutableList()); + + NullableMode nullableMode = function.nullable() + ? NullableMode.ALWAYS_NULLABLE + : NullableMode.ALWAYS_NOT_NULLABLE; + + boolean isAnalyticFunction = false; + String functionName = function.getName(); + org.apache.doris.catalog.AggregateFunction catalogFunction = new org.apache.doris.catalog.AggregateFunction( + new FunctionName(functionName), argTypes, + function.getDataType().toCatalogDataType(), + function.getIntermediateTypes().toCatalogDataType(), + function.hasVarArguments(), + null, "", "", null, "", + null, "", null, false, + isAnalyticFunction, false, TFunctionBinaryType.BUILTIN, + true, true, nullableMode + ); + + boolean isMergeFn = aggregateParam.aggPhase.isGlobal(); + + // create catalog FunctionCallExpr without analyze again + return new FunctionCallExpr(catalogFunction, fnParams, aggFnParams, isMergeFn, catalogArguments); + } + public static org.apache.doris.analysis.AssertNumRowsElement translateAssert( AssertNumRowsElement assertNumRowsElement) { return new org.apache.doris.analysis.AssertNumRowsElement(assertNumRowsElement.getDesiredNumOfRows(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index b81292e753e968..2fcbcff96403d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -35,9 +35,15 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Table; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.properties.DistributionSpecAny; +import org.apache.doris.nereids.properties.DistributionSpecGather; import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.properties.DistributionSpecReplicated; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; @@ -46,19 +52,19 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalEmptyRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; @@ -68,6 +74,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; @@ -93,6 +100,7 @@ import org.apache.doris.planner.UnionNode; import org.apache.doris.tablefunction.TableValuedFunctionIf; import org.apache.doris.thrift.TPartitionType; +import org.apache.doris.thrift.TPushAggOp; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -111,6 +119,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -134,6 +143,23 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor 1) { rootFragment = exchangeToMergeFragment(rootFragment, context); } @@ -162,20 +188,12 @@ public PlanFragment translatePlan(PhysicalPlan physicalPlan, PlanTranslatorConte * Translate Agg. */ @Override - public PlanFragment visitPhysicalAggregate( - PhysicalAggregate aggregate, + public PlanFragment visitPhysicalHashAggregate( + PhysicalHashAggregate aggregate, PlanTranslatorContext context) { PlanFragment inputPlanFragment = aggregate.child(0).accept(this, context); - // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts: - // 1. group by expressions: removing duplicate expressions add to tuple - // 2. agg functions: only removing duplicate agg functions in output expression should appear in tuple. - // e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple - // We need: - // 1. add a project after agg, if agg function is not the top output expression.(Done) - // 2. introduce canonicalized, semanticEquals and deterministic in Expression - // for removing duplicate. List groupByExpressionList = aggregate.getGroupByExpressions(); List outputExpressionList = aggregate.getOutputExpressions(); @@ -184,26 +202,34 @@ public PlanFragment visitPhysicalAggregate( ArrayList execGroupingExpressions = groupByExpressionList.stream() .map(e -> ExpressionTranslator.translate(e, context)) .collect(Collectors.toCollection(ArrayList::new)); - // 2. collect agg functions and generate agg function to slot reference map + // 2. collect agg expressions and generate agg function to slot reference map List aggFunctionOutput = Lists.newArrayList(); - List aggregateFunctionList = outputExpressionList.stream() - .filter(o -> o.anyMatch(AggregateFunction.class::isInstance)) + List aggregateExpressionList = outputExpressionList.stream() + .filter(o -> o.anyMatch(AggregateExpression.class::isInstance)) .peek(o -> aggFunctionOutput.add(o.toSlot())) - .map(o -> o.>collect(AggregateFunction.class::isInstance)) + .map(o -> o.>collect(AggregateExpression.class::isInstance)) .flatMap(Set::stream) .collect(Collectors.toList()); - ArrayList execAggregateFunctions = aggregateFunctionList.stream() + ArrayList execAggregateFunctions = aggregateExpressionList.stream() .map(aggregateFunction -> (FunctionCallExpr) ExpressionTranslator.translate(aggregateFunction, context)) .collect(Collectors.toCollection(ArrayList::new)); - // process partition list - List partitionExpressionList = aggregate.getPartitionExpressions(); - List execPartitionExpressions = partitionExpressionList.stream() - .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); - DataPartition mergePartition = DataPartition.UNPARTITIONED; - if (CollectionUtils.isNotEmpty(execPartitionExpressions) - && aggregate.getAggPhase() != AggPhase.DISTINCT_GLOBAL) { - mergePartition = DataPartition.hashPartitioned(execPartitionExpressions); + PlanFragment currentFragment; + if (inputPlanFragment.getPlanRoot() instanceof ExchangeNode) { + Preconditions.checkState(aggregate.child() instanceof PhysicalDistribute, + "When the ExchangeNode is child of PhysicalHashAggregate, " + + "it should be created by PhysicalDistribute, but meet " + aggregate.child()); + ExchangeNode exchangeNode = (ExchangeNode) inputPlanFragment.getPlanRoot(); + Optional> partitionExpressions = aggregate.getPartitionExpressions(); + PhysicalDistribute physicalDistribute = (PhysicalDistribute) aggregate.child(); + DataPartition dataPartition = toDataPartition(physicalDistribute, partitionExpressions, context).get(); + currentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, dataPartition); + inputPlanFragment.setOutputPartition(dataPartition); + inputPlanFragment.setPlanRoot(exchangeNode.getChild(0)); + inputPlanFragment.setDestination(exchangeNode); + context.addPlanFragment(currentFragment); + } else { + currentFragment = inputPlanFragment; } // 3. generate output tuple @@ -213,40 +239,42 @@ public PlanFragment visitPhysicalAggregate( slotList.addAll(aggFunctionOutput); outputTupleDesc = generateTupleDesc(slotList, null, context); - AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, outputTupleDesc, - outputTupleDesc, aggregate.getAggPhase().toExec()); + List aggFunOutputIds = ImmutableList.of(); + if (!aggFunctionOutput.isEmpty()) { + aggFunOutputIds = outputTupleDesc + .getSlots() + .subList(groupSlotList.size(), outputTupleDesc.getSlots().size()) + .stream() + .map(slot -> slot.getId().asInt()) + .collect(ImmutableList.toImmutableList()); + } + boolean isPartial = aggregate.getAggregateParam().aggMode.productAggregateBuffer; + AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, + aggFunOutputIds, isPartial, outputTupleDesc, outputTupleDesc, aggregate.getAggPhase().toExec()); AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(), - inputPlanFragment.getPlanRoot(), aggInfo); - if (!aggregate.getAggPhase().isGlobal() && !aggregate.isFinalPhase()) { + currentFragment.getPlanRoot(), aggInfo); + if (!aggregate.getAggMode().isFinalPhase) { aggregationNode.unsetNeedsFinalize(); } - PlanFragment currentFragment = inputPlanFragment; + PhysicalHashAggregate firstAggregateInFragment = context.getFirstAggregateInFragment(currentFragment); + switch (aggregate.getAggPhase()) { case LOCAL: - aggregationNode.setUseStreamingPreagg(aggregate.isUsingStream()); - aggregationNode.setIntermediateTuple(); + // we should set is useStreamingAgg when has exchange, + // so the `aggregationNode.setUseStreamingPreagg()` in the visitPhysicalDistribute break; case DISTINCT_LOCAL: + aggregationNode.setIntermediateTuple(); + break; case GLOBAL: case DISTINCT_GLOBAL: - if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) { - aggregationNode.setIntermediateTuple(); - } - if (currentFragment.getPlanRoot() instanceof ExchangeNode) { - ExchangeNode exchangeNode = (ExchangeNode) currentFragment.getPlanRoot(); - currentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition); - inputPlanFragment.setOutputPartition(mergePartition); - inputPlanFragment.setPlanRoot(exchangeNode.getChild(0)); - inputPlanFragment.setDestination(exchangeNode); - context.addPlanFragment(currentFragment); - } - if (aggregate.getAggPhase() != AggPhase.DISTINCT_LOCAL) { - currentFragment.updateDataPartition(mergePartition); - } break; default: throw new RuntimeException("Unsupported yet"); } + if (firstAggregateInFragment == null) { + context.setFirstAggregateInFragment(currentFragment, aggregate); + } currentFragment.setPlanRoot(aggregationNode); if (aggregate.getStats() != null) { aggregationNode.setCardinality((long) aggregate.getStats().getRowCount()); @@ -361,6 +389,35 @@ public PlanFragment visitPhysicalOneRowRelation(PhysicalOneRowRelation oneRowRel return planFragment; } + @Override + public PlanFragment visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, PlanTranslatorContext context) { + Preconditions.checkState(storageLayerAggregate.getRelation() instanceof PhysicalOlapScan, + "PhysicalStorageLayerAggregate only support PhysicalOlapScan: " + + storageLayerAggregate.getRelation().getClass().getName()); + PlanFragment planFragment = storageLayerAggregate.getRelation().accept(this, context); + + OlapScanNode olapScanNode = (OlapScanNode) planFragment.getPlanRoot(); + TPushAggOp pushAggOp; + switch (storageLayerAggregate.getAggOp()) { + case COUNT: + pushAggOp = TPushAggOp.COUNT; + break; + case MIN_MAX: + pushAggOp = TPushAggOp.MINMAX; + break; + case MIX: + pushAggOp = TPushAggOp.MIX; + break; + default: + throw new AnalysisException("Unsupported storage layer aggregate: " + + storageLayerAggregate.getAggOp()); + } + olapScanNode.setPushDownAggNoGrouping(pushAggOp); + + return planFragment; + } + @Override public PlanFragment visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanTranslatorContext context) { // Create OlapScanNode @@ -379,7 +436,6 @@ public PlanFragment visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanTransla tupleDescriptor.setRef(tableRef); olapScanNode.setSelectedPartitionIds(olapScan.getSelectedPartitionIds()); olapScanNode.setSampleTabletIds(olapScan.getSelectedTabletIds()); - olapScanNode.setPushDownAggNoGrouping(olapScan.getPushDownAggOperator().toThrift()); switch (olapScan.getTable().getKeysType()) { case AGG_KEYS: @@ -400,6 +456,7 @@ public PlanFragment visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanTransla expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, olapScanNode, context) ) ); + olapScanNode.finalizeForNerieds(); // Create PlanFragment DataPartition dataPartition = DataPartition.RANDOM; if (olapScan.getDistributionSpec() instanceof DistributionSpecHash) { @@ -1063,6 +1120,18 @@ public PlanFragment visitPhysicalLimit(PhysicalLimit physicalLim public PlanFragment visitPhysicalDistribute(PhysicalDistribute distribute, PlanTranslatorContext context) { PlanFragment childFragment = distribute.child().accept(this, context); + + if (childFragment.getPlanRoot() instanceof AggregationNode + && distribute.child() instanceof PhysicalHashAggregate + && context.getFirstAggregateInFragment(childFragment) == distribute.child()) { + PhysicalHashAggregate hashAggregate = (PhysicalHashAggregate) distribute.child(); + if (hashAggregate.getAggPhase() == AggPhase.LOCAL + && hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) { + AggregationNode aggregationNode = (AggregationNode) childFragment.getPlanRoot(); + aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream()); + } + } + ExchangeNode exchange = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false); exchange.setNumInstances(childFragment.getPlanRoot().getNumInstances()); childFragment.setPlanRoot(exchange); @@ -1136,7 +1205,8 @@ private TupleDescriptor generateTupleDesc(List slotList, List or private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition, PlanTranslatorContext context) { - ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false); + ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), + childFragment.getPlanRoot(), false); exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances()); PlanFragment parentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, parentPartition); childFragment.setDestination(exchangeNode); @@ -1168,10 +1238,11 @@ private PlanFragment exchangeToMergeFragment(PlanFragment inputFragment, PlanTra Preconditions.checkState(inputFragment.isPartitioned()); // exchange node clones the behavior of its input, aside from the conjuncts - ExchangeNode mergePlan = - new ExchangeNode(context.nextPlanNodeId(), inputFragment.getPlanRoot(), false); + ExchangeNode mergePlan = new ExchangeNode(context.nextPlanNodeId(), + inputFragment.getPlanRoot(), false); + DataPartition dataPartition = DataPartition.UNPARTITIONED; mergePlan.setNumInstances(inputFragment.getPlanRoot().getNumInstances()); - PlanFragment fragment = new PlanFragment(context.nextFragmentId(), mergePlan, DataPartition.UNPARTITIONED); + PlanFragment fragment = new PlanFragment(context.nextFragmentId(), mergePlan, dataPartition); inputFragment.setDestination(mergePlan); context.addPlanFragment(fragment); return fragment; @@ -1305,7 +1376,7 @@ private boolean projectOnAgg(PhysicalProject project) { while (child instanceof PhysicalFilter || child instanceof PhysicalDistribute) { child = (PhysicalPlan) child.child(0); } - return child instanceof PhysicalAggregate; + return child instanceof PhysicalHashAggregate; } private boolean hasExprCalc(PhysicalProject project) { @@ -1334,4 +1405,32 @@ private List removeAlias(PhysicalProject project) { } return slotReferences; } + + private Optional toDataPartition(PhysicalDistribute distribute, + Optional> partitionExpressions, PlanTranslatorContext context) { + if (distribute.getDistributionSpec() == DistributionSpecGather.INSTANCE) { + return Optional.of(DataPartition.UNPARTITIONED); + } else if (distribute.getDistributionSpec() == DistributionSpecReplicated.INSTANCE) { + // the data partition should be left child of join + return Optional.empty(); + } else if (distribute.getDistributionSpec() instanceof DistributionSpecHash + || distribute.getDistributionSpec() == DistributionSpecAny.INSTANCE) { + if (!partitionExpressions.isPresent()) { + throw new AnalysisException("Missing partition expressions"); + } + Preconditions.checkState( + partitionExpressions.get().stream().allMatch(expr -> expr instanceof SlotReference), + "All partition expression should be slot: " + partitionExpressions.get()); + if (!partitionExpressions.isPresent() || partitionExpressions.get().isEmpty()) { + return Optional.of(DataPartition.UNPARTITIONED); + } + List partitionExprs = partitionExpressions.get() + .stream() + .map(p -> ExpressionTranslator.translate(p, context)) + .collect(ImmutableList.toImmutableList()); + return Optional.of(new DataPartition(TPartitionType.HASH_PARTITIONED, partitionExprs)); + } else { + return Optional.empty(); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java index 92cc8571df7d89..13fadc09a17356 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.planner.PlanFragment; import org.apache.doris.planner.PlanFragmentId; import org.apache.doris.planner.PlanNode; @@ -40,6 +41,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -71,6 +73,9 @@ public class PlanTranslatorContext { private final IdGenerator nodeIdGenerator = PlanNodeId.createGenerator(); + private final IdentityHashMap firstAggInFragment + = new IdentityHashMap<>(); + public PlanTranslatorContext(CascadesContext ctx) { this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext()); } @@ -133,6 +138,14 @@ public List getScanNodes() { return scanNodes; } + public PhysicalHashAggregate getFirstAggregateInFragment(PlanFragment planFragment) { + return firstAggInFragment.get(planFragment); + } + + public void setFirstAggregateInFragment(PlanFragment planFragment, PhysicalHashAggregate aggregate) { + firstAggInFragment.put(planFragment, aggregate); + } + /** * Create SlotDesc and add it to the mappings from expression to the stales epxr */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java index a0e45e17966a2b..c16a654490f89e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java @@ -17,6 +17,8 @@ package org.apache.doris.nereids.jobs; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.CopyInResult; import org.apache.doris.nereids.memo.Group; @@ -33,14 +35,20 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.List; +import java.util.Locale; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -57,6 +65,7 @@ public abstract class Job implements TracerSupplier { protected JobType type; protected JobContext context; protected boolean once; + protected final Set disableRules; public Job(JobType type, JobContext context) { this(type, context, true); @@ -67,6 +76,8 @@ public Job(JobType type, JobContext context, boolean once) { this.type = type; this.context = context; this.once = once; + this.disableRules = getAndCacheSessionVariable(context, "disableNereidsRules", + ImmutableSet.of(), SessionVariable::getDisableNereidsRules); } public void pushJob(Job job) { @@ -90,6 +101,7 @@ public boolean isOnce() { */ public List getValidRules(GroupExpression groupExpression, List candidateRules) { return candidateRules.stream() + .filter(rule -> !disableRules.contains(rule.getRuleType().name().toUpperCase(Locale.ROOT))) .filter(rule -> Objects.nonNull(rule) && rule.getPattern().matchRoot(groupExpression.getPlan()) && groupExpression.notApplied(rule)).collect(Collectors.toList()); } @@ -133,4 +145,21 @@ protected void countJobExecutionTimesOfGroupExpressions(GroupExpression groupExp COUNTER_TRACER.log(CounterEvent.of(Memo.getStateId(), CounterType.JOB_EXECUTION, groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan())); } + + private T getAndCacheSessionVariable(JobContext context, String cacheName, + T defaultValue, Function variableSupplier) { + CascadesContext cascadesContext = context.getCascadesContext(); + ConnectContext connectContext = cascadesContext.getConnectContext(); + if (connectContext == null) { + return defaultValue; + } + + StatementContext statementContext = cascadesContext.getStatementContext(); + if (statementContext == null) { + return defaultValue; + } + T cacheResult = statementContext.getOrRegisterCache(cacheName, + () -> variableSupplier.apply(connectContext.getSessionVariable())); + return cacheResult; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java index 52014728ca7e9d..310adbfb477eed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.jobs.rewrite.VisitorRewriteJob; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleFactory; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import java.util.ArrayList; @@ -74,8 +75,8 @@ protected Job topDownBatch(List ruleFactories, boolean once) { cascadesContext.getCurrentJobContext(), once); } - protected Job visitorJob(DefaultPlanRewriter planRewriter) { - return new VisitorRewriteJob(cascadesContext, planRewriter, true); + protected Job visitorJob(RuleType ruleType, DefaultPlanRewriter planRewriter) { + return new VisitorRewriteJob(cascadesContext, planRewriter, ruleType); } protected Job optimize() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index 6162c609b385c3..9936a44d74e4d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.Job; import org.apache.doris.nereids.rules.RuleSet; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; @@ -27,6 +28,7 @@ import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate; import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate; import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning; +import org.apache.doris.nereids.rules.rewrite.logical.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.logical.EliminateGroupByConstant; import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; @@ -40,7 +42,6 @@ import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition; import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet; -import org.apache.doris.nereids.rules.rewrite.logical.PushAggregateToOlapScan; import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin; import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; @@ -78,11 +79,14 @@ public NereidsRewriteJobExecutor(CascadesContext cascadesContext) { .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) - .add(visitorJob(new InferPredicates())) + .add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates())) .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) .add(topDownBatch(ImmutableList.of(new ColumnPruning()))) .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) - .add(visitorJob(new InferPredicates())) + .add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates())) + .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) + .add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates())) + .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) .add(topDownBatch(ImmutableList.of(PushFilterInsideJoin.INSTANCE))) .add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin()))) .add(topDownBatch(ImmutableList.of(new LimitPushDown()))) @@ -97,7 +101,7 @@ public NereidsRewriteJobExecutor(CascadesContext cascadesContext) { .add(topDownBatch(ImmutableList.of(new EliminateGroupByConstant()))) .add(topDownBatch(ImmutableList.of(new EliminateOrderByConstant()))) .add(topDownBatch(ImmutableList.of(new EliminateUnnecessaryProject()))) - .add(topDownBatch(ImmutableList.of(new PushAggregateToOlapScan()))) + .add(topDownBatch(ImmutableList.of(new EliminateAggregate()))) // this rule batch must keep at the end of rewrite to do some plan check .add(bottomUpBatch(ImmutableList.of(new CheckAfterRewrite()))) .build(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java index 31117c050743ce..b4d7f098674076 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java @@ -77,10 +77,11 @@ public void execute() throws AnalysisException { GroupExpression newGroupExpression = result.correspondingExpression; if (newPlan instanceof LogicalPlan) { pushJob(new OptimizeGroupExpressionJob(newGroupExpression, context)); - pushJob(new DeriveStatsJob(newGroupExpression, context)); } else { pushJob(new CostAndEnforcerJob(newGroupExpression, context)); } + // we should derive stats for new logical/physical plan if the plan missing the stats + pushJob(new DeriveStatsJob(newGroupExpression, context)); APPLY_RULE_TRACER.log(TransformEvent.of(groupExpression, plan, newPlans, rule.getRuleType()), rule::isRewrite); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java index f11f693cf74364..60271c4fe87906 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java @@ -163,6 +163,11 @@ public void execute() { GroupExpression lowestCostExpr = lowestCostPlanOpt.get().second; lowestCostChildren.add(lowestCostExpr); PhysicalProperties outputProperties = lowestCostExpr.getOutputProperties(requestChildProperty); + + // use child's outputProperties to reset the request properties, so no unnecessary enforce. + // this is safety because `childGroup.getLowestCostPlan(current plan's requestChildProperty). + // getOutputProperties(current plan's requestChildProperty) == child plan's outputProperties`, + // the outputProperties must satisfy the origin requestChildProperty requestChildrenProperties.set(curChildIndex, outputProperties); curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first; @@ -178,7 +183,7 @@ public void execute() { // if break when running the loop above, the condition must be false. if (curChildIndex == groupExpression.arity()) { if (!calculateEnforce(requestChildrenProperties)) { - return; + return; // if error exists, return } if (curTotalCost < context.getCostUpperBound()) { context.setCostUpperBound(curTotalCost); @@ -217,14 +222,17 @@ private boolean calculateEnforce(List requestChildrenPropert return false; } StatsCalculator.estimate(groupExpression); + // previous curTotalCost exclude the exists best cost of current node curTotalCost -= curNodeCost; - curNodeCost = CostCalculator.calculateCost(groupExpression); + curNodeCost = CostCalculator.calculateCost(groupExpression); // recompute current node's cost in current context groupExpression.setCost(curNodeCost); + // (previous curTotalCost) - (previous curNodeCost) + (current curNodeCost) = (current curTotalCost). + // if current curTotalCost maybe less than previous curTotalCost, we will update the lowest cost and plan + // to the grouping expression and the owner group curTotalCost += curNodeCost; // record map { outputProperty -> outputProperty }, { ANY -> outputProperty }, - recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, - requestChildrenProperties); + recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, requestChildrenProperties); recordPropertyAndCost(groupExpression, outputProperty, outputProperty, requestChildrenProperties); enforce(outputProperty, requestChildrenProperties); return true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java index f274f5e6ebcc54..66638c5023a0b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java @@ -28,6 +28,8 @@ import org.apache.doris.nereids.metrics.event.StatsStateEvent; import org.apache.doris.nereids.stats.StatsCalculator; +import java.util.List; + /** * Job to derive stats for {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}. */ @@ -45,31 +47,44 @@ public class DeriveStatsJob extends Job { * @param context context of current job */ public DeriveStatsJob(GroupExpression groupExpression, JobContext context) { - super(JobType.DERIVE_STATS, context); - this.groupExpression = groupExpression; - this.deriveChildren = false; + this(groupExpression, false, context); } - /** - * Copy constructor for DeriveStatsJob. - * - * @param other DeriveStatsJob copied from - */ - public DeriveStatsJob(DeriveStatsJob other) { - super(JobType.DERIVE_STATS, other.context); - this.groupExpression = other.groupExpression; - this.deriveChildren = other.deriveChildren; + private DeriveStatsJob(GroupExpression groupExpression, boolean deriveChildren, JobContext context) { + super(JobType.DERIVE_STATS, context); + this.groupExpression = groupExpression; + this.deriveChildren = deriveChildren; } @Override public void execute() { countJobExecutionTimesOfGroupExpressions(groupExpression); - if (!deriveChildren) { - deriveChildren = true; - pushJob(new DeriveStatsJob(this)); - for (Group child : groupExpression.children()) { - if (!child.getLogicalExpressions().isEmpty()) { - pushJob(new DeriveStatsJob(child.getLogicalExpressions().get(0), context)); + if (groupExpression.isStatDerived()) { + return; + } + if (!deriveChildren && groupExpression.arity() > 0) { + pushJob(new DeriveStatsJob(groupExpression, true, context)); + + List children = groupExpression.children(); + // rule maybe return new logical plans to wrap some new physical plans, + // so we should check derive stats for it if no stats + for (int i = children.size() - 1; i >= 0; i--) { + Group childGroup = children.get(i); + + List logicalExpressions = childGroup.getLogicalExpressions(); + for (int j = logicalExpressions.size() - 1; j >= 0; j--) { + GroupExpression logicalChild = logicalExpressions.get(j); + if (!logicalChild.isStatDerived()) { + pushJob(new DeriveStatsJob(logicalChild, context)); + } + } + + List physicalExpressions = childGroup.getPhysicalExpressions(); + for (int j = physicalExpressions.size() - 1; j >= 0; j--) { + GroupExpression physicalChild = physicalExpressions.get(j); + if (!physicalChild.isStatDerived()) { + pushJob(new DeriveStatsJob(physicalChild, context)); + } } } } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java index 9510a4147c1520..4836803be1bb2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java @@ -26,15 +26,18 @@ import org.apache.doris.nereids.memo.Memo; import org.apache.doris.nereids.metrics.CounterType; import org.apache.doris.nereids.metrics.event.CounterEvent; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import java.util.Locale; import java.util.Objects; /** * Use visitor to rewrite the plan. */ public class VisitorRewriteJob extends Job { + private final RuleType ruleType; private final Group group; private final DefaultPlanRewriter planRewriter; @@ -42,14 +45,19 @@ public class VisitorRewriteJob extends Job { /** * Constructor. */ - public VisitorRewriteJob(CascadesContext cascadesContext, DefaultPlanRewriter rewriter, boolean once) { - super(JobType.VISITOR_REWRITE, cascadesContext.getCurrentJobContext(), once); + public VisitorRewriteJob(CascadesContext cascadesContext, + DefaultPlanRewriter rewriter, RuleType ruleType) { + super(JobType.VISITOR_REWRITE, cascadesContext.getCurrentJobContext(), true); + this.ruleType = Objects.requireNonNull(ruleType, "ruleType cannot be null"); this.group = Objects.requireNonNull(cascadesContext.getMemo().getRoot(), "group cannot be null"); this.planRewriter = Objects.requireNonNull(rewriter, "planRewriter cannot be null"); } @Override public void execute() { + if (disableRules.contains(ruleType.name().toUpperCase(Locale.ROOT))) { + return; + } GroupExpression logicalExpression = group.getLogicalExpression(); Plan root = context.getCascadesContext().getMemo().copyOut(logicalExpression, true); COUNTER_TRACER.log(CounterEvent.of(Memo.getStateId(), CounterType.JOB_EXECUTION, group, logicalExpression, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index fb2adb9ac65b48..80244a1ca88961 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -197,7 +197,7 @@ public GroupExpression getBestPlan(PhysicalProperties properties) { */ public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) { if (lowestCostPlans.containsKey(properties)) { - if (lowestCostPlans.get(properties).first >= cost) { + if (lowestCostPlans.get(properties).first > cost) { lowestCostPlans.put(properties, Pair.of(cost, expression)); } } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index ba6a1d73fb9655..6ca12867fc8844 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.metrics.consumer.LogConsumer; import org.apache.doris.nereids.metrics.event.GroupMergeEvent; import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.analysis.CTEContext; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.GroupPlan; @@ -209,11 +210,11 @@ public Plan copyOut(GroupExpression logicalExpression, boolean includeGroupExpre * Utility function to create a new {@link CascadesContext} with this Memo. */ public CascadesContext newCascadesContext(StatementContext statementContext) { - return new CascadesContext(this, statementContext); + return new CascadesContext(this, statementContext, PhysicalProperties.ANY); } public CascadesContext newCascadesContext(StatementContext statementContext, CTEContext cteContext) { - return new CascadesContext(this, statementContext, cteContext); + return new CascadesContext(this, statementContext, cteContext, PhysicalProperties.ANY); } /** @@ -677,7 +678,8 @@ public String toString() { builder.append(group).append("\n"); builder.append(" stats=").append(group.getStatistics()).append("\n"); StatsDeriveResult stats = group.getStatistics(); - if (stats != null && group.getLogicalExpressions().get(0).getPlan() instanceof LogicalOlapScan) { + if (stats != null && !group.getLogicalExpressions().isEmpty() + && group.getLogicalExpressions().get(0).getPlan() instanceof LogicalOlapScan) { for (Entry e : stats.getSlotIdToColumnStats().entrySet()) { builder.append(" ").append(e.getKey()).append(":").append(e.getValue()).append("\n"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index 39546a55ebd906..7194fccabf5490 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -37,7 +37,6 @@ import org.apache.doris.nereids.DorisParser.DereferenceContext; import org.apache.doris.nereids.DorisParser.ExistContext; import org.apache.doris.nereids.DorisParser.ExplainContext; -import org.apache.doris.nereids.DorisParser.ExpressionContext; import org.apache.doris.nereids.DorisParser.FromClauseContext; import org.apache.doris.nereids.DorisParser.GroupingElementContext; import org.apache.doris.nereids.DorisParser.GroupingSetContext; @@ -134,6 +133,7 @@ import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysSub; @@ -158,7 +158,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Interval; import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -207,7 +207,6 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -543,15 +542,15 @@ public Expression visitArithmeticBinary(ArithmeticBinaryContext ctx) { Expression right = getExpression(ctx.right); int type = ctx.operator.getType(); - if (left instanceof IntervalLiteral) { + if (left instanceof Interval) { if (type != DorisParser.PLUS) { throw new ParseException("Only supported: " + Operator.ADD, ctx); } - IntervalLiteral interval = (IntervalLiteral) left; + Interval interval = (Interval) left; return new TimestampArithmetic(Operator.ADD, right, interval.value(), interval.timeUnit(), true); } - if (right instanceof IntervalLiteral) { + if (right instanceof Interval) { Operator op; if (type == DorisParser.PLUS) { op = Operator.ADD; @@ -560,7 +559,7 @@ public Expression visitArithmeticBinary(ArithmeticBinaryContext ctx) { } else { throw new ParseException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT, ctx); } - IntervalLiteral interval = (IntervalLiteral) right; + Interval interval = (Interval) right; return new TimestampArithmetic(op, left, interval.value(), interval.timeUnit(), false); } @@ -741,32 +740,39 @@ public Expression visitCast(DorisParser.CastContext ctx) { public UnboundFunction visitExtract(DorisParser.ExtractContext ctx) { return ParserUtils.withOrigin(ctx, () -> { String functionName = ctx.field.getText(); - return new UnboundFunction(functionName, false, false, + return new UnboundFunction(functionName, false, Collections.singletonList(getExpression(ctx.source))); }); } @Override - public UnboundFunction visitFunctionCall(DorisParser.FunctionCallContext ctx) { + public Expression visitFunctionCall(DorisParser.FunctionCallContext ctx) { return ParserUtils.withOrigin(ctx, () -> { - // TODO:In the future, instead of specifying the function name, - // the function information is obtained by parsing the catalog. This method is more scalable. String functionName = ctx.identifier().getText(); boolean isDistinct = ctx.DISTINCT() != null; - List expressionContexts = ctx.expression(); - List params = visit(expressionContexts, Expression.class); - for (Expression expression : params) { - if (expression instanceof UnboundStar && functionName.equalsIgnoreCase("count") && !isDistinct) { - return new UnboundFunction(functionName, false, true, new ArrayList<>()); + List params = visit(ctx.expression(), Expression.class); + List unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance); + if (unboundStars.size() > 0) { + if (functionName.equalsIgnoreCase("count")) { + if (unboundStars.size() > 1) { + throw new ParseException( + "'*' can only be used once in conjunction with COUNT: " + functionName, ctx); + } + if (!unboundStars.get(0).getQualifier().isEmpty()) { + throw new ParseException("'*' can not has qualifier: " + unboundStars.size(), ctx); + } + return new Count(); } + throw new ParseException("'*' can only be used in conjunction with COUNT: " + functionName, ctx); + } else { + return new UnboundFunction(functionName, isDistinct, params); } - return new UnboundFunction(functionName, isDistinct, false, params); }); } @Override public Expression visitInterval(IntervalContext ctx) { - return new IntervalLiteral(getExpression(ctx.value), visitUnitIdentifier(ctx.unit)); + return new Interval(getExpression(ctx.value), visitUnitIdentifier(ctx.unit)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedAction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedAction.java index 4cb393b791f935..a577b24de57cc3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedAction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedAction.java @@ -20,7 +20,8 @@ import org.apache.doris.nereids.trees.plans.Plan; /** - * Define an callback action when match a pattern, usually implement as a rule body. + * Define a callback action when match a pattern, and then transform to a plan, + * usually implement as a rule body. * e.g. exchange join children for JoinCommutative Rule */ public interface MatchedAction { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedMultiAction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedMultiAction.java new file mode 100644 index 00000000000000..abc471d181e317 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/MatchedMultiAction.java @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.trees.plans.Plan; + +import java.util.List; + +/** + * Define a callback action when match a pattern, and then transform to a batch of plans, + * usually implement as a rule body. + * e.g. DisassembleAggregate + */ +public interface MatchedMultiAction { + + List apply(MatchingContext ctx); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternDescriptor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternDescriptor.java index dd742c34e0d7f0..70249cdd578b7b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternDescriptor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternDescriptor.java @@ -54,11 +54,23 @@ public PatternDescriptor whenNot(Predicate predicate) { public PatternMatcher then( Function matchedAction) { - return new PatternMatcher<>(pattern, defaultPromise, ctx -> matchedAction.apply(ctx.root)); + MatchedAction adaptMatchedAction = ctx -> matchedAction.apply(ctx.root); + return new PatternMatcher<>(pattern, defaultPromise, adaptMatchedAction); } public PatternMatcher thenApply( MatchedAction matchedAction) { return new PatternMatcher<>(pattern, defaultPromise, matchedAction); } + + public PatternMatcher thenMulti( + Function> matchedAction) { + MatchedMultiAction adaptMatchedAction = ctx -> matchedAction.apply(ctx.root); + return new PatternMatcher<>(pattern, defaultPromise, adaptMatchedAction); + } + + public PatternMatcher thenApplyMulti( + MatchedMultiAction matchedAction) { + return new PatternMatcher<>(pattern, defaultPromise, matchedAction); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternMatcher.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternMatcher.java index 44b8103bdc55ae..2a6e1505bd1abc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternMatcher.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PatternMatcher.java @@ -37,6 +37,7 @@ public class PatternMatcher { public final Pattern pattern; public final RulePromise defaultRulePromise; public final MatchedAction matchedAction; + public final MatchedMultiAction matchedMultiAction; /** * PatternMatcher wrap a pattern, defaultRulePromise and matchedAction. @@ -51,6 +52,16 @@ public PatternMatcher(Pattern pattern, RulePromise defaultRulePromis this.defaultRulePromise = Objects.requireNonNull( defaultRulePromise, "defaultRulePromise can not be null"); this.matchedAction = Objects.requireNonNull(matchedAction, "matchedAction can not be null"); + this.matchedMultiAction = null; + } + + public PatternMatcher(Pattern pattern, RulePromise defaultRulePromise, + MatchedMultiAction matchedAction) { + this.pattern = Objects.requireNonNull(pattern, "pattern can not be null"); + this.defaultRulePromise = Objects.requireNonNull( + defaultRulePromise, "defaultRulePromise can not be null"); + this.matchedMultiAction = Objects.requireNonNull(matchedAction, "matchedMultiAction can not be null"); + this.matchedAction = null; } public Rule toRule(RuleType ruleType) { @@ -68,10 +79,19 @@ public Rule toRule(RuleType ruleType, RulePromise rulePromise) { return new Rule(ruleType, pattern, rulePromise) { @Override public List transform(Plan originPlan, CascadesContext context) { - MatchingContext matchingContext = - new MatchingContext<>((INPUT_TYPE) originPlan, pattern, context); - OUTPUT_TYPE replacePlan = matchedAction.apply(matchingContext); - return ImmutableList.of(replacePlan == null ? originPlan : replacePlan); + if (matchedMultiAction != null) { + MatchingContext matchingContext = + new MatchingContext<>((INPUT_TYPE) originPlan, pattern, context); + List replacePlans = matchedMultiAction.apply(matchingContext); + return replacePlans == null || replacePlans.isEmpty() + ? ImmutableList.of(originPlan) + : ImmutableList.copyOf(replacePlans); + } else { + MatchingContext matchingContext = + new MatchingContext<>((INPUT_TYPE) originPlan, pattern, context); + OUTPUT_TYPE replacePlan = matchedAction.apply(matchingContext); + return ImmutableList.of(replacePlan == null ? originPlan : replacePlan); + } } }; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index fe1bcbeac5ae9b..c580866a164284 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.planner.RuntimeFilterId; @@ -142,6 +143,13 @@ public PhysicalOlapScan visitPhysicalOlapScan(PhysicalOlapScan scan, CascadesCon return scan; } + @Override + public PhysicalStorageLayerAggregate visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, CascadesContext context) { + storageLayerAggregate.getRelation().accept(this, context); + return storageLayerAggregate; + } + private static Pair checkAndMaybeSwapChild(EqualTo expr, PhysicalHashJoin join) { if (expr.child(0).equals(expr.child(1)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java index 058c59db7f5576..afca448aff82fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java @@ -24,10 +24,10 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.AbstractPlan; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; @@ -59,7 +59,8 @@ public class RuntimeFilterPruner extends PlanPostProcessor { // Physical plans // ******************************* @Override - public PhysicalAggregate visitPhysicalAggregate(PhysicalAggregate agg, CascadesContext context) { + public PhysicalHashAggregate visitPhysicalHashAggregate( + PhysicalHashAggregate agg, CascadesContext context) { agg.child().accept(this, context); context.getRuntimeFilterContext().addEffectiveSrcNode(agg); return agg; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java index 86cf47386e413b..9683da9b0521c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java @@ -24,11 +24,12 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; @@ -36,6 +37,8 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.JoinUtils; @@ -76,7 +79,8 @@ public PhysicalProperties visit(Plan plan, PlanContext context) { } @Override - public PhysicalProperties visitPhysicalAggregate(PhysicalAggregate agg, PlanContext context) { + public PhysicalProperties visitPhysicalHashAggregate( + PhysicalHashAggregate agg, PlanContext context) { Preconditions.checkState(childrenOutputProperties.size() == 1); PhysicalProperties childOutputProperty = childrenOutputProperties.get(0); switch (agg.getAggPhase()) { @@ -218,13 +222,28 @@ public PhysicalProperties visitPhysicalNestedLoopJoin( @Override public PhysicalProperties visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanContext context) { - if (olapScan.getDistributionSpec() instanceof DistributionSpecHash) { + // TODO: find a better way to handle both tablet num == 1 and colocate table together in future + if (!olapScan.getTable().isColocateTable() && olapScan.getScanTabletNum() == 1) { + return PhysicalProperties.GATHER; + } else if (olapScan.getDistributionSpec() instanceof DistributionSpecHash) { return PhysicalProperties.createHash((DistributionSpecHash) olapScan.getDistributionSpec()); } else { return PhysicalProperties.ANY; } } + @Override + public PhysicalProperties visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) { + return storageLayerAggregate.getRelation().accept(this, context); + } + + @Override + public PhysicalProperties visitPhysicalTVFRelation(PhysicalTVFRelation tvfRelation, PlanContext context) { + TableValuedFunction function = tvfRelation.getFunction(); + return function.getPhysicalProperties(); + } + @Override public PhysicalProperties visitPhysicalAssertNumRows(PhysicalAssertNumRows assertNumRows, PlanContext context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 31de1bb7a5da58..3819b4429b02dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -22,10 +22,8 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; -import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; -import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.JoinUtils; @@ -72,12 +70,7 @@ public Double visit(Plan plan, Void context) { } @Override - public Double visitPhysicalAggregate(PhysicalAggregate agg, Void context) { - if (agg.isFinalPhase() - && agg.getAggPhase() == AggPhase.LOCAL - && children.get(0).getPlan() instanceof PhysicalDistribute) { - return -1.0; - } + public Double visitPhysicalHashAggregate(PhysicalHashAggregate agg, Void context) { return 0.0; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java index 35a41f3aa65207..d4c9c24c06293c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java @@ -17,7 +17,15 @@ package org.apache.doris.nereids.properties; +import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; + +import java.util.Collection; +import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * Physical properties used in cascades. @@ -56,6 +64,19 @@ public PhysicalProperties(DistributionSpec distributionSpec, OrderSpec orderSpec this.orderSpec = orderSpec; } + public static PhysicalProperties createHash( + Collection orderedShuffledColumns, ShuffleType shuffleType) { + List partitionedSlots = orderedShuffledColumns.stream() + .map(SlotReference.class::cast) + .map(SlotReference::getExprId) + .collect(Collectors.toList()); + return createHash(partitionedSlots, shuffleType); + } + + public static PhysicalProperties createHash(List orderedShuffledColumns, ShuffleType shuffleType) { + return new PhysicalProperties(new DistributionSpecHash(orderedShuffledColumns, shuffleType)); + } + public static PhysicalProperties createHash(DistributionSpecHash distributionSpecHash) { return new PhysicalProperties(distributionSpecHash); } @@ -99,6 +120,15 @@ public int hashCode() { @Override public String toString() { + if (this.equals(ANY)) { + return "ANY"; + } + if (this.equals(REPLICATED)) { + return "REPLICATED"; + } + if (this.equals(GATHER)) { + return "GATHER"; + } return distributionSpec.toString() + " " + orderSpec.toString(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 8b0fc1c95d24b9..1b8f29fa9108a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -27,15 +27,14 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import com.google.common.base.Preconditions; @@ -76,6 +75,14 @@ public List> getRequestChildrenPropertyList(GroupExpres @Override public Void visit(Plan plan, PlanContext context) { + if (plan instanceof RequirePropertiesSupplier) { + RequireProperties requireProperties = ((RequirePropertiesSupplier) plan).getRequireProperties(); + List requestPhysicalProperties = + requireProperties.computeRequirePhysicalProperties(plan, requestPropertyFromParent); + addRequestPropertyToChildren(requestPhysicalProperties); + return null; + } + List requiredPropertyList = Lists.newArrayListWithCapacity(context.getGroupExpression().arity()); for (int i = context.getGroupExpression().arity(); i > 0; --i) { @@ -85,49 +92,6 @@ public Void visit(Plan plan, PlanContext context) { return null; } - @Override - public Void visitPhysicalAggregate(PhysicalAggregate agg, PlanContext context) { - // 1. first phase agg just return any - if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) { - addRequestPropertyToChildren(PhysicalProperties.ANY); - return null; - } - if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) { - addRequestPropertyToChildren(requestPropertyFromParent); - return null; - } - // 2. second phase agg, need to return shuffle with partition key - List partitionExpressions = agg.getPartitionExpressions(); - if (partitionExpressions.isEmpty() && agg.getAggPhase() != AggPhase.DISTINCT_LOCAL) { - addRequestPropertyToChildren(PhysicalProperties.GATHER); - return null; - } - if (agg.getAggPhase() == AggPhase.DISTINCT_LOCAL) { - // use slots in distinct agg as shuffle slots - List shuffleSlots = extractFromDistinctFunction(agg.getOutputExpressions()); - Preconditions.checkState(!shuffleSlots.isEmpty()); - addRequestPropertyToChildren( - PhysicalProperties.createHash(new DistributionSpecHash(shuffleSlots, ShuffleType.AGGREGATE))); - return null; - } - // TODO: when parent is a join node, - // use requestPropertyFromParent to keep column order as join to avoid shuffle again. - if (partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) { - List partitionedSlots = partitionExpressions.stream() - .map(SlotReference.class::cast) - .map(SlotReference::getExprId) - .collect(Collectors.toList()); - addRequestPropertyToChildren( - PhysicalProperties.createHash(new DistributionSpecHash(partitionedSlots, ShuffleType.AGGREGATE))); - return null; - } - - throw new RuntimeException("Need to add a rule to split aggregate to aggregate(project)," - + " see more in AggregateDisassemble"); - - // TODO: add other phase logical when we support distinct aggregate - } - @Override public Void visitPhysicalQuickSort(PhysicalQuickSort sort, PlanContext context) { addRequestPropertyToChildren(PhysicalProperties.ANY); @@ -183,22 +147,32 @@ private void addRequestPropertyToChildren(PhysicalProperties... physicalProperti requestPropertyToChildren.add(Lists.newArrayList(physicalProperties)); } - private List extractFromDistinctFunction(List outputExpression) { + private void addRequestPropertyToChildren(List physicalProperties) { + requestPropertyToChildren.add(physicalProperties); + } + + private List extractExprIdFromDistinctFunction(List outputExpression) { + Set distinctAggregateFunctions = ExpressionUtils.collect(outputExpression, expr -> + expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct() + ); List exprIds = Lists.newArrayList(); - for (NamedExpression originOutputExpr : outputExpression) { - Set aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (aggregateFunction.isDistinct()) { - for (Expression expr : aggregateFunction.children()) { - Preconditions.checkState(expr instanceof SlotReference, "normalize aggregate failed to" - + " normalize aggregate function " + aggregateFunction.toSql()); - exprIds.add(((SlotReference) expr).getExprId()); - } - } + for (AggregateFunction aggregateFunction : distinctAggregateFunctions) { + for (Expression expr : aggregateFunction.children()) { + Preconditions.checkState(expr instanceof SlotReference, "normalize aggregate failed to" + + " normalize aggregate function " + aggregateFunction.toSql()); + exprIds.add(((SlotReference) expr).getExprId()); } } return exprIds; } + + private void addRequestHashDistribution(List hashColumns, ShuffleType shuffleType) { + List partitionedSlots = hashColumns.stream() + .map(SlotReference.class::cast) + .map(SlotReference::getExprId) + .collect(Collectors.toList()); + addRequestPropertyToChildren( + PhysicalProperties.createHash(new DistributionSpecHash(partitionedSlots, shuffleType))); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequireProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequireProperties.java new file mode 100644 index 00000000000000..534c4814ef63ef --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequireProperties.java @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.properties; + +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** RequireProperties */ +public class RequireProperties { + private final boolean followParentProperties; + private final List properties; + + private RequireProperties(PhysicalProperties... properties) { + this(false, properties); + } + + private RequireProperties(boolean followParentProperties, PhysicalProperties... requireProperties) { + Preconditions.checkArgument((followParentProperties == false && requireProperties.length > 0) + || (followParentProperties == true && requireProperties.length == 0)); + this.properties = ImmutableList.copyOf(requireProperties); + this.followParentProperties = followParentProperties; + } + + public static RequireProperties of(PhysicalProperties... properties) { + return new RequireProperties(properties); + } + + public static RequireProperties followParent() { + return new RequireProperties(true); + } + + public RequirePropertiesTree withChildren(RequireProperties... requireProperties) { + List children = Arrays.stream(requireProperties) + .map(child -> new RequirePropertiesTree(child, ImmutableList.of())) + .collect(ImmutableList.toImmutableList()); + return new RequirePropertiesTree(this, children); + } + + public RequirePropertiesTree withChildren(RequirePropertiesTree... children) { + return new RequirePropertiesTree(this, ImmutableList.copyOf(children)); + } + + public boolean isFollowParentProperties() { + return followParentProperties; + } + + public List getProperties() { + return properties; + } + + /** computeRequirePhysicalProperties */ + public List computeRequirePhysicalProperties( + Plan currentPlan, PhysicalProperties parentRequire) { + int childNum = currentPlan.arity(); + if (followParentProperties) { + // CostAndEnforcerJob will modify this list: requestChildrenProperties.set(curChildIndex, outputProperties) + List requireProperties = Lists.newArrayListWithCapacity(childNum); + for (int i = 0; i < childNum; i++) { + requireProperties.add(parentRequire); + } + return requireProperties; + } else { + Preconditions.checkState(properties.size() == childNum, + "Expect require physical properties num is " + childNum + ", but real is " + + properties.size()); + // CostAndEnforcerJob will modify this list: requestChildrenProperties.set(curChildIndex, outputProperties) + return Lists.newArrayList(properties); + } + } + + @Override + public String toString() { + if (followParentProperties) { + return "followParentProperties"; + } else { + return properties.toString(); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RequireProperties that = (RequireProperties) o; + return followParentProperties == that.followParentProperties + && Objects.equals(properties, that.properties); + } + + @Override + public int hashCode() { + return Objects.hash(followParentProperties, properties); + } + + /** RequirePropertiesTree */ + public static class RequirePropertiesTree { + public final RequireProperties requireProperties; + public final List children; + + private RequirePropertiesTree(RequireProperties requireProperties, List children) { + this.requireProperties = requireProperties; + this.children = children; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequirePropertiesSupplier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequirePropertiesSupplier.java new file mode 100644 index 00000000000000..7f3f6084ff85dc --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequirePropertiesSupplier.java @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.properties; + +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.properties.RequireProperties.RequirePropertiesTree; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** RequirePropertiesSupplier */ +public interface RequirePropertiesSupplier

{ + List children(); + + RequireProperties getRequireProperties(); + + Plan withRequireAndChildren(RequireProperties requireProperties, List children); + + default P withRequire(RequireProperties requireProperties) { + return (P) withRequireAndChildren(requireProperties, children()); + } + + /** withRequireTree */ + default P withRequireTree(RequirePropertiesTree tree) { + List childrenRequires = tree.children; + List children = children(); + if (!childrenRequires.isEmpty() && children.size() != childrenRequires.size()) { + throw new AnalysisException("The number of RequireProperties mismatch the plan tree"); + } + + List newChildren = children; + if (!childrenRequires.isEmpty()) { + ImmutableList.Builder newChildrenBuilder = + ImmutableList.builderWithExpectedSize(childrenRequires.size()); + for (int i = 0; i < children.size(); i++) { + Plan child = children.get(i); + Preconditions.checkState(child instanceof RequirePropertiesSupplier, + "child should be RequirePropertiesTree: " + child); + Plan newChild = ((RequirePropertiesSupplier) child).withRequireTree(childrenRequires.get(i)); + newChildrenBuilder.add(newChild); + } + newChildren = newChildrenBuilder.build(); + } + + return (P) withRequireAndChildren(tree.requireProperties, newChildren); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index d3212f9214c2e0..3668bf6db52592 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject; import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose; import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject; -import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows; import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation; import org.apache.doris.nereids.rules.implementation.LogicalFilterToPhysicalFilter; @@ -40,8 +39,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysicalTVFRelation; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; -import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; -import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble; +import org.apache.doris.nereids.rules.rewrite.AggregateStrategies; import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuterJoin; import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters; import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits; @@ -73,8 +71,7 @@ public class RuleSet { .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP) .add(SemiJoinSemiJoinTranspose.INSTANCE) .add(SemiJoinSemiJoinTransposeProject.INSTANCE) - .add(new AggregateDisassemble()) - .add(new DistinctAggregateDisassemble()) + // .add(new DisassembleDistinctAggregate()) .add(new PushdownFilterThroughProject()) .add(new MergeProjects()) .build(); @@ -93,7 +90,6 @@ public class RuleSet { new MergeLimits()); public static final List IMPLEMENTATION_RULES = planRuleFactories() - .add(new LogicalAggToPhysicalHashAgg()) .add(new LogicalRepeatToPhysicalRepeat()) .add(new LogicalFilterToPhysicalFilter()) .add(new LogicalJoinToHashJoin()) @@ -107,6 +103,7 @@ public class RuleSet { .add(new LogicalOneRowRelationToPhysicalOneRowRelation()) .add(new LogicalEmptyRelationToPhysicalEmptyRelation()) .add(new LogicalTVFRelationToPhysicalTVFRelation()) + .add(new AggregateStrategies()) .build(); public static final List LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1f97de337d9f76..598a3ba64c9c39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -74,10 +74,12 @@ public enum RuleType { CHECK_ROW_POLICY(RuleTypeClass.REWRITE), ELIMINATE_EXCEPT(RuleTypeClass.REWRITE), + ELIMINATE_AGGREGATE(RuleTypeClass.REWRITE), RESOLVE_ORDINAL_IN_ORDER_BY(RuleTypeClass.REWRITE), RESOLVE_ORDINAL_IN_GROUP_BY(RuleTypeClass.REWRITE), // check analysis rule + CHECK_AGGREGATE_ANALYSIS(RuleTypeClass.CHECK), CHECK_ANALYSIS(RuleTypeClass.CHECK), // rewrite rules @@ -85,11 +87,12 @@ public enum RuleType { NORMALIZE_REPEAT(RuleTypeClass.REWRITE), AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE), DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE), - COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE), ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE), + MERGE_PROJECTS(RuleTypeClass.REWRITE), LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE), ELIMINATE_GROUP_BY_CONSTANT(RuleTypeClass.REWRITE), ELIMINATE_ORDER_BY_CONSTANT(RuleTypeClass.REWRITE), + INFER_PREDICATES(RuleTypeClass.REWRITE), // subquery analyze ANALYZE_FILTER_SUBQUERY(RuleTypeClass.REWRITE), @@ -115,6 +118,7 @@ public enum RuleType { // column prune rules, COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE), COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE), + PRUNE_ONE_ROW_RELATION_COLUMN(RuleTypeClass.REWRITE), COLUMN_PRUNE_SORT_CHILD(RuleTypeClass.REWRITE), COLUMN_PRUNE_JOIN_CHILD(RuleTypeClass.REWRITE), COLUMN_PRUNE_REPEAT_CHILD(RuleTypeClass.REWRITE), @@ -127,7 +131,6 @@ public enum RuleType { REORDER_JOIN(RuleTypeClass.REWRITE), // Merge Consecutive plan MERGE_FILTERS(RuleTypeClass.REWRITE), - MERGE_PROJECTS(RuleTypeClass.REWRITE), MERGE_LIMITS(RuleTypeClass.REWRITE), // Eliminate plan ELIMINATE_LIMIT(RuleTypeClass.REWRITE), @@ -189,6 +192,17 @@ public enum RuleType { LOGICAL_LIMIT_TO_PHYSICAL_LIMIT_RULE(RuleTypeClass.IMPLEMENTATION), LOGICAL_OLAP_SCAN_TO_PHYSICAL_OLAP_SCAN_RULE(RuleTypeClass.IMPLEMENTATION), LOGICAL_ASSERT_NUM_ROWS_TO_PHYSICAL_ASSERT_NUM_ROWS(RuleTypeClass.IMPLEMENTATION), + STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT(RuleTypeClass.IMPLEMENTATION), + STORAGE_LAYER_AGGREGATE_WITH_PROJECT(RuleTypeClass.IMPLEMENTATION), + ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT(RuleTypeClass.IMPLEMENTATION), + TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT(RuleTypeClass.IMPLEMENTATION), + TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI(RuleTypeClass.IMPLEMENTATION), + THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI(RuleTypeClass.IMPLEMENTATION), + TWO_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION), + ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI(RuleTypeClass.IMPLEMENTATION), + TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI(RuleTypeClass.IMPLEMENTATION), + TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION), + THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION), IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION), LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java index 1e6527dcce0b68..5935c760323068 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java @@ -33,11 +33,8 @@ import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -176,27 +173,21 @@ public LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFR } @Override - public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) { + public Expression visitUnboundFunction(UnboundFunction unboundFunction, Env env) { unboundFunction = (UnboundFunction) super.visitUnboundFunction(unboundFunction, env); - // FunctionRegistry can't support boolean arg now, tricky here. - if (unboundFunction.getName().equalsIgnoreCase("count")) { - List arguments = unboundFunction.getArguments(); - if ((arguments.size() == 0 && unboundFunction.isStar()) || arguments.stream() - .allMatch(Expression::isConstant)) { - return new Count(); - } - if (arguments.size() == 1) { - AggregateParam aggregateParam = new AggregateParam( - unboundFunction.isDistinct(), true, AggPhase.LOCAL, false); - return new Count(aggregateParam, unboundFunction.getArguments().get(0)); - } - } FunctionRegistry functionRegistry = env.getFunctionRegistry(); String functionName = unboundFunction.getName(); - FunctionBuilder builder = functionRegistry.findFunctionBuilder( - functionName, unboundFunction.getArguments()); - return builder.build(functionName, unboundFunction.getArguments()); + List arguments = unboundFunction.isDistinct() + ? ImmutableList.builder() + .add(unboundFunction.isDistinct()) + .addAll(unboundFunction.getArguments()) + .build() + : (List) unboundFunction.getArguments(); + + FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments); + BoundFunction boundFunction = builder.build(functionName, arguments); + return boundFunction; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index 853a6cb76da46e..df196ea0634a95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -43,6 +43,7 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.LeafPlan; @@ -160,7 +161,7 @@ public List buildRules() { Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) - .collect(Collectors.toMap(Alias::getName, UnaryNode::child)); + .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); List replacedGroupBy = agg.getGroupByExpressions().stream() .map(groupBy -> { if (groupBy instanceof UnboundSlot) { @@ -190,7 +191,7 @@ public List buildRules() { Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) - .collect(Collectors.toMap(Alias::getName, UnaryNode::child)); + .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); List> replacedGroupingSets = repeat.getGroupingSets().stream() .map(groupBy -> groupBy.stream().map(expr -> { @@ -383,9 +384,6 @@ public Slot visitUnboundSlot(UnboundSlot unboundSlot, PlannerContext context) { @Override public Expression visitUnboundStar(UnboundStar unboundStar, PlannerContext context) { - if (!(plan instanceof LogicalProject)) { - throw new AnalysisException("UnboundStar must exists in Projection"); - } List qualifier = unboundStar.getQualifier(); switch (qualifier.size()) { case 0: // select * @@ -480,7 +478,7 @@ private boolean handleNamePartsTwoOrThree(Slot boundSlot, List nameParts } /** BoundStar is used to wrap list of slots for temporary. */ - private class BoundStar extends NamedExpression implements PropagateNullable { + public static class BoundStar extends NamedExpression implements PropagateNullable { public BoundStar(List children) { super(children.toArray(new Slot[0])); Preconditions.checkArgument(children.stream().noneMatch(slot -> slot instanceof UnboundSlot), @@ -495,6 +493,11 @@ public String toSql() { public List getSlots() { return (List) children(); } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitBoundStar(this, context); + } } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 0f8bf93669664e..8ec06987cc4f29 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -22,11 +22,15 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import com.google.common.collect.ImmutableList; import org.apache.commons.lang.StringUtils; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -34,15 +38,25 @@ /** * Check analysis rule to check semantic correct after analysis by Nereids. */ -public class CheckAnalysis extends OneAnalysisRuleFactory { +public class CheckAnalysis implements AnalysisRuleFactory { @Override - public Rule build() { - return any().then(plan -> { - checkBound(plan); - checkExpressionInputTypes(plan); - return null; - }).toRule(RuleType.CHECK_ANALYSIS); + public List buildRules() { + return ImmutableList.of( + RuleType.CHECK_ANALYSIS.build( + any().then(plan -> { + checkBound(plan); + checkExpressionInputTypes(plan); + return null; + }) + ), + RuleType.CHECK_AGGREGATE_ANALYSIS.build( + logicalAggregate().then(agg -> { + checkAggregate(agg); + return agg; + }) + ) + ); } private void checkExpressionInputTypes(Plan plan) { @@ -68,4 +82,18 @@ private void checkBound(Plan plan) { .collect(Collectors.toSet()), ", "))); } } + + private void checkAggregate(LogicalAggregate aggregate) { + Set aggregateFunctions = aggregate.getAggregateFunctions(); + boolean distinctMultiColumns = aggregateFunctions.stream() + .anyMatch(fun -> fun.isDistinct() && fun.arity() > 1); + long distinctFunctionNum = aggregateFunctions.stream() + .filter(AggregateFunction::isDistinct) + .count(); + + if (distinctMultiColumns && distinctFunctionNum > 1) { + throw new AnalysisException( + "The query contains multi count distinct or sum distinct, each can't have multi columns"); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index 421b19932be901..b1d01eac418dcf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -217,8 +217,10 @@ private boolean isEquivalent(Expression source, Expression expression) { return false; } - private boolean checkWhetherNestedAggregateFunctionsExist(AggregateFunction function) { - return function.children().stream().anyMatch(child -> child.anyMatch(AggregateFunction.class::isInstance)); + private boolean checkWhetherNestedAggregateFunctionsExist(AggregateFunction aggregateFunction) { + return aggregateFunction.children() + .stream() + .anyMatch(child -> child.anyMatch(AggregateFunction.class::isInstance)); } private void generateAliasForNewOutputSlots(Expression expression) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 37719170e420af..d88e9b4591ee94 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -20,6 +20,8 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.logical.NormalizeToSlot.NormalizeToSlotContext; +import org.apache.doris.nereids.rules.rewrite.logical.NormalizeToSlot.NormalizeToSlotTriplet; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -41,12 +43,12 @@ import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import javax.annotation.Nullable; /** NormalizeRepeat * eg: select sum(k2 + 1), grouping(k1) from t1 group by grouping sets ((k1)); @@ -111,19 +113,19 @@ private void checkGroupingSetsSize(LogicalRepeat repeat) { } private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { - Set needPushDownExpr = collectPushDownExpressions(repeat); - PushDownContext pushDownContext = PushDownContext.toPushDownContext(repeat, needPushDownExpr); + Set needToSlots = collectNeedToSlotExpressions(repeat); + NormalizeToSlotContext context = buildContext(repeat, needToSlots); // normalize grouping sets to List> List> normalizedGroupingSets = repeat.getGroupingSets() .stream() - .map(groupingSet -> (List) (List) pushDownContext.normalizeToUseSlotRef(groupingSet)) + .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef(groupingSet)) .collect(ImmutableList.toImmutableList()); // replace the arguments of grouping scalar function to virtual slots // replace some complex expression to slot, e.g. `a + 1` - List normalizedAggOutput = - pushDownContext.normalizeToUseSlotRef(repeat.getOutputExpressions()); + List normalizedAggOutput = context.normalizeToUseSlotRef( + repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction); Set virtualSlotsInFunction = ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); @@ -150,7 +152,7 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { .addAll(allVirtualSlots) .build(); - Set pushedProject = pushDownContext.pushDownToNamedExpression(needPushDownExpr); + Set pushedProject = context.pushDownToNamedExpression(needToSlots); Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); LogicalRepeat normalizedRepeat = repeat.withNormalizedExpr( @@ -164,7 +166,7 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { Optional.of(normalizedRepeat), normalizedRepeat); } - private Set collectPushDownExpressions(LogicalRepeat repeat) { + private Set collectNeedToSlotExpressions(LogicalRepeat repeat) { // 3 parts need push down: // flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction @@ -204,109 +206,54 @@ private Plan pushDownProject(Set pushedExprs, Plan originBottom return originBottomPlan; } - private static class PushDownContext { - private final Map pushDownMap; - - public PushDownContext(Map pushDownMap) { - this.pushDownMap = pushDownMap; + /** toPushDownContext */ + public NormalizeToSlotContext buildContext(Repeat repeat, + Set sourceExpressions) { + Set aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance); + Map existsAliasMap = Maps.newLinkedHashMap(); + for (Alias existsAlias : aliases) { + existsAliasMap.put(existsAlias.child(), existsAlias); } - public static PushDownContext toPushDownContext(Repeat repeat, - Set sourceExpressions) { - List groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets()); - Set commonGroupingSetExpressions = repeat.getCommonGroupingSetExpressions(); - - Map pushDownMap = Maps.newLinkedHashMap(); - for (Expression expression : sourceExpressions) { - Optional pushDownTriplet; - if (groupingSetExpressions.contains(expression)) { - boolean isCommonGroupingSetExpression = commonGroupingSetExpressions.contains(expression); - pushDownTriplet = PushDownTriplet.toGroupingSetExpressionPushDownTriplet( - isCommonGroupingSetExpression, expression); - } else { - pushDownTriplet = PushDownTriplet.toPushDownTriplet(expression); - } - - if (pushDownTriplet.isPresent()) { - pushDownMap.put(expression, pushDownTriplet.get()); - } + List groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets()); + Set commonGroupingSetExpressions = repeat.getCommonGroupingSetExpressions(); + + Map normalizeToSlotMap = Maps.newLinkedHashMap(); + for (Expression expression : sourceExpressions) { + Optional pushDownTriplet; + if (groupingSetExpressions.contains(expression)) { + boolean isCommonGroupingSetExpression = commonGroupingSetExpressions.contains(expression); + pushDownTriplet = toGroupingSetExpressionPushDownTriplet( + isCommonGroupingSetExpression, expression, existsAliasMap.get(expression)); + } else { + pushDownTriplet = Optional.of( + NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression))); } - return new PushDownContext(pushDownMap); - } - public List normalizeToUseSlotRef(List expressions) { - return expressions.stream() - .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { - if (child instanceof GroupingScalarFunction) { - GroupingScalarFunction function = (GroupingScalarFunction) child; - List normalizedRealExpressions = normalizeToUseSlotRef(function.getArguments()); - function = function.withChildren(normalizedRealExpressions); - // eliminate GroupingScalarFunction and replace to VirtualSlotReference - return Repeat.generateVirtualSlotByFunction(function); - } - - PushDownTriplet pushDownTriplet = pushDownMap.get(child); - return pushDownTriplet == null ? child : pushDownTriplet.remainExpr; - })).collect(ImmutableList.toImmutableList()); - } - - /** - * generate bottom projections with groupByExpressions. - * eg: - * groupByExpressions: k1#0, k2#1 + 1; - * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; - */ - public Set pushDownToNamedExpression(Collection needToPushExpressions) { - return needToPushExpressions.stream() - .map(expr -> { - PushDownTriplet pushDownTriplet = pushDownMap.get(expr); - return pushDownTriplet == null ? (NamedExpression) expr : pushDownTriplet.pushedExpr; - }).collect(ImmutableSet.toImmutableSet()); - } - } - - private static class PushDownTriplet { - public final Expression originExpr; - public final Slot remainExpr; - public final NamedExpression pushedExpr; - - public PushDownTriplet(Expression originExpr, Slot remainExpr, NamedExpression pushedExpr) { - this.originExpr = originExpr; - this.remainExpr = remainExpr; - this.pushedExpr = pushedExpr; - } - - private static Optional toGroupingSetExpressionPushDownTriplet( - boolean isCommonGroupingSetExpression, Expression expression) { - Optional pushDownTriplet = toPushDownTriplet(expression); - if (!pushDownTriplet.isPresent()) { - return pushDownTriplet; + if (pushDownTriplet.isPresent()) { + normalizeToSlotMap.put(expression, pushDownTriplet.get()); } - - PushDownTriplet originTriplet = pushDownTriplet.get(); - SlotReference remainSlot = (SlotReference) originTriplet.remainExpr; - Slot newSlot = remainSlot.withCommonGroupingSetExpression(isCommonGroupingSetExpression); - return Optional.of(new PushDownTriplet(expression, newSlot, originTriplet.pushedExpr)); } + return new NormalizeToSlotContext(normalizeToSlotMap); + } - private static Optional toPushDownTriplet(Expression expression) { - - if (expression instanceof SlotReference) { - PushDownTriplet pushDownTriplet = - new PushDownTriplet(expression, (SlotReference) expression, (SlotReference) expression); - return Optional.of(pushDownTriplet); - } - - if (expression instanceof NamedExpression) { - NamedExpression namedExpression = (NamedExpression) expression; - PushDownTriplet pushDownTriplet = - new PushDownTriplet(expression, namedExpression.toSlot(), namedExpression); - return Optional.of(pushDownTriplet); - } + private Optional toGroupingSetExpressionPushDownTriplet( + boolean isCommonGroupingSetExpression, Expression expression, @Nullable Alias existsAlias) { + NormalizeToSlotTriplet originTriplet = NormalizeToSlotTriplet.toTriplet(expression, existsAlias); + SlotReference remainSlot = (SlotReference) originTriplet.remainExpr; + Slot newSlot = remainSlot.withCommonGroupingSetExpression(isCommonGroupingSetExpression); + return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr)); + } - Alias alias = new Alias(expression, expression.toSql()); - PushDownTriplet pushDownTriplet = new PushDownTriplet(expression, alias.toSlot(), alias); - return Optional.of(pushDownTriplet); + private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) { + if (expr instanceof GroupingScalarFunction) { + GroupingScalarFunction function = (GroupingScalarFunction) expr; + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + function = function.withChildren(normalizedRealExpressions); + // eliminate GroupingScalarFunction and replace to VirtualSlotReference + return Repeat.generateVirtualSlotByFunction(function); + } else { + return expr; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java index 951fed9ed990c0..4eb6913caf6a8d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java @@ -50,6 +50,10 @@ public ExpressionRewrite(ExpressionRuleExecutor rewriter) { this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null"); } + public Expression rewrite(Expression expression) { + return rewriter.rewrite(expression); + } + @Override public List buildRules() { return ImmutableList.of( @@ -119,8 +123,7 @@ public Rule build() { return agg; } return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions, - agg.isDisassembled(), agg.isNormalized(), agg.isFinalPhase(), agg.getAggPhase(), - agg.getSourceRepeat(), agg.child()); + agg.isNormalized(), agg.getSourceRepeat(), agg.child()); }).toRule(RuleType.REWRITE_AGG_EXPRESSION); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java index 1fc7c9ce00f68c..363869b99eaffb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.CaseWhen; @@ -41,15 +42,16 @@ import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; /** * evaluate an expression on fe. @@ -59,143 +61,148 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return process(expr, ctx); - } + if (expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct()) { + return expr; + } else if (expr instanceof AggregateExpression && ((AggregateExpression) expr).getFunction().isDistinct()) { + return expr; + } - @Override - public Expression visit(Expression expr, ExpressionRewriteContext context) { - return expr; + expr = rewriteChildren(expr, ctx); + if (expr instanceof PropagateNullable && argsHasNullLiteral(expr)) { + return new NullLiteral(expr.getDataType()); + } + return expr.accept(this, ctx); } /** * process constant expression. */ - public Expression process(Expression expr, ExpressionRewriteContext ctx) { - if (expr instanceof PropagateNullable) { - List children = expr.children() - .stream() - .map(child -> process(child, ctx)) - .collect(Collectors.toList()); - - if (ExpressionUtils.hasNullLiteral(children)) { - return Literal.of(null); - } + @Override + public Expression visit(Expression expr, ExpressionRewriteContext ctx) { + return expr; + } - if (!ExpressionUtils.isAllLiteral(children)) { - return expr.withChildren(children); - } - return expr.withChildren(children).accept(this, ctx); - } else { - return expr.accept(this, ctx); - } + @Override + public Expression visitSlot(Slot slot, ExpressionRewriteContext context) { + return slot; + } + + @Override + public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) { + return literal; } @Override public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(equalTo)) { + return equalTo; + } return BooleanLiteral.of(((Literal) equalTo.left()).compareTo((Literal) equalTo.right()) == 0); } @Override public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(greaterThan)) { + return greaterThan; + } return BooleanLiteral.of(((Literal) greaterThan.left()).compareTo((Literal) greaterThan.right()) > 0); } @Override public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(greaterThanEqual)) { + return greaterThanEqual; + } return BooleanLiteral.of(((Literal) greaterThanEqual.left()) .compareTo((Literal) greaterThanEqual.right()) >= 0); } @Override public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(lessThan)) { + return lessThan; + } return BooleanLiteral.of(((Literal) lessThan.left()).compareTo((Literal) lessThan.right()) < 0); - } @Override public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(lessThanEqual)) { + return lessThanEqual; + } return BooleanLiteral.of(((Literal) lessThanEqual.left()).compareTo((Literal) lessThanEqual.right()) <= 0); } @Override public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext context) { - Expression left = process(nullSafeEqual.left(), context); - Expression right = process(nullSafeEqual.right(), context); - if (ExpressionUtils.isAllLiteral(left, right)) { - Literal l = (Literal) left; - Literal r = (Literal) right; - if (l.isNullLiteral() && r.isNullLiteral()) { - return BooleanLiteral.TRUE; - } else if (!l.isNullLiteral() && !r.isNullLiteral()) { - return BooleanLiteral.of(l.compareTo(r) == 0); - } else { - return BooleanLiteral.FALSE; - } + if (!allArgsIsAllLiteral(nullSafeEqual)) { + return nullSafeEqual; + } + Literal l = (Literal) nullSafeEqual.left(); + Literal r = (Literal) nullSafeEqual.right(); + if (l.isNullLiteral() && r.isNullLiteral()) { + return BooleanLiteral.TRUE; + } else if (!l.isNullLiteral() && !r.isNullLiteral()) { + return BooleanLiteral.of(l.compareTo(r) == 0); + } else { + return BooleanLiteral.FALSE; } - return nullSafeEqual.withChildren(left, right); } @Override public Expression visitNot(Not not, ExpressionRewriteContext context) { + if (!allArgsIsAllLiteral(not)) { + return not; + } return BooleanLiteral.of(!((BooleanLiteral) not.child()).getValue()); } - @Override - public Expression visitSlot(Slot slot, ExpressionRewriteContext context) { - return slot; - } - - @Override - public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) { - return literal; - } - @Override public Expression visitAnd(And and, ExpressionRewriteContext context) { - List children = Lists.newArrayList(); - for (Expression child : and.children()) { - Expression newChild = process(child, context); - if (newChild.equals(BooleanLiteral.FALSE)) { - return BooleanLiteral.FALSE; - } - if (!newChild.equals(BooleanLiteral.TRUE)) { - children.add(newChild); - } + if (and.getArguments().stream().anyMatch(BooleanLiteral.FALSE::equals)) { + return BooleanLiteral.FALSE; + } + if (argsHasNullLiteral(and)) { + return Literal.of(null); } - if (children.isEmpty()) { + List nonTrueLiteral = and.children() + .stream() + .filter(conjunct -> !BooleanLiteral.TRUE.equals(conjunct)) + .collect(ImmutableList.toImmutableList()); + if (nonTrueLiteral.isEmpty()) { return BooleanLiteral.TRUE; } - if (children.size() == 1) { - return children.get(0); + if (nonTrueLiteral.size() == and.arity()) { + return and; } - if (ExpressionUtils.isAllNullLiteral(children)) { - return Literal.of(null); + if (nonTrueLiteral.size() == 1) { + return nonTrueLiteral.get(0); } - return and.withChildren(children); + return and.withChildren(nonTrueLiteral); } @Override public Expression visitOr(Or or, ExpressionRewriteContext context) { - List children = Lists.newArrayList(); - for (Expression child : or.children()) { - Expression newChild = process(child, context); - if (newChild.equals(BooleanLiteral.TRUE)) { - return BooleanLiteral.TRUE; - } - if (!newChild.equals(BooleanLiteral.FALSE)) { - children.add(newChild); - } + if (or.getArguments().stream().anyMatch(BooleanLiteral.TRUE::equals)) { + return BooleanLiteral.TRUE; } - if (children.isEmpty()) { + if (ExpressionUtils.isAllNullLiteral(or.getArguments())) { + return Literal.of(null); + } + List nonFalseLiteral = or.children() + .stream() + .filter(conjunct -> !BooleanLiteral.FALSE.equals(conjunct)) + .collect(ImmutableList.toImmutableList()); + if (nonFalseLiteral.isEmpty()) { return BooleanLiteral.FALSE; } - if (children.size() == 1) { - return children.get(0); + if (nonFalseLiteral.size() == or.arity()) { + return or; } - if (ExpressionUtils.isAllNullLiteral(children)) { - return Literal.of(null); + if (nonFalseLiteral.size() == 1) { + return nonFalseLiteral.get(0); } - return or.withChildren(children); + return or.withChildren(nonFalseLiteral); } @Override @@ -205,15 +212,19 @@ public Expression visitLike(Like like, ExpressionRewriteContext context) { @Override public Expression visitCast(Cast cast, ExpressionRewriteContext context) { - Expression child = process(cast.child(), context); + if (!allArgsIsAllLiteral(cast)) { + return cast; + } + Expression child = cast.child(); // todo: process other null case if (child.isNullLiteral()) { - return Literal.of(null); + return new NullLiteral(cast.getDataType()); } - if (child.isLiteral()) { + try { return child.castTo(cast.getDataType()); + } catch (Throwable t) { + return cast; } - return cast.withChildren(child); } @Override @@ -222,12 +233,10 @@ public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewr if (boundFunction.getArguments().isEmpty()) { return boundFunction; } - List newArgs = boundFunction.getArguments().stream().map(arg -> process(arg, context)) - .collect(Collectors.toList()); - if (ExpressionUtils.isAllLiteral(newArgs)) { - return ExpressionEvaluator.INSTANCE.eval(boundFunction.withChildren(newArgs)); + if (!ExpressionUtils.isAllLiteral(boundFunction.getArguments())) { + return boundFunction; } - return boundFunction.withChildren(newArgs); + return ExpressionEvaluator.INSTANCE.eval(boundFunction); } @Override @@ -242,13 +251,13 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont List whenClauses = new ArrayList<>(); for (WhenClause whenClause : caseWhen.getWhenClauses()) { - Expression whenOperand = process(whenClause.getOperand(), context); + Expression whenOperand = whenClause.getOperand(); if (!(whenOperand.isLiteral())) { - whenClauses.add(new WhenClause(whenOperand, process(whenClause.getResult(), context))); + whenClauses.add(new WhenClause(whenOperand, whenClause.getResult())); } else if (BooleanLiteral.TRUE.equals(whenOperand)) { foundNewDefault = true; - newDefault = process(whenClause.getResult(), context); + newDefault = whenClause.getResult(); break; } } @@ -257,7 +266,7 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont if (foundNewDefault) { defaultResult = newDefault; } else { - defaultResult = process(caseWhen.getDefaultValue().orElse(Literal.of(null)), context); + defaultResult = caseWhen.getDefaultValue().orElse(Literal.of(null)); } if (whenClauses.isEmpty()) { @@ -268,47 +277,56 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont @Override public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { - Expression value = process(inPredicate.child(0), context); - List children = Lists.newArrayList(); - children.add(value); + Expression value = inPredicate.child(0); if (value.isNullLiteral()) { return Literal.of(null); } - boolean hasNull = false; - boolean hasUnresolvedValue = !value.isLiteral(); - for (int i = 1; i < inPredicate.children().size(); i++) { - Expression inValue = process(inPredicate.child(i), context); - children.add(inValue); - if (!inValue.isLiteral()) { - hasUnresolvedValue = true; - } - if (inValue.isNullLiteral()) { - hasNull = true; - } - if (inValue.isLiteral() && value.isLiteral() && ((Literal) value).compareTo((Literal) inValue) == 0) { - return Literal.of(true); - } + + boolean valueIsLiteral = value.isLiteral(); + if (!valueIsLiteral) { + return inPredicate; } - if (hasUnresolvedValue) { - return inPredicate.withChildren(children); + + for (Expression item : inPredicate.getOptions()) { + if (valueIsLiteral && value.equals(item)) { + return BooleanLiteral.TRUE; + } } - return hasNull ? Literal.of(null) : Literal.of(false); + return BooleanLiteral.FALSE; } @Override public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) { - Expression child = process(isNull.child(), context); - if (child.isNullLiteral()) { - return Literal.of(true); - } else if (!child.nullable()) { - return Literal.of(false); + if (!allArgsIsAllLiteral(isNull)) { + return isNull; } - return isNull.withChildren(child); + return Literal.of(isNull.child().nullable()); } @Override public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) { return ExpressionEvaluator.INSTANCE.eval(arithmetic); } + + private Expression rewriteChildren(Expression expr, ExpressionRewriteContext ctx) { + List newChildren = new ArrayList<>(); + boolean hasNewChildren = false; + for (Expression child : expr.children()) { + Expression newChild = rewrite(child, ctx); + if (newChild != child) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? expr.withChildren(newChildren) : expr; + } + + private boolean allArgsIsAllLiteral(Expression expression) { + return ExpressionUtils.isAllLiteral(expression.getArguments()); + } + + private boolean argsHasNullLiteral(Expression expression) { + return ExpressionUtils.hasNullLiteral(expression.getArguments()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java deleted file mode 100644 index 46874ff52629d7..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java +++ /dev/null @@ -1,45 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.implementation; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; - -/** - * Implementation rule that convert logical aggregation to physical hash aggregation. - */ -public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory { - @Override - public Rule build() { - return logicalAggregate().thenApply(ctx -> { - boolean useStreamAgg = !ctx.connectContext.getSessionVariable().disableStreamPreaggregations - && !ctx.root.getGroupByExpressions().isEmpty() - && !ctx.root.isFinalPhase(); - return new PhysicalAggregate<>( - ctx.root.getGroupByExpressions(), - ctx.root.getOutputExpressions(), - ctx.root.getPartitionExpressions(), - ctx.root.getAggPhase(), - useStreamAgg, - ctx.root.isFinalPhase(), - ctx.root.getLogicalProperties(), - ctx.root.child()); - }).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java index 582027cef28702..bcde91d3dba573 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java @@ -53,7 +53,6 @@ public Rule build() { olapScan.getSelectedPartitionIds(), convertDistribution(olapScan), olapScan.getPreAggStatus(), - olapScan.getPushDownAggOperator(), Optional.empty(), olapScan.getLogicalProperties()) ).toRule(RuleType.LOGICAL_OLAP_SCAN_TO_PHYSICAL_OLAP_SCAN_RULE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java deleted file mode 100644 index 4d0b55ebe32d0c..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ /dev/null @@ -1,157 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.plans.AggPhase; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.util.ExpressionUtils; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Used to generate the merge agg node for distributed execution. - * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId. - *
- * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
- * the initial plan is:
- *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
- *   +-- childPlan
- * we should rewrite to:
- *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
- *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
- *       +-- childPlan
- * 
- * - * TODO: - * 1. if instance count is 1, shouldn't disassemble the agg plan - */ -public class AggregateDisassemble extends OneRewriteRuleFactory { - - @Override - public Rule build() { - return logicalAggregate() - .when(LogicalAggregate::isFinalPhase) - .when(LogicalAggregate::isLocal) - .then(this::disassembleAggregateFunction).toRule(RuleType.AGGREGATE_DISASSEMBLE); - } - - private LogicalAggregate> disassembleAggregateFunction( - LogicalAggregate aggregate) { - List originOutputExprs = aggregate.getOutputExpressions(); - List originGroupByExprs = aggregate.getGroupByExpressions(); - Map inputSubstitutionMap = Maps.newHashMap(); - - // 1. generate a map from local aggregate output to global aggregate expr substitution. - // inputSubstitutionMap use for replacing expression in global aggregate - // replace rule is: - // a: Expression is a group by key and is a slot reference. e.g. group by k1 - // b. Expression is a group by key and is an expression. e.g. group by k1 + 1 - // c. Expression is an aggregate function. e.g. sum(v1) in select list - // +-----------+---------------------+-------------------------+--------------------------------+ - // | situation | origin expression | local output expression | expression in global aggregate | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) | - // +-----------+---------------------+-------------------------+--------------------------------+ - // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x - // 2. collect local aggregate output expressions and local aggregate group by expression list - List localGroupByExprs = new ArrayList<>(aggregate.getGroupByExpressions()); - List localOutputExprs = Lists.newArrayList(); - for (Expression originGroupByExpr : originGroupByExprs) { - if (inputSubstitutionMap.containsKey(originGroupByExpr)) { - continue; - } - // group by expr must be SlotReference or NormalizeAggregate has bugs. - Preconditions.checkState(originGroupByExpr instanceof SlotReference, - "normalize aggregate failed to normalize group by expression " + originGroupByExpr.toSql()); - inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - localOutputExprs.add((SlotReference) originGroupByExpr); - } - for (NamedExpression originOutputExpr : originOutputExprs) { - Set aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (inputSubstitutionMap.containsKey(aggregateFunction)) { - continue; - } - AggregateFunction localAggregateFunction = aggregateFunction.withAggregateParam( - aggregateFunction.getAggregateParam() - .withPhaseAndDisassembled(false, AggPhase.LOCAL, true) - ); - NamedExpression localOutputExpr = new Alias(localAggregateFunction, aggregateFunction.toSql()); - - AggregateFunction substitutionValue = aggregateFunction - // save the origin input types to the global aggregate functions - .withAggregateParam(aggregateFunction.getAggregateParam() - .withPhaseAndDisassembled(true, AggPhase.GLOBAL, true)) - .withChildren(Lists.newArrayList(localOutputExpr.toSlot())); - - inputSubstitutionMap.put(aggregateFunction, substitutionValue); - localOutputExprs.add(localOutputExpr); - } - } - - // 3. replace expression in globalOutputExprs and globalGroupByExprs - List globalOutputExprs = aggregate.getOutputExpressions().stream() - .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)) - .map(NamedExpression.class::cast) - .collect(Collectors.toList()); - List globalGroupByExprs = localGroupByExprs.stream() - .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList()); - // 4. generate new plan - LogicalAggregate localAggregate = new LogicalAggregate<>( - localGroupByExprs, - localOutputExprs, - true, - aggregate.isNormalized(), - false, - AggPhase.LOCAL, - aggregate.getSourceRepeat(), - aggregate.child() - ); - return new LogicalAggregate<>( - globalGroupByExprs, - globalOutputExprs, - true, - aggregate.isNormalized(), - true, - AggPhase.GLOBAL, - aggregate.getSourceRepeat(), - localAggregate - ); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java new file mode 100644 index 00000000000000..8377d1bcb02e24 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java @@ -0,0 +1,1234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.KeysType; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.annotation.DependsRules; +import org.apache.doris.nereids.pattern.PatternDescriptor; +import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.properties.RequireProperties; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRuleOnFE; +import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; +import org.apache.doris.nereids.rules.implementation.ImplementationRuleFactory; +import org.apache.doris.nereids.rules.implementation.LogicalOlapScanToPhysicalOlapScan; +import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.AggMode; +import org.apache.doris.nereids.trees.plans.AggPhase; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; + +/** AggregateStrategies */ +@DependsRules({ + NormalizeAggregate.class, + FoldConstantRuleOnFE.class +}) +public class AggregateStrategies implements ImplementationRuleFactory { + + @Override + public List buildRules() { + PatternDescriptor> basePattern = logicalAggregate() + .when(LogicalAggregate::isNormalized); + + return ImmutableList.of( + RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build( + logicalAggregate( + logicalOlapScan() + ) + .when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg()) + .thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext)) + ), + RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT.build( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ) + .when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg()) + .thenApply(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalProject project = agg.child(); + LogicalOlapScan olapScan = project.child(); + return storageLayerAggregate(agg, project, olapScan, ctx.cascadesContext); + }) + ), + RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 0) + .thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 0) + .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( + basePattern + .when(this::containsCountDistinctMultiExpr) + .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.connectContext)) + ), + RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( + basePattern + .when(this::containsCountDistinctMultiExpr) + .thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.connectContext)) + ), + RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + ), + RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() > 1 && !containsCountDistinctMultiExpr(agg)) + .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) + ) + ); + } + + /** + * sql: select count(*) from tbl + * + * before: + * + * LogicalAggregate(groupBy=[], output=[count(*)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * LogicalAggregate(groupBy=[], output=[count(*)]) + * | + * PhysicalStorageLayerAggregate(pushAggOp=COUNT, table=PhysicalOlapScan(table=tbl)) + * + */ + private LogicalAggregate storageLayerAggregate( + LogicalAggregate aggregate, + @Nullable LogicalProject project, + LogicalOlapScan olapScan, CascadesContext cascadesContext) { + final LogicalAggregate canNotPush = aggregate; + + KeysType keysType = olapScan.getTable().getKeysType(); + if (keysType != KeysType.AGG_KEYS && keysType != KeysType.DUP_KEYS) { + return canNotPush; + } + + List groupByExpressions = aggregate.getGroupByExpressions(); + if (!groupByExpressions.isEmpty() || !aggregate.getDistinctArguments().isEmpty()) { + return canNotPush; + } + + Set aggregateFunctions = aggregate.getAggregateFunctions(); + Set> functionClasses = aggregateFunctions + .stream() + .map(AggregateFunction::getClass) + .collect(Collectors.toSet()); + + Map supportedAgg = PushDownAggOp.supportedFunctions(); + if (!supportedAgg.keySet().containsAll(functionClasses)) { + return canNotPush; + } + if (functionClasses.contains(Count.class) && keysType != KeysType.DUP_KEYS) { + return canNotPush; + } + if (aggregateFunctions.stream().anyMatch(fun -> fun.arity() > 1)) { + return canNotPush; + } + + // we already normalize the arguments to slotReference + List argumentsOfAggregateFunction = aggregateFunctions.stream() + .flatMap(aggregateFunction -> aggregateFunction.getArguments().stream()) + .collect(ImmutableList.toImmutableList()); + + if (project != null) { + argumentsOfAggregateFunction = Project.findProject( + (List) (List) argumentsOfAggregateFunction, project.getProjects()) + .stream() + .map(p -> p instanceof Alias ? p.child(0) : p) + .collect(ImmutableList.toImmutableList()); + } + + boolean onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction + .stream() + .allMatch(argument -> { + if (argument instanceof SlotReference) { + return true; + } + if (argument instanceof Cast) { + return argument.child(0) instanceof SlotReference + && argument.getDataType().isNumericType() + && argument.child(0).getDataType().isNumericType(); + } + return false; + }); + if (!onlyContainsSlotOrNumericCastSlot) { + return canNotPush; + } + + Set pushDownAggOps = functionClasses.stream() + .map(supportedAgg::get) + .collect(Collectors.toSet()); + + PushDownAggOp mergeOp = pushDownAggOps.size() == 1 + ? pushDownAggOps.iterator().next() + : PushDownAggOp.MIX; + + Set aggUsedSlots = + ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance); + + List usedSlotInTable = (List) (List) Project.findProject(aggUsedSlots, + (List) (List) olapScan.getOutput()); + + for (SlotReference slot : usedSlotInTable) { + Column column = slot.getColumn().get(); + if (keysType == KeysType.AGG_KEYS && !column.isKey()) { + return canNotPush; + } + // The zone map max length of CharFamily is 512, do not + // over the length: https://github.com/apache/doris/pull/6293 + if (mergeOp == PushDownAggOp.MIN_MAX || mergeOp == PushDownAggOp.MIX) { + PrimitiveType colType = column.getType().getPrimitiveType(); + if (colType.isArrayType() || colType.isComplexType() || colType == PrimitiveType.STRING) { + return canNotPush; + } + if (colType.isCharFamily() && mergeOp != PushDownAggOp.COUNT && column.getType().getLength() > 512) { + return canNotPush; + } + } + if (mergeOp == PushDownAggOp.COUNT || mergeOp == PushDownAggOp.MIX) { + // NULL value behavior in `count` function is zero, so + // we should not use row_count to speed up query. the col + // must be not null + if (column.isAllowNull()) { + return canNotPush; + } + } + } + + PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan() + .build() + .transform(olapScan, cascadesContext) + .get(0); + + return aggregate.withChildren(ImmutableList.of( + new PhysicalStorageLayerAggregate(physicalOlapScan, mergeOp) + )); + } + + /** + * sql: select count(*) from tbl group by id + * + * before: + * + * LogicalAggregate(groupBy=[id], output=[count(*)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[id], output=[count(*)]) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[id], output=[count(*)]) + * | + * LogicalOlapScan(table=tbl, **already distribute by id**) + * + */ + private List> onePhaseAggregateWithoutDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + AggregateParam inputToResultParam = AggregateParam.localResult(); + List newOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + return new AggregateExpression((AggregateFunction) outputChild, inputToResultParam); + } + return outputChild; + }); + PhysicalHashAggregate gatherLocalAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), newOutput, Optional.empty(), + inputToResultParam, false, + logicalAgg.getLogicalProperties(), + requireGather, logicalAgg.child()); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.of(gatherLocalAgg); + } else { + RequireProperties requireHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate hashLocalAgg = gatherLocalAgg + .withRequire(requireHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(gatherLocalAgg) + .add(hashLocalAgg) + .build(); + } + } + + /** + * sql: select count(distinct id, name) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * LogicalOlapScan(table=tbl, **already distribute by name**) + * + */ + private List> twoPhaseAggregateWithCountDistinctMulti( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + + Set countDistinctArguments = logicalAgg.getDistinctArguments(); + + List localAggGroupBy = ImmutableList.copyOf(ImmutableSet.builder() + .addAll(logicalAgg.getGroupByExpressions()) + .addAll(countDistinctArguments) + .build()); + + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + + Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() + .filter(aggregateFunction -> !aggregateFunction.isDistinct()) + .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { + AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); + return new Alias(localAggExpr, localAggExpr.toSql()); + })); + + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + List localOutput = ImmutableList.builder() + .addAll((List) (List) localAggGroupBy.stream() + .filter(g -> !(g instanceof Literal)) + .collect(ImmutableList.toImmutableList())) + .addAll(nonDistinctAggFunctionToAliasPhase1.values()) + .build(); + PhysicalHashAggregate gatherLocalAgg = new PhysicalHashAggregate<>( + localAggGroupBy, localOutput, Optional.of(partitionExpressions), + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER), + maybeUsingStreamAgg(connectContext, logicalAgg), + logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child() + ); + + List distinctGroupBy = logicalAgg.getGroupByExpressions(); + + LogicalAggregate countIfAgg = countDistinctMultiExprToCountIf(logicalAgg, connectContext).first; + + AggregateParam distinctInputToResultParam + = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + AggregateParam globalBufferToResultParam + = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); + List distinctOutput = ExpressionUtils.rewriteDownShortCircuit( + countIfAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) outputChild; + Alias alias = nonDistinctAggFunctionToAliasPhase1.get(aggregateFunction); + if (alias == null) { + return new AggregateExpression(aggregateFunction, distinctInputToResultParam); + } else { + return new AggregateExpression(aggregateFunction, + globalBufferToResultParam, alias.toSlot()); + } + } else { + return outputChild; + } + }); + + PhysicalHashAggregate gatherLocalGatherDistinctAgg = new PhysicalHashAggregate<>( + distinctGroupBy, distinctOutput, Optional.of(partitionExpressions), + distinctInputToResultParam, false, + logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg + ); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.of(gatherLocalGatherDistinctAgg); + } else { + RequireProperties requireHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate hashLocalHashGlobalAgg = gatherLocalGatherDistinctAgg + .withRequireTree(requireHash.withChildren(requireHash)) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(gatherLocalGatherDistinctAgg) + .add(hashLocalHashGlobalAgg) + .build(); + } + } + + /** + * sql: select count(distinct id, name) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + */ + private List> threePhaseAggregateWithCountDistinctMulti( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + + Set countDistinctArguments = logicalAgg.getDistinctArguments(); + + List localAggGroupBy = ImmutableList.copyOf(ImmutableSet.builder() + .addAll(logicalAgg.getGroupByExpressions()) + .addAll(countDistinctArguments) + .build()); + + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + + Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() + .filter(aggregateFunction -> !aggregateFunction.isDistinct()) + .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { + AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); + return new Alias(localAggExpr, localAggExpr.toSql()); + })); + + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + List localOutput = ImmutableList.builder() + .addAll((List) (List) localAggGroupBy.stream() + .filter(g -> !(g instanceof Literal)) + .collect(ImmutableList.toImmutableList())) + .addAll(nonDistinctAggFunctionToAliasPhase1.values()) + .build(); + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>( + localAggGroupBy, localOutput, Optional.of(partitionExpressions), + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER), + maybeUsingStreamAgg(connectContext, logicalAgg), + logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child() + ); + + List globalAggGroupBy = localAggGroupBy; + + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + Map nonDistinctAggFunctionToAliasPhase2 = + nonDistinctAggFunctionToAliasPhase1.entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + AggregateFunction originFunction = kv.getKey(); + Alias localOutputAlias = kv.getValue(); + AggregateExpression globalAggExpr = new AggregateExpression( + originFunction, bufferToBufferParam, localOutputAlias.toSlot()); + return new Alias(globalAggExpr, globalAggExpr.toSql()); + })); + + Set slotInCountDistinct = ExpressionUtils.collect( + ImmutableList.copyOf(countDistinctArguments), SlotReference.class::isInstance); + List globalAggOutput = ImmutableList.copyOf(ImmutableSet.builder() + .addAll((List) (List) logicalAgg.getGroupByExpressions()) + .addAll(slotInCountDistinct) + .addAll(nonDistinctAggFunctionToAliasPhase2.values()) + .build()); + + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( + globalAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), + bufferToBufferParam, false, logicalAgg.getLogicalProperties(), + requireGather, anyLocalAgg); + + LogicalAggregate countIfAgg = countDistinctMultiExprToCountIf(logicalAgg, connectContext).first; + + AggregateParam distinctInputToResultParam + = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + AggregateParam globalBufferToResultParam + = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); + List distinctOutput = ExpressionUtils.rewriteDownShortCircuit( + countIfAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) outputChild; + Alias alias = nonDistinctAggFunctionToAliasPhase2.get(aggregateFunction); + if (alias == null) { + return new AggregateExpression(aggregateFunction, distinctInputToResultParam); + } else { + return new AggregateExpression(aggregateFunction, + globalBufferToResultParam, alias.toSlot()); + } + } else { + return outputChild; + } + }); + + PhysicalHashAggregate anyLocalGatherGlobalGatherAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), distinctOutput, Optional.empty(), + distinctInputToResultParam, false, + logicalAgg.getLogicalProperties(), requireGather, anyLocalGatherGlobalAgg + ); + + RequireProperties requireDistinctHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate anyLocalHashGlobalGatherDistinctAgg + = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of( + anyLocalGatherGlobalAgg + .withRequire(requireDistinctHash) + .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + )); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.>builder() + .add(anyLocalGatherGlobalGatherAgg) + .add(anyLocalHashGlobalGatherDistinctAgg) + .build(); + } else { + RequireProperties requireGroupByHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate> anyLocalHashGlobalHashDistinctAgg + = anyLocalGatherGlobalGatherAgg.withRequirePropertiesAndChild(requireGroupByHash, + anyLocalGatherGlobalAgg + .withRequire(requireGroupByHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()) + ) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(anyLocalGatherGlobalGatherAgg) + .add(anyLocalHashGlobalGatherDistinctAgg) + .add(anyLocalHashGlobalHashDistinctAgg) + .build(); + } + } + + /** + * sql: select name, count(value) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[name, count(value)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + */ + private List> twoPhaseAggregateWithoutDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + Map inputToBufferAliases = logicalAgg.getAggregateFunctions() + .stream() + .collect(ImmutableMap.toImmutableMap(function -> function, function -> { + AggregateExpression inputToBuffer = new AggregateExpression(function, inputToBufferParam); + return new Alias(inputToBuffer, inputToBuffer.toSql()); + })); + + List localAggGroupBy = logicalAgg.getGroupByExpressions(); + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + List localAggOutput = ImmutableList.builder() + // we already normalized the group by expressions to List by the NormalizeAggregate rule + .addAll((List) localAggGroupBy) + .addAll(inputToBufferAliases.values()) + .build(); + + RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>( + localAggGroupBy, localAggOutput, Optional.of(partitionExpressions), + inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg), + logicalAgg.getLogicalProperties(), requireAny, + logicalAgg.child()); + + AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); + List globalAggOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), outputChild -> { + Alias inputToBufferAlias = inputToBufferAliases.get(outputChild); + if (inputToBufferAlias == null) { + return outputChild; + } + AggregateFunction function = (AggregateFunction) outputChild; + return new AggregateExpression(function, bufferToResultParam, inputToBufferAlias.toSlot()); + }); + + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate( + localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), + bufferToResultParam, false, anyLocalAgg.getLogicalProperties(), + requireGather, anyLocalAgg); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.of(anyLocalGatherGlobalAgg); + } else { + RequireProperties requireHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + + PhysicalHashAggregate anyLocalHashGlobalAgg = anyLocalGatherGlobalAgg + .withRequire(requireHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(anyLocalGatherGlobalAgg) + .add(anyLocalHashGlobalAgg) + .build(); + } + } + + /** + * sql: select count(distinct id) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * LogicalOlapScan(table=tbl, **if distribute by name**) + * + */ + private List> twoPhaseAggregateWithDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + + Set distinctArguments = aggregateFunctions.stream() + .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .collect(ImmutableSet.toImmutableSet()); + + Set localAggGroupBy = ImmutableSet.builder() + .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll(distinctArguments) + .build(); + + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + + Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() + .filter(aggregateFunction -> !aggregateFunction.isDistinct()) + .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { + AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); + return new Alias(localAggExpr, localAggExpr.toSql()); + })); + + List localAggOutput = ImmutableList.builder() + .addAll(localAggGroupBy) + .addAll(nonDistinctAggFunctionToAliasPhase1.values()) + .build(); + + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + PhysicalHashAggregate gatherLocalAgg = new PhysicalHashAggregate<>(ImmutableList.copyOf(localAggGroupBy), + localAggOutput, Optional.of(partitionExpressions), inputToBufferParam, + /* + * should not use streaming, there has some bug in be will compute wrong result, + * see aggregate_strategies.groovy + */ + false, Optional.empty(), logicalAgg.getLogicalProperties(), + requireGather, logicalAgg.child()); + + AggregateParam inputToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT); + List globalOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) outputChild; + if (aggregateFunction.isDistinct()) { + Preconditions.checkArgument(aggregateFunction.arity() == 1); + AggregateFunction nonDistinct = aggregateFunction + .withDistinctAndChildren(false, aggregateFunction.getArguments()); + return new AggregateExpression(nonDistinct, AggregateParam.localResult()); + } else { + Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild); + return new AggregateExpression( + aggregateFunction, inputToResultParam, alias.toSlot()); + } + } else { + return outputChild; + } + }); + + PhysicalHashAggregate gatherLocalGatherGlobalAgg + = new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), globalOutput, + Optional.empty(), inputToResultParam, false, + logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( + distinctArguments, ShuffleType.AGGREGATE)); + PhysicalHashAggregate hashLocalGatherGlobalAgg = gatherLocalGatherGlobalAgg + .withChildren(ImmutableList.of(gatherLocalAgg + .withRequire(requireDistinctHash) + .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + )); + return ImmutableList.>builder() + .add(gatherLocalGatherGlobalAgg) + .add(hashLocalGatherGlobalAgg) + .build(); + } else { + RequireProperties requireGroupByHash = RequireProperties.of(PhysicalProperties.createHash( + logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate> hashLocalHashGlobalAgg = gatherLocalGatherGlobalAgg + .withRequirePropertiesAndChild(requireGroupByHash, gatherLocalAgg + .withRequire(requireGroupByHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()) + ) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(gatherLocalGatherGlobalAgg) + .add(hashLocalHashGlobalAgg) + .build(); + } + } + + /** + * sql: select count(distinct id) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + */ + // TODO: support one phase aggregate(group by columns + distinct columns) + two phase distinct aggregate + private List> threePhaseAggregateWithDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + + Set distinctArguments = aggregateFunctions.stream() + .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .collect(ImmutableSet.toImmutableSet()); + + Set localAggGroupBySet = ImmutableSet.builder() + .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll(distinctArguments) + .build(); + + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + + Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() + .filter(aggregateFunction -> !aggregateFunction.isDistinct()) + .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { + AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); + return new Alias(localAggExpr, localAggExpr.toSql()); + })); + + List localAggOutput = ImmutableList.builder() + .addAll(localAggGroupBySet) + .addAll(nonDistinctAggFunctionToAliasPhase1.values()) + .build(); + + List localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet); + boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy); + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy, + localAggOutput, Optional.of(partitionExpressions), inputToBufferParam, + maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), + requireAny, logicalAgg.child()); + + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + Map nonDistinctAggFunctionToAliasPhase2 = + nonDistinctAggFunctionToAliasPhase1.entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + AggregateFunction originFunction = kv.getKey(); + Alias localOutput = kv.getValue(); + AggregateExpression globalAggExpr = new AggregateExpression( + originFunction, bufferToBufferParam, localOutput.toSlot()); + return new Alias(globalAggExpr, globalAggExpr.toSql()); + })); + + List globalAggOutput = ImmutableList.builder() + .addAll(localAggGroupBySet) + .addAll(nonDistinctAggFunctionToAliasPhase2.values()) + .build(); + + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( + localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), + bufferToBufferParam, false, logicalAgg.getLogicalProperties(), + requireGather, anyLocalAgg); + + AggregateParam bufferToResultParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + List distinctOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), expr -> { + if (expr instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) expr; + if (aggregateFunction.isDistinct()) { + Preconditions.checkArgument(aggregateFunction.arity() == 1); + AggregateFunction nonDistinct = aggregateFunction + .withDistinctAndChildren(false, aggregateFunction.getArguments()); + return new AggregateExpression(nonDistinct, + bufferToResultParam, aggregateFunction.child(0)); + } else { + Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr); + return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot()); + } + } + return expr; + }); + + PhysicalHashAggregate anyLocalGatherGlobalGatherDistinctAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), distinctOutput, Optional.empty(), + bufferToResultParam, false, logicalAgg.getLogicalProperties(), + requireGather, anyLocalGatherGlobalAgg); + + RequireProperties requireDistinctHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate anyLocalHashGlobalGatherDistinctAgg + = anyLocalGatherGlobalGatherDistinctAgg + .withChildren(ImmutableList.of(anyLocalGatherGlobalAgg + .withRequire(requireDistinctHash) + .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + )); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.>builder() + .add(anyLocalGatherGlobalGatherDistinctAgg) + .add(anyLocalHashGlobalGatherDistinctAgg) + .build(); + } else { + RequireProperties requireGroupByHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate anyLocalHashGlobalHashDistinctAgg + = anyLocalGatherGlobalGatherDistinctAgg + .withRequirePropertiesAndChild(requireGroupByHash, anyLocalGatherGlobalAgg + .withRequire(requireGroupByHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()) + ) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(anyLocalGatherGlobalGatherDistinctAgg) + .add(anyLocalHashGlobalGatherDistinctAgg) + .add(anyLocalHashGlobalHashDistinctAgg) + .build(); + } + } + + /** + * sql: select count(distinct id) from (...) group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[count(distinct id)]) + * | + * any plan + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * any plan + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) + * | + * any plan(**already distribute by name**) + * + */ + private List> onePhaseAggregateWithMultiDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + AggregateParam inputToResultParam = AggregateParam.localResult(); + List newOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + AggregateFunction function = tryConvertToMultiDistinct((AggregateFunction) outputChild); + return new AggregateExpression(function, inputToResultParam); + } + return outputChild; + }); + + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); + PhysicalHashAggregate gatherLocalAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), newOutput, inputToResultParam, + maybeUsingStreamAgg(connectContext, logicalAgg), + logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child()); + if (logicalAgg.getGroupByExpressions().isEmpty()) { + return ImmutableList.of(gatherLocalAgg); + } else { + RequireProperties requireHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate hashLocalAgg = gatherLocalAgg + .withRequire(requireHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(gatherLocalAgg) + .add(hashLocalAgg) + .build(); + } + } + + /** + * sql: select count(distinct id) from tbl group by name + * + * before: + * + * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) + * | + * LogicalOlapScan(table=tbl) + * + * after: + * + * single node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=GATHER) + * | + * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + * distribute node aggregate: + * + * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) + * | + * PhysicalDistribute(distributionSpec=HASH(name)) + * | + * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER) + * | + * LogicalOlapScan(table=tbl) + * + */ + private List> twoPhaseAggregateWithMultiDistinct( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + Map aggFunctionToAliasPhase1 = aggregateFunctions.stream() + .collect(ImmutableMap.toImmutableMap(function -> function, function -> { + AggregateFunction multiDistinct = tryConvertToMultiDistinct(function); + AggregateExpression localAggExpr = new AggregateExpression(multiDistinct, inputToBufferParam); + return new Alias(localAggExpr, localAggExpr.toSql()); + })); + + List localAggOutput = ImmutableList.builder() + // already normalize group by expression to List + .addAll((List) (List) logicalAgg.getGroupByExpressions()) + .addAll(aggFunctionToAliasPhase1.values()) + .build(); + + RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), localAggOutput, + inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg), + logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()); + + AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); + List globalOutput = ExpressionUtils.rewriteDownShortCircuit( + logicalAgg.getOutputExpressions(), outputChild -> { + if (outputChild instanceof AggregateFunction) { + Alias alias = aggFunctionToAliasPhase1.get(outputChild); + AggregateExpression localAggExpr = (AggregateExpression) alias.child(); + return new AggregateExpression(localAggExpr.getFunction(), + bufferToResultParam, alias.toSlot()); + } else { + return outputChild; + } + }); + + PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), globalOutput, Optional.empty(), + bufferToResultParam, false, logicalAgg.getLogicalProperties(), + RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg); + + if (logicalAgg.getGroupByExpressions().isEmpty()) { + Set distinctArguments = logicalAgg.getDistinctArguments(); + RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( + distinctArguments, ShuffleType.AGGREGATE)); + PhysicalHashAggregate hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg + .withChildren(ImmutableList.of(anyLocalAgg + .withRequire(requireDistinctHash) + .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + )); + return ImmutableList.>builder() + .add(anyLocalGatherGlobalAgg) + .add(hashLocalGatherGlobalAgg) + .build(); + } else { + RequireProperties requireHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.AGGREGATE)); + PhysicalHashAggregate anyLocalHashGlobalAgg = anyLocalGatherGlobalAgg + .withRequire(requireHash) + .withPartitionExpressions(logicalAgg.getGroupByExpressions()); + return ImmutableList.>builder() + .add(anyLocalGatherGlobalAgg) + .add(anyLocalHashGlobalAgg) + .build(); + } + } + + private boolean maybeUsingStreamAgg( + ConnectContext connectContext, LogicalAggregate logicalAggregate) { + return !connectContext.getSessionVariable().disableStreamPreaggregations + && !logicalAggregate.getGroupByExpressions().isEmpty(); + } + + private boolean maybeUsingStreamAgg( + ConnectContext connectContext, List groupByExpressions) { + return !connectContext.getSessionVariable().disableStreamPreaggregations + && !groupByExpressions.isEmpty(); + } + + private List getHashAggregatePartitionExpressions( + LogicalAggregate logicalAggregate) { + return logicalAggregate.getGroupByExpressions().isEmpty() + ? ImmutableList.copyOf(logicalAggregate.getDistinctArguments()) + : logicalAggregate.getGroupByExpressions(); + } + + private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) { + if (function instanceof Count && function.isDistinct()) { + return new MultiDistinctCount(function.getArgument(0), + function.getArguments().subList(1, function.arity()).toArray(new Expression[0])); + } else if (function instanceof Sum && function.isDistinct()) { + return new MultiDistinctSum(function.getArgument(0)); + } + return function; + } + + /** + * countDistinctMultiExprToCountIf. + * + * NOTE: this function will break the normalized output, e.g. from `count(distinct slot1, slot2)` to + * `count(if(slot1 is null, null, slot2))`. So if you invoke this method, and separate the + * phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate. + */ + private Pair, List> countDistinctMultiExprToCountIf( + LogicalAggregate aggregate, ConnectContext connectContext) { + ImmutableList.Builder countIfList = ImmutableList.builder(); + List newOutput = ExpressionUtils.rewriteDownShortCircuit( + aggregate.getOutputExpressions(), outputChild -> { + if (outputChild instanceof Count) { + Count count = (Count) outputChild; + if (count.isDistinct() && count.arity() > 1) { + Set arguments = ImmutableSet.copyOf(count.getArguments()); + Expression countExpr = count.getArgument(arguments.size() - 1); + for (int i = arguments.size() - 2; i >= 0; --i) { + Expression argument = count.getArgument(i); + If ifNull = new If(new IsNull(argument), NullLiteral.INSTANCE, countExpr); + countExpr = assignNullType(ifNull, connectContext); + } + Count countIf = new Count(countExpr); + countIfList.add(countIf); + return countIf; + } + } + return outputChild; + }); + return Pair.of(aggregate.withAggOutput(newOutput), countIfList.build()); + } + + private boolean containsCountDistinctMultiExpr(LogicalAggregate aggregate) { + return ExpressionUtils.anyMatch(aggregate.getOutputExpressions(), expr -> + expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() > 1); + } + + // don't invoke the ExpressionNormalization, because the expression maybe simplified and get rid of some slots + private If assignNullType(If ifExpr, ConnectContext context) { + If ifWithCoercion = (If) TypeCoercion.INSTANCE.rewrite(ifExpr, new ExpressionRewriteContext(context)); + Expression trueValue = ifWithCoercion.getArgument(1); + if (trueValue instanceof Cast && trueValue.child(0) instanceof NullLiteral) { + List newArgs = Lists.newArrayList(ifWithCoercion.getArguments()); + // backend don't support null type, so we should set the type + newArgs.set(1, new NullLiteral(((Cast) trueValue).getDataType())); + return ifWithCoercion.withChildren(newArgs); + } + return ifWithCoercion; + } + + private boolean enablePushDownNoGroupAgg() { + ConnectContext connectContext = ConnectContext.get(); + return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg(); + } + + private boolean enableSingleDistinctColumnOpt() { + ConnectContext connectContext = ConnectContext.get(); + return connectContext == null || connectContext.getSessionVariable().enableSingleDistinctColumnOpt(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java deleted file mode 100644 index 1d182cf346b697..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java +++ /dev/null @@ -1,190 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.plans.AggPhase; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.util.ExpressionUtils; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Used to generate the merge agg node for distributed execution. - * NOTICE: DISTINCT GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId. - *
- * If we have a query: SELECT COUNT(distinct v1 * v2) + 1 FROM t
- * the initial plan is:
- *   +-- Aggregate(phase: [LOCAL], outputExpr: [Alias(COUNT(distinct v1 * v2) + 1) #2])
- *       +-- childPlan
- * we should rewrite to:
- *   Aggregate(phase: [GLOBAL DISTINCT], outputExpr: [Alias(SUM(c) + 1) #2])
- *   +-- Aggregate(phase: [LOCAL DISTINCT], outputExpr: [SUM(b) as c] )
- *       +-- Aggregate(phase: [GLOBAL], outputExpr: [COUNT(distinct a) as b])
- *           +-- Aggregate(phase: [LOCAL], outputExpr: [COUNT(distinct v1 * v2) as a])
- *               +-- childPlan
- * 
- */ -public class DistinctAggregateDisassemble extends OneRewriteRuleFactory { - - @Override - public Rule build() { - return logicalAggregate() - .when(LogicalAggregate::needDistinctDisassemble) - .then(this::disassembleAggregateFunction).toRule(RuleType.DISTINCT_AGGREGATE_DISASSEMBLE); - } - - private LogicalAggregate>>> - disassembleAggregateFunction( - LogicalAggregate aggregate) { - // Double-check to prevent incorrect changes - Preconditions.checkArgument(aggregate.getAggPhase() == AggPhase.LOCAL); - Preconditions.checkArgument(aggregate.isFinalPhase()); - List groupByExpressions = aggregate.getGroupByExpressions(); - if (groupByExpressions == null || groupByExpressions.isEmpty()) { - // If there are no group by expressions, in order to parallelize, - // we need to manually use the distinct function argument as group by expressions - groupByExpressions = new ArrayList<>(getDistinctFunctionParams(aggregate)); - } - Pair, List> localAndGlobal = - disassemble(aggregate.getOutputExpressions(), - groupByExpressions, - AggPhase.LOCAL, AggPhase.GLOBAL); - Pair, List> globalAndDistinctLocal = - disassemble(localAndGlobal.second, - groupByExpressions, - AggPhase.GLOBAL, AggPhase.DISTINCT_LOCAL); - Pair, List> distinctLocalAndDistinctGlobal = - disassemble(globalAndDistinctLocal.second, - aggregate.getGroupByExpressions(), - AggPhase.DISTINCT_LOCAL, AggPhase.DISTINCT_GLOBAL); - // generate new plan - LogicalAggregate localAggregate = new LogicalAggregate<>( - groupByExpressions, - localAndGlobal.first, - true, - aggregate.isNormalized(), - false, - AggPhase.LOCAL, - aggregate.getSourceRepeat(), - aggregate.child() - ); - LogicalAggregate> globalAggregate = new LogicalAggregate<>( - groupByExpressions, - globalAndDistinctLocal.first, - true, - aggregate.isNormalized(), - false, - AggPhase.GLOBAL, - aggregate.getSourceRepeat(), - localAggregate - ); - LogicalAggregate>> distinctLocalAggregate = - new LogicalAggregate<>( - aggregate.getGroupByExpressions(), - distinctLocalAndDistinctGlobal.first, - true, - aggregate.isNormalized(), - false, - AggPhase.DISTINCT_LOCAL, - aggregate.getSourceRepeat(), - globalAggregate - ); - return new LogicalAggregate<>( - aggregate.getGroupByExpressions(), - distinctLocalAndDistinctGlobal.second, - true, - aggregate.isNormalized(), - true, - AggPhase.DISTINCT_GLOBAL, - aggregate.getSourceRepeat(), - distinctLocalAggregate - ); - } - - private Pair, List> disassemble( - List originOutputExprs, - List childGroupByExprs, - AggPhase childPhase, - AggPhase parentPhase) { - Map inputSubstitutionMap = Maps.newHashMap(); - - List childOutputExprs = Lists.newArrayList(); - // The groupBy slots are placed at the beginning of the output, in line with the stale optimiser - childGroupByExprs.stream().forEach(expression -> childOutputExprs.add((SlotReference) expression)); - for (NamedExpression originOutputExpr : originOutputExprs) { - Set aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (inputSubstitutionMap.containsKey(aggregateFunction)) { - continue; - } - AggregateFunction childAggregateFunction = aggregateFunction.withAggregateParam( - aggregateFunction.getAggregateParam() - .withPhaseAndDisassembled(false, childPhase, true) - ); - NamedExpression childOutputExpr = new Alias(childAggregateFunction, aggregateFunction.toSql()); - AggregateFunction substitutionValue = aggregateFunction - // save the origin input types to the global aggregate functions - .withAggregateParam(aggregateFunction.getAggregateParam() - .withPhaseAndDisassembled(true, parentPhase, true)) - .withChildren(Lists.newArrayList(childOutputExpr.toSlot())); - - inputSubstitutionMap.put(aggregateFunction, substitutionValue); - childOutputExprs.add(childOutputExpr); - } - } - - // 3. replace expression in parentOutputExprs - List parentOutputExprs = originOutputExprs.stream() - .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)) - .map(NamedExpression.class::cast) - .collect(Collectors.toList()); - return Pair.of(childOutputExprs, parentOutputExprs); - } - - private List getDistinctFunctionParams(LogicalAggregate agg) { - List result = new ArrayList<>(); - for (NamedExpression originOutputExpr : agg.getOutputExpressions()) { - Set aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (aggregateFunction.isDistinct()) { - result.addAll(aggregateFunction.children()); - } - } - } - return result; - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateAggregate.java new file mode 100644 index 00000000000000..8e3d4977af443d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateAggregate.java @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; + +import java.util.List; + +/** EliminateAggregate */ +public class EliminateAggregate extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate(logicalAggregate()).then(outerAgg -> { + LogicalAggregate innerAgg = outerAgg.child(); + + if (!isSame(outerAgg.getGroupByExpressions(), innerAgg.getGroupByExpressions())) { + return outerAgg; + } + if (!onlyHasSlots(outerAgg.getOutputExpressions())) { + return outerAgg; + } + List prunedInnerAggOutput = Project.findProject(outerAgg.getOutputSet(), + innerAgg.getOutputExpressions()); + return innerAgg.withAggOutput(prunedInnerAggOutput); + }).toRule(RuleType.ELIMINATE_AGGREGATE); + } + + private boolean isSame(List list1, List list2) { + return list1.size() == list2.size() && list2.containsAll(list1); + } + + private boolean onlyHasSlots(List exprs) { + return exprs.stream().allMatch(SlotReference.class::isInstance); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index efa4c570e3609b..d69cff4d58e18c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -23,22 +23,18 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.util.List; -import java.util.Map; -import java.util.Map.Entry; import java.util.Set; -import java.util.stream.Collectors; /** * normalize aggregate's group keys and AggregateFunction's child to SlotReference @@ -56,100 +52,86 @@ * After rule: * Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, Alias(SR#11))#7, Alias(SR#10 + 1)#8) * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11]) - * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3) + * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3) *

* More example could get from UT {NormalizeAggregateTest} */ -public class NormalizeAggregate extends OneRewriteRuleFactory { +public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot { @Override public Rule build() { return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { - // substitution map used to substitute expression in aggregate's output to use it as top projections - Map substitutionMap = Maps.newHashMap(); - List keys = aggregate.getGroupByExpressions(); - List newOutputs = Lists.newArrayList(); - - // keys - Map> partitionedKeys = keys.stream() - .collect(Collectors.groupingBy(SlotReference.class::isInstance)); - List newKeys = Lists.newArrayList(); - List bottomProjections = Lists.newArrayList(); - if (partitionedKeys.containsKey(false)) { - // process non-SlotReference keys - newKeys.addAll(partitionedKeys.get(false).stream() - .map(e -> new Alias(e, e.toSql())) - .peek(a -> substitutionMap.put(a.child(), a.toSlot())) - .peek(bottomProjections::add) - .map(Alias::toSlot) - .collect(Collectors.toList())); - } - if (partitionedKeys.containsKey(true)) { - // process SlotReference keys - partitionedKeys.get(true).stream() - .map(SlotReference.class::cast) - .peek(s -> substitutionMap.put(s, s)) - .peek(bottomProjections::add) - .forEach(newKeys::add); - } - // add all necessary key to output - substitutionMap.entrySet().stream() - .filter(kv -> aggregate.getOutputExpressions().stream() - .anyMatch(e -> e.anyMatch(kv.getKey()::equals))) - .map(Entry::getValue) - .map(NamedExpression.class::cast) - .forEach(newOutputs::add); - - // if we generate bottom, we need to generate to project too. - // output - List outputs = aggregate.getOutputExpressions(); - Map> partitionedOutputs = outputs.stream() - .collect(Collectors.groupingBy(e -> e.anyMatch(AggregateFunction.class::isInstance))); - - boolean needBottomProjects = partitionedKeys.containsKey(false); - if (partitionedOutputs.containsKey(true)) { - // process expressions that contain aggregate function - Set aggregateFunctions = partitionedOutputs.get(true).stream() - .flatMap(e -> e.>collect(AggregateFunction.class::isInstance).stream()) - .collect(Collectors.toSet()); - - // replace all non-slot expression in aggregate functions children. - for (AggregateFunction aggregateFunction : aggregateFunctions) { - List newChildren = Lists.newArrayList(); - for (Expression child : aggregateFunction.getArguments()) { - if (child instanceof SlotReference || child instanceof Literal) { - newChildren.add(child); - if (child instanceof SlotReference) { - bottomProjections.add((SlotReference) child); - } - } else { - needBottomProjects = true; - Alias alias = new Alias(child, child.toSql()); - bottomProjections.add(alias); - newChildren.add(alias.toSlot()); - } - } - AggregateFunction newFunction = (AggregateFunction) aggregateFunction.withChildren(newChildren); - Alias alias = new Alias(newFunction, newFunction.toSql()); - newOutputs.add(alias); - substitutionMap.put(aggregateFunction, alias.toSlot()); - } - } - - // assemble - LogicalPlan root = aggregate.child(); - if (needBottomProjects) { - root = new LogicalProject<>(bottomProjections, root); - } - root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(), - true, aggregate.isFinalPhase(), aggregate.getAggPhase(), - aggregate.getSourceRepeat(), root); - List projections = outputs.stream() - .map(e -> ExpressionUtils.replace(e, substitutionMap)) - .map(NamedExpression.class::cast) - .collect(Collectors.toList()); - root = new LogicalProject<>(projections, root); - - return root; + // push expression to bottom project + Set existsAliases = ExpressionUtils.collect( + aggregate.getOutputExpressions(), Alias.class::isInstance); + Set needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate); + NormalizeToSlotContext groupByAndArgumentToSlotContext = + NormalizeToSlotContext.buildContext(existsAliases, needToSlots); + Set bottomProjects = + groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots); + Plan normalizedChild = bottomProjects.isEmpty() + ? aggregate.child() + : new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); + + // begin normalize aggregate + + // replace groupBy and arguments of aggregate function to slot, may be this output contains + // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace + // the sum(value) to slot and move the `slot + 1` to the upper project later. + List normalizeOutputPhase1 = groupByAndArgumentToSlotContext + .normalizeToUseSlotRef(aggregate.getOutputExpressions()); + Set normalizedAggregateFunctions = + ExpressionUtils.collect(normalizeOutputPhase1, AggregateFunction.class::isInstance); + + existsAliases = ExpressionUtils.collect(normalizeOutputPhase1, Alias.class::isInstance); + + // now reuse the exists alias for the aggregate functions, + // or create new alias for the aggregate functions + NormalizeToSlotContext aggregateFunctionToSlotContext = + NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions); + + Set normalizedAggregateFunctionsWithAlias = + aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); + + List normalizedGroupBy = + (List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions()); + + // we can safely add all groupBy and aggregate functions to output, because we will + // add a project on it, and the upper project can protect the scope of visible of slot + List normalizedAggregateOutput = ImmutableList.builder() + .addAll(normalizedGroupBy) + .addAll(normalizedAggregateFunctionsWithAlias) + .build(); + + LogicalAggregate normalizedAggregate = aggregate.withNormalized( + (List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild); + + // replace aggregate function to slot + List upperProjects = + aggregateFunctionToSlotContext.normalizeToUseSlotRef(normalizeOutputPhase1); + return new LogicalProject<>(upperProjects, normalizedAggregate); }).toRule(RuleType.NORMALIZE_AGGREGATE); } + + private Set collectGroupByAndArgumentsOfAggregateFunctions(LogicalAggregate aggregate) { + // 2 parts need push down: + // groupingByExpr, argumentsOfAggregateFunction + + Set groupingByExpr = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + + Set aggregateFunctions = ExpressionUtils.collect( + aggregate.getOutputExpressions(), AggregateFunction.class::isInstance); + + ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() + .flatMap(function -> function.getArguments().stream()) + .collect(ImmutableSet.toImmutableSet()); + + ImmutableSet needPushDown = ImmutableSet.builder() + // group by should be pushed down, e.g. group by (k + 1), + // we should push down the `k + 1` to the bottom plan + .addAll(groupingByExpr) + // e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan + .addAll(argumentsOfAggregateFunction) + .build(); + return needPushDown; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java new file mode 100644 index 00000000000000..34686534e39548 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** NormalizeToSlot */ +public interface NormalizeToSlot { + + /** NormalizeSlotContext */ + class NormalizeToSlotContext { + private final Map normalizeToSlotMap; + + public NormalizeToSlotContext(Map normalizeToSlotMap) { + this.normalizeToSlotMap = normalizeToSlotMap; + } + + /** buildContext */ + public static NormalizeToSlotContext buildContext( + Set existsAliases, Set sourceExpressions) { + Map normalizeToSlotMap = Maps.newLinkedHashMap(); + + Map existsAliasMap = Maps.newLinkedHashMap(); + for (Alias existsAlias : existsAliases) { + existsAliasMap.put(existsAlias.child(), existsAlias); + } + + for (Expression expression : sourceExpressions) { + if (normalizeToSlotMap.containsKey(expression)) { + continue; + } + NormalizeToSlotTriplet normalizeToSlotTriplet = + NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)); + normalizeToSlotMap.put(expression, normalizeToSlotTriplet); + } + return new NormalizeToSlotContext(normalizeToSlotMap); + } + + /** normalizeToUseSlotRef, no custom normalize */ + public List normalizeToUseSlotRef(List expressions) { + return normalizeToUseSlotRef(expressions, (context, expr) -> expr); + } + + /** normalizeToUseSlotRef */ + public List normalizeToUseSlotRef(List expressions, + BiFunction customNormalize) { + return expressions.stream() + .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { + Expression newChild = customNormalize.apply(this, child); + if (newChild != null && newChild != child) { + return newChild; + } + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + })).collect(ImmutableList.toImmutableList()); + } + + /** + * generate bottom projections with groupByExpressions. + * eg: + * groupByExpressions: k1#0, k2#1 + 1; + * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; + */ + public Set pushDownToNamedExpression(Collection needToPushExpressions) { + return needToPushExpressions.stream() + .map(expr -> { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); + return normalizeToSlotTriplet == null + ? (NamedExpression) expr + : normalizeToSlotTriplet.pushedExpr; + }).collect(ImmutableSet.toImmutableSet()); + } + } + + /** NormalizeToSlotTriplet */ + class NormalizeToSlotTriplet { + // which expression need to normalized to slot? + // e.g. `a + 1` + public final Expression originExpr; + // the slot already normalized. + // e.g. new Alias(`a + 1`).toSlot() + public final Slot remainExpr; + // the output expression need to push down to the bottom project. + // e.g. new Alias(`a + 1`) + public final NamedExpression pushedExpr; + + public NormalizeToSlotTriplet(Expression originExpr, Slot remainExpr, NamedExpression pushedExpr) { + this.originExpr = originExpr; + this.remainExpr = remainExpr; + this.pushedExpr = pushedExpr; + } + + /** toTriplet */ + public static NormalizeToSlotTriplet toTriplet(Expression expression, @Nullable Alias existsAlias) { + if (existsAlias != null) { + return new NormalizeToSlotTriplet(expression, existsAlias.toSlot(), existsAlias); + } + + if (expression instanceof NamedExpression) { + NamedExpression namedExpression = (NamedExpression) expression; + NormalizeToSlotTriplet normalizeToSlotTriplet = + new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression); + return normalizeToSlotTriplet; + } + + Alias alias = new Alias(expression, expression.toSql()); + return new NormalizeToSlotTriplet(expression, alias.toSlot(), alias); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScan.java deleted file mode 100644 index f301c456c92b07..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScan.java +++ /dev/null @@ -1,197 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.catalog.KeysType; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.types.ArrayType; -import org.apache.doris.nereids.types.BitmapType; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.HllType; -import org.apache.doris.nereids.types.StringType; -import org.apache.doris.nereids.types.coercion.CharacterType; -import org.apache.doris.nereids.types.coercion.NumericType; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * push aggregate without group by exprs to olap scan. - */ -public class PushAggregateToOlapScan implements RewriteRuleFactory { - @Override - public List buildRules() { - return ImmutableList.of( - logicalAggregate(logicalOlapScan()) - .when(aggregate -> check(aggregate, aggregate.child())) - .then(aggregate -> { - LogicalOlapScan olapScan = aggregate.child(); - Map projections = Maps.newHashMap(); - olapScan.getOutput().forEach(s -> projections.put(s, s)); - LogicalOlapScan pushed = pushAggregateToOlapScan(aggregate, olapScan, projections); - if (pushed == olapScan) { - return aggregate; - } else { - return aggregate.withChildren(pushed); - } - }) - .toRule(RuleType.PUSH_AGGREGATE_TO_OLAP_SCAN), - logicalAggregate(logicalProject(logicalOlapScan())) - .when(aggregate -> check(aggregate, aggregate.child().child())) - .then(aggregate -> { - LogicalProject project = aggregate.child(); - LogicalOlapScan olapScan = project.child(); - Map projections = Maps.newHashMap(); - olapScan.getOutput().forEach(s -> projections.put(s, s)); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .filter(alias -> alias.child() instanceof Slot) - .forEach(alias -> projections.put(alias.toSlot(), (Slot) alias.child())); - LogicalOlapScan pushed = pushAggregateToOlapScan(aggregate, olapScan, projections); - if (pushed == olapScan) { - return aggregate; - } else { - return aggregate.withChildren(project.withChildren(pushed)); - } - }) - .toRule(RuleType.PUSH_AGGREGATE_TO_OLAP_SCAN) - ); - } - - private boolean check(LogicalAggregate aggregate, LogicalOlapScan olapScan) { - // session variables - if (ConnectContext.get() != null && !ConnectContext.get().getSessionVariable().enablePushDownNoGroupAgg()) { - return false; - } - - // olap scan - if (olapScan.isAggPushed()) { - return false; - } - KeysType keysType = olapScan.getTable().getKeysType(); - if (keysType == KeysType.UNIQUE_KEYS || keysType == KeysType.PRIMARY_KEYS) { - return false; - } - - // aggregate - if (!aggregate.getGroupByExpressions().isEmpty()) { - return false; - } - List aggregateFunctions = aggregate.getOutputExpressions().stream() - .>map(e -> e.collect(AggregateFunction.class::isInstance)) - .flatMap(Set::stream).collect(Collectors.toList()); - if (aggregateFunctions.stream().anyMatch(af -> af.arity() > 1)) { - return false; - } - if (!aggregateFunctions.stream() - .allMatch(af -> af instanceof Count || af instanceof Min || af instanceof Max)) { - return false; - } - - // both - if (aggregateFunctions.stream().anyMatch(Count.class::isInstance) && keysType != KeysType.DUP_KEYS) { - return false; - } - - return true; - - } - - private LogicalOlapScan pushAggregateToOlapScan( - LogicalAggregate aggregate, - LogicalOlapScan olapScan, - Map projections) { - List aggregateFunctions = aggregate.getOutputExpressions().stream() - .>map(e -> e.collect(AggregateFunction.class::isInstance)) - .flatMap(Set::stream).collect(Collectors.toList()); - - PushDownAggOperator pushDownAggOperator = olapScan.getPushDownAggOperator(); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - pushDownAggOperator = pushDownAggOperator.merge(aggregateFunction.getName()); - if (aggregateFunction.arity() == 0) { - continue; - } - Expression child = aggregateFunction.child(0); - Slot slot; - if (child instanceof Slot) { - slot = (Slot) child; - } else if (child instanceof Cast && child.child(0) instanceof SlotReference) { - if (child.getDataType() instanceof NumericType - && child.child(0).getDataType() instanceof NumericType) { - slot = (Slot) child.child(0); - } else { - return olapScan; - } - } else { - return olapScan; - } - - // replace by SlotReference in olap table. check no complex project on this SlotReference. - if (!projections.containsKey(slot)) { - return olapScan; - } - slot = projections.get(slot); - - DataType dataType = slot.getDataType(); - if (pushDownAggOperator.containsMinMax()) { - - if (dataType instanceof ArrayType - || dataType instanceof HllType - || dataType instanceof BitmapType - || dataType instanceof StringType) { - return olapScan; - } - } - - // The zone map max length of CharFamily is 512, do not - // over the length: https://github.com/apache/doris/pull/6293 - if (dataType instanceof CharacterType - && (((CharacterType) dataType).getLen() > 512 || ((CharacterType) dataType).getLen() < 0)) { - return olapScan; - } - - if (pushDownAggOperator.containsCount() && slot.nullable()) { - return olapScan; - } - } - return olapScan.withPushDownAggregateOperator(pushDownAggOperator); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 600ccafc2e202f..9d4ed1ba95fe02 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.stats; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.CaseWhen; @@ -275,4 +276,10 @@ public ColumnStatistic visitVirtualReference(VirtualSlotReference virtualSlotRef public ColumnStatistic visitBoundFunction(BoundFunction boundFunction, StatsDeriveResult context) { return ColumnStatistic.DEFAULT; } + + @Override + public ColumnStatistic visitAggregateExpression(AggregateExpression aggregateExpression, + StatsDeriveResult context) { + return aggregateExpression.child().accept(this, context); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 51d335d70ce63a..2e8cd7abffc0bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -49,11 +49,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalEmptyRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; @@ -63,6 +63,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; @@ -191,7 +192,7 @@ public StatsDeriveResult visitPhysicalEmptyRelation(PhysicalEmptyRelation emptyR } @Override - public StatsDeriveResult visitPhysicalAggregate(PhysicalAggregate agg, Void context) { + public StatsDeriveResult visitPhysicalHashAggregate(PhysicalHashAggregate agg, Void context) { return computeAggregate(agg); } @@ -210,6 +211,12 @@ public StatsDeriveResult visitPhysicalOlapScan(PhysicalOlapScan olapScan, Void c return computeScan(olapScan); } + @Override + public StatsDeriveResult visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, Void context) { + return storageLayerAggregate.getRelation().accept(this, context); + } + @Override public StatsDeriveResult visitPhysicalTVFRelation(PhysicalTVFRelation tvfRelation, Void context) { return tvfRelation.getFunction().computeStats(tvfRelation.getOutput()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java new file mode 100644 index 00000000000000..6b51100c9fc3be --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.trees.plans.AggMode; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; + +import java.util.List; +import java.util.Objects; + +/** + * AggregateExpression. + * + * It is used to wrap some physical information for the aggregate function, + * so the aggregate function don't need to care about the phase of + * aggregate. + */ +public class AggregateExpression extends Expression implements UnaryExpression, PropagateNullable { + private final AggregateFunction function; + + private final AggregateParam aggregateParam; + + /** local aggregate */ + public AggregateExpression(AggregateFunction aggregate, AggregateParam aggregateParam) { + this(aggregate, aggregateParam, aggregate); + } + + /** aggregate maybe consume a buffer, so the child could be a slot, not an aggregate function */ + public AggregateExpression(AggregateFunction aggregate, AggregateParam aggregateParam, Expression child) { + super(child); + this.function = Objects.requireNonNull(aggregate, "function cannot be null"); + this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregateParam cannot be null"); + } + + public AggregateFunction getFunction() { + return function; + } + + public AggregateParam getAggregateParam() { + return aggregateParam; + } + + public boolean isDistinct() { + return function.isDistinct(); + } + + @Override + public DataType getDataType() { + if (aggregateParam.aggMode.productAggregateBuffer) { + // buffer type + return VarcharType.SYSTEM_DEFAULT; + } else { + // final result type + return function.getDataType(); + } + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitAggregateExpression(this, context); + } + + @Override + public AggregateExpression withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + Expression child = children.get(0); + if (!aggregateParam.aggMode.consumeAggregateBuffer) { + Preconditions.checkArgument(child instanceof AggregateFunction, + "when aggregateMode is " + aggregateParam.aggMode.name() + + ", the child of AggregateExpression should be AggregateFunction, but " + + child.getClass()); + return new AggregateExpression((AggregateFunction) child, aggregateParam); + } else { + return new AggregateExpression(function, aggregateParam, child); + } + } + + public AggregateExpression withAggregateParam(AggregateParam aggregateParam) { + return new AggregateExpression(function, aggregateParam, child()); + } + + @Override + public String toSql() { + if (aggregateParam.aggMode.productAggregateBuffer) { + return "partial_" + function.toSql(); + } else { + return function.toSql(); + } + } + + @Override + public String toString() { + AggMode aggMode = aggregateParam.aggMode; + String prefix = aggMode.productAggregateBuffer ? "partial_" : ""; + if (aggMode.consumeAggregateBuffer) { + return prefix + function.getName() + "(" + child().toString() + ")"; + } else { + return prefix + child().toString(); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + AggregateExpression that = (AggregateExpression) o; + return Objects.equals(function, that.function) + && Objects.equals(aggregateParam, that.aggregateParam) + && Objects.equals(child(), that.child()); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), function, aggregateParam, child()); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index bdf3d3776362a2..8c388b20d2e040 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -80,9 +80,7 @@ private TypeCheckResult checkInputDataTypes(List inputs, List R accept(ExpressionVisitor visitor, C context) { - return visitor.visit(this, context); - } + public abstract R accept(ExpressionVisitor visitor, C context); @Override public List children() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java index 3a7c2f656b762a..f1a0e852f37ea6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.MapType; @@ -84,4 +85,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(super.hashCode(), keyValues); } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitTVFProperties(this, context); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java index 9f96fe9edf969b..8c1f74ebbcaaee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java @@ -20,7 +20,7 @@ import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; -import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit; +import org.apache.doris.nereids.trees.expressions.literal.Interval.TimeUnit; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index 32ad5472c7f189..b2a021a5329fd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -48,11 +48,9 @@ public abstract class BoundFunction extends Expression implements FunctionTrait, private final Supplier signatureCache = Suppliers.memoize(() -> { // first step: find the candidate signature in the signature list - List originArguments = getOriginArguments(); - FunctionSignature matchedSignature = searchSignature( - getOriginArgumentTypes(), originArguments, getSignatures()); + FunctionSignature matchedSignature = searchSignature(getSignatures()); // second step: change the signature, e.g. fill precision for decimal v2 - return computeSignature(matchedSignature, originArguments); + return computeSignature(matchedSignature); }); public BoundFunction(String name, Expression... arguments) { @@ -65,14 +63,14 @@ public BoundFunction(String name, List children) { this.name = Objects.requireNonNull(name, "name can not be null"); } - protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + protected FunctionSignature computeSignature(FunctionSignature signature) { // NOTE: // this computed chain only process the common cases. // If you want to add some common cases to here, please separate the process code // to the other methods and add to this chain. // If you want to add some special cases, please override this method in the special // function class, like 'If' function and 'Substring' function. - return ComputeSignatureChain.from(signature, arguments) + return ComputeSignatureChain.from(signature, getArguments()) .then(this::computePrecisionForDatetimeV2) .then(this::upgradeDateOrDateTimeToV2) .then(this::upgradeDecimalV2ToV3) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java index 7448d00f3197ad..f8649a1abb395e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java @@ -19,7 +19,6 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.annotation.Developing; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; @@ -48,8 +47,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes * * @return the matched signature */ - FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures); + FunctionSignature searchSignature(List signatures); ///// re-defined other interface's methods, so we can mixin this interfaces like a trait ///// diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java index 2c5020fd5131da..edd1c144d58140 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java @@ -18,8 +18,6 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; import com.google.common.collect.ImmutableList; @@ -29,19 +27,16 @@ public interface CustomSignature extends ComputeSignature { // custom generate a function signature. - FunctionSignature customSignature(List argumentTypes, List arguments); + FunctionSignature customSignature(); @Override default List getSignatures() { - List originArgumentTypes = getOriginArgumentTypes(); - List originArguments = getOriginArguments(); - return ImmutableList.of(customSignature(originArgumentTypes, originArguments)); + return ImmutableList.of(customSignature()); } // use the first signature as the candidate signature. @Override - default FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures) { + default FunctionSignature searchSignature(List signatures) { return signatures.get(0); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java index 2726888da44194..2084671e0cc69c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java @@ -41,17 +41,17 @@ public DateTimeWithPrecision(String name, List arguments) { } @Override - protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + protected FunctionSignature computeSignature(FunctionSignature signature) { if (arity() == 1 && signature.returnType instanceof DateTimeV2Type) { // For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when // searching in FunctionSet. So we adjust the return type by hand here. - if (arguments.get(0) instanceof IntegerLikeLiteral) { - IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) arguments.get(0); + if (getArgument(0) instanceof IntegerLikeLiteral) { + IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) getArgument(0); signature = signature.withReturnType(DateTimeV2Type.of(integerLikeLiteral.getIntValue())); } else { signature = signature.withReturnType(DateTimeV2Type.of(6)); } } - return super.computeSignature(signature, arguments); + return super.computeSignature(signature); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java index e71c1449dcaf21..8fb7bd8a789ac9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java @@ -24,6 +24,8 @@ import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; @@ -153,18 +155,18 @@ public static DecimalLiteral multiplyDecimal(DecimalLiteral first, DecimalLitera } @ExecFunction(name = "divide", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE") - public static DoubleLiteral divideDouble(DoubleLiteral first, DoubleLiteral second) { + public static Literal divideDouble(DoubleLiteral first, DoubleLiteral second) { if (second.getValue() == 0.0) { - return null; + return new NullLiteral(first.getDataType()); } double result = first.getValue() / second.getValue(); return new DoubleLiteral(result); } @ExecFunction(name = "divide", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL") - public static DecimalLiteral divideDecimal(DecimalLiteral first, DecimalLiteral second) { + public static Literal divideDecimal(DecimalLiteral first, DecimalLiteral second) { if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { - return null; + return new NullLiteral(first.getDataType()); } BigDecimal result = first.getValue().divide(second.getValue()); return new DecimalLiteral(result); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java index 2ddaec4d01d420..19a4d3fefc4acc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java @@ -19,8 +19,6 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import java.util.List; @@ -38,9 +36,9 @@ static boolean isExplicitlyCastable(AbstractDataType signatureType, AbstractData } @Override - default FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures) { - return SearchSignature.from(signatures, arguments) + default FunctionSignature searchSignature(List signatures) { + + return SearchSignature.from(signatures, getArguments()) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java index ac72ac3db45138..8eeb415983b644 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.types.DataType; +import com.google.common.collect.ImmutableList; + import java.util.List; /** @@ -43,6 +45,13 @@ default Expression getArgument(int index) { return child(index); } + default List getArgumentsTypes() { + return getArguments() + .stream() + .map(Expression::getDataType) + .collect(ImmutableList.toImmutableList()); + } + default DataType getDataType() throws UnboundException { throw new UnboundException("dataType"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java index e452895d566145..d2ead858e1f5ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.trees.expressions.functions; +import org.apache.doris.common.util.ReflectionUtils; import org.apache.doris.nereids.trees.expressions.Expression; import com.google.common.base.Preconditions; @@ -28,6 +29,7 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -57,8 +59,12 @@ public boolean canApply(List arguments) { } for (int i = 0; i < arguments.size(); i++) { Class constructorArgumentType = getConstructorArgumentType(i); - if (!constructorArgumentType.isInstance(arguments.get(i))) { - return false; + Object argument = arguments.get(i); + if (!constructorArgumentType.isInstance(argument)) { + Optional primitiveType = ReflectionUtils.getPrimitiveType(argument.getClass()); + if (!primitiveType.isPresent() || !constructorArgumentType.isAssignableFrom(primitiveType.get())) { + return false; + } } } return true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java index 68c1a557d386e1..152f0f777a976a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java @@ -17,13 +17,6 @@ package org.apache.doris.nereids.trees.expressions.functions; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; - -import com.google.common.collect.ImmutableList; - -import java.util.List; - /** * FunctionTrait. */ @@ -31,15 +24,4 @@ public interface FunctionTrait extends ExpressionTrait { String getName(); boolean hasVarArguments(); - - default List getOriginArguments() { - return getArguments(); - } - - default List getOriginArgumentTypes() { - return getArguments() - .stream() - .map(Expression::getDataType) - .collect(ImmutableList.toImmutableList()); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java index e5018be876793a..5669160377c860 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java @@ -18,8 +18,6 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import java.util.List; @@ -37,9 +35,8 @@ static boolean isIdentical(AbstractDataType signatureType, AbstractDataType real } @Override - default FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures) { - return SearchSignature.from(signatures, arguments) + default FunctionSignature searchSignature(List signatures) { + return SearchSignature.from(signatures, getArguments()) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) .resultOrException(getName()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java index 77cfd54d08d344..24c9e2afb92ec4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java @@ -19,8 +19,6 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import java.util.List; @@ -38,9 +36,8 @@ static boolean isImplicitlyCastable(AbstractDataType signatureType, AbstractData } @Override - default FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures) { - return SearchSignature.from(signatures, arguments) + default FunctionSignature searchSignature(List signatures) { + return SearchSignature.from(signatures, getArguments()) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java index 5f1dd0ee1a3c91..0073c70f93a7b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java @@ -18,8 +18,6 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.NullType; import org.apache.doris.nereids.types.coercion.AbstractDataType; @@ -39,9 +37,8 @@ static boolean isNullOrIdentical(AbstractDataType signatureType, AbstractDataTyp } @Override - default FunctionSignature searchSignature(List argumentTypes, List arguments, - List signatures) { - return SearchSignature.from(signatures, arguments) + default FunctionSignature searchSignature(List signatures) { + return SearchSignature.from(signatures, getArguments()) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index e97071f0a770b7..662f32ad934272 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -17,116 +17,73 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.PartialAggType; +import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * The function which consume arguments in lots of rows and product one value. */ public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes { - private final AggregateParam aggregateParam; + protected final boolean isDistinct; public AggregateFunction(String name, Expression... arguments) { - this(name, AggregateParam.finalPhase(), arguments); + this(name, false, arguments); } - public AggregateFunction(String name, AggregateParam aggregateParam, Expression... arguments) { + public AggregateFunction(String name, boolean isDistinct, Expression... arguments) { super(name, arguments); - this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregateParam can not be null"); + this.isDistinct = isDistinct; } - @Override - public List getOriginArguments() { - return getArgumentsBeforeDisassembled(); + public AggregateFunction(String name, List children) { + this(name, false, children); } - @Override - public List getOriginArgumentTypes() { - return getArgumentTypesBeforeDisassembled(); + public AggregateFunction(String name, boolean isDistinct, List children) { + super(name, children); + this.isDistinct = isDistinct; } @Override public abstract AggregateFunction withChildren(List children); - public abstract AggregateFunction withAggregateParam(AggregateParam aggregateParam); + protected List intermediateTypes() { + return ImmutableList.of(VarcharType.SYSTEM_DEFAULT); + } - protected abstract List intermediateTypes(List argumentTypes, List arguments); + public abstract AggregateFunction withDistinctAndChildren(boolean isDistinct, List children); /** getIntermediateTypes */ public final PartialAggType getIntermediateTypes() { - if (isGlobal() && isDisassembled()) { - return (PartialAggType) child(0).getDataType(); - } - List arguments = getArgumentsBeforeDisassembled(); - List types = getArgumentTypesBeforeDisassembled(); - return new PartialAggType(getArguments(), intermediateTypes(types, arguments)); - } - - public final DataType getFinalType() { - return getSignature().returnType; + return new PartialAggType(getArguments(), intermediateTypes()); } @Override public final DataType getDataType() { - if (aggregateParam.aggPhase.isGlobal() || aggregateParam.isFinalPhase) { - return getFinalType(); - } else { - return getIntermediateTypes(); - } + return getSignature().returnType; } @Override - public final List expectedInputTypes() { - if (isGlobal() && isDisassembled()) { - return ImmutableList.of(getIntermediateTypes()); - } else { - return getSignature().argumentsTypes; - } - } - - public List getArgumentsBeforeDisassembled() { - if (arity() == 1 && getArgument(0).getDataType() instanceof PartialAggType) { - return ((PartialAggType) getArgument(0).getDataType()).getOriginArguments(); - } - return getArguments(); - } - - public List getArgumentTypesBeforeDisassembled() { - return getArgumentsBeforeDisassembled() - .stream() - .map(Expression::getDataType) - .collect(ImmutableList.toImmutableList()); + public List expectedInputTypes() { + return getSignature().argumentsTypes; } public boolean isDistinct() { - return aggregateParam.isDistinct; - } - - public boolean isGlobal() { - return aggregateParam.aggPhase.isGlobal(); - } - - public boolean isFinalPhase() { - return aggregateParam.isFinalPhase; - } - - public boolean isDisassembled() { - return aggregateParam.isDisassembled; - } - - public AggregateParam getAggregateParam() { - return aggregateParam; + return isDistinct; } @Override @@ -138,14 +95,14 @@ public boolean equals(Object o) { return false; } AggregateFunction that = (AggregateFunction) o; - return Objects.equals(aggregateParam, that.aggregateParam) + return Objects.equals(isDistinct, that.isDistinct) && Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children); } @Override public int hashCode() { - return Objects.hash(aggregateParam, getName(), children); + return Objects.hash(isDistinct, getName(), children); } @Override @@ -157,4 +114,22 @@ public R accept(ExpressionVisitor visitor, C context) { public boolean hasVarArguments() { return false; } + + @Override + public String toSql() throws UnboundException { + String args = children() + .stream() + .map(Expression::toSql) + .collect(Collectors.joining(", ")); + return getName() + "(" + (isDistinct ? "DISTINCT " : "") + args + ")"; + } + + @Override + public String toString() { + String args = children() + .stream() + .map(Expression::toString) + .collect(Collectors.joining(", ")); + return getName() + "(" + (isDistinct ? "DISTINCT " : "") + args + ")"; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java index c96e437768522b..97851151e615b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java @@ -17,57 +17,38 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; -import com.google.common.base.Preconditions; - import java.util.Objects; /** AggregateParam. */ public class AggregateParam { - public final boolean isFinalPhase; public final AggPhase aggPhase; - public final boolean isDistinct; - - public final boolean isDisassembled; + public final AggMode aggMode; /** AggregateParam */ - public AggregateParam(boolean isDistinct, boolean isFinalPhase, AggPhase aggPhase, boolean isDisassembled) { - this.isFinalPhase = isFinalPhase; - this.isDistinct = isDistinct; - this.aggPhase = aggPhase; - this.isDisassembled = isDisassembled; - if (!isFinalPhase) { - Preconditions.checkArgument(isDisassembled, - "non-final phase aggregate should be disassembed"); - } - } - - public static AggregateParam finalPhase() { - return new AggregateParam(false, true, AggPhase.LOCAL, false); + public AggregateParam(AggPhase aggPhase, AggMode aggMode) { + this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be null"); + this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be null"); } - public static AggregateParam distinctAndFinalPhase() { - return new AggregateParam(true, true, AggPhase.LOCAL, false); - } - - public AggregateParam withDistinct(boolean isDistinct) { - return new AggregateParam(isDistinct, isFinalPhase, aggPhase, isDisassembled); + public static AggregateParam localResult() { + return new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT); } public AggregateParam withAggPhase(AggPhase aggPhase) { - return new AggregateParam(isDistinct, isFinalPhase, aggPhase, isDisassembled); + return new AggregateParam(aggPhase, aggMode); } - public AggregateParam withDisassembled(boolean isDisassembled) { - return new AggregateParam(isDistinct, isFinalPhase, aggPhase, isDisassembled); + public AggregateParam withAggPhase(AggMode aggMode) { + return new AggregateParam(aggPhase, aggMode); } - public AggregateParam withPhaseAndDisassembled(boolean isFinalPhase, AggPhase aggPhase, - boolean isDisassembled) { - return new AggregateParam(isDistinct, isFinalPhase, aggPhase, isDisassembled); + public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode aggMode) { + return new AggregateParam(aggPhase, aggMode); } @Override @@ -79,14 +60,20 @@ public boolean equals(Object o) { return false; } AggregateParam that = (AggregateParam) o; - return isDistinct == that.isDistinct - && isFinalPhase == that.isFinalPhase - && Objects.equals(aggPhase, that.aggPhase) - && Objects.equals(isDisassembled, that.isDisassembled); + return Objects.equals(aggPhase, that.aggPhase) + && Objects.equals(aggMode, that.aggMode); } @Override public int hashCode() { - return Objects.hash(isDistinct, isFinalPhase, aggPhase, isDisassembled); + return Objects.hash(aggPhase, aggMode); + } + + @Override + public String toString() { + return "AggregateParam{" + + "aggPhase=" + aggPhase + + ", aggMode=" + aggMode + + '}'; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java index 17f4a0ba397120..3961d7ff1a5535 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java @@ -44,32 +44,33 @@ public Avg(Expression child) { super("avg", child); } - public Avg(AggregateParam aggregateParam, Expression child) { - super("avg", aggregateParam, child); + public Avg(boolean isDistinct, Expression child) { + super("avg", isDistinct, child); } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - DataType implicitCastType = implicitCast(argumentTypes.get(0)); + public FunctionSignature customSignature() { + DataType implicitCastType = implicitCast(getArgument(0).getDataType()); return FunctionSignature.ret(implicitCastType).args(implicitCastType); } @Override - protected List intermediateTypes(List argumentTypes, List arguments) { - DataType sumType = getFinalType(); + protected List intermediateTypes() { + DataType sumType = getDataType(); BigIntType countType = BigIntType.INSTANCE; return ImmutableList.of(sumType, countType); } @Override - public Avg withChildren(List children) { + public AggregateFunction withDistinctAndChildren(boolean isDistinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Avg(getAggregateParam(), children.get(0)); + return new Avg(isDistinct, children.get(0)); } @Override - public Avg withAggregateParam(AggregateParam aggregateParam) { - return new Avg(aggregateParam, child()); + public Avg withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new Avg(isDistinct, children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java index d62066a5b7deea..5bace587aeb583 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java @@ -41,24 +41,20 @@ public BitmapIntersect(Expression arg0) { super("bitmap_intersect", arg0); } - public BitmapIntersect(AggregateParam aggregateParam, Expression arg0) { - super("bitmap_intersect", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BitmapType.INSTANCE); } @Override - public BitmapIntersect withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new BitmapIntersect(getAggregateParam(), children.get(0)); + public BitmapIntersect withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public BitmapIntersect withAggregateParam(AggregateParam aggregateParam) { - return new BitmapIntersect(aggregateParam, child()); + public BitmapIntersect withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new BitmapIntersect(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java index f3d2ef0d520b9f..7f8a80b0034b9b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java @@ -41,24 +41,20 @@ public BitmapUnion(Expression arg0) { super("bitmap_union", arg0); } - public BitmapUnion(AggregateParam aggregateParam, Expression arg0) { - super("bitmap_union", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BitmapType.INSTANCE); } @Override - public BitmapUnion withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new BitmapUnion(getAggregateParam(), children.get(0)); + public BitmapUnion withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public BitmapUnion withAggregateParam(AggregateParam aggregateParam) { - return new BitmapUnion(aggregateParam, child()); + public BitmapUnion withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new BitmapUnion(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java index a98653a6e48a7f..b1dd6345e8654a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java @@ -42,24 +42,20 @@ public BitmapUnionCount(Expression arg0) { super("bitmap_union_count", arg0); } - public BitmapUnionCount(AggregateParam aggregateParam, Expression arg0) { - super("bitmap_union_count", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BitmapType.INSTANCE); } @Override - public BitmapUnionCount withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new BitmapUnionCount(getAggregateParam(), children.get(0)); + public BitmapUnionCount withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public BitmapUnionCount withAggregateParam(AggregateParam aggregateParam) { - return new BitmapUnionCount(aggregateParam, child()); + public BitmapUnionCount withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new BitmapUnionCount(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java index 2e22ce64c3c0f3..00cbbc5a2d5680 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java @@ -48,24 +48,20 @@ public BitmapUnionInt(Expression arg0) { super("bitmap_union_int", arg0); } - public BitmapUnionInt(AggregateParam aggregateParam, Expression arg0) { - super("bitmap_union_int", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BitmapType.INSTANCE); } @Override - public BitmapUnionInt withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new BitmapUnionInt(getAggregateParam(), children.get(0)); + public BitmapUnionInt withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public BitmapUnionInt withAggregateParam(AggregateParam aggregateParam) { - return new BitmapUnionInt(aggregateParam, child()); + public BitmapUnionInt withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new BitmapUnionInt(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 2a3ca1947e1758..cc05cec814a8ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -18,19 +18,18 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.stream.Collectors; /** count agg function. */ public class Count extends AggregateFunction implements AlwaysNotNullable, CustomSignature { @@ -42,19 +41,17 @@ public Count() { this.isStar = true; } - public Count(AggregateParam aggregateParam) { - super("count", aggregateParam); - this.isStar = true; - } - public Count(Expression child) { super("count", child); this.isStar = false; } - public Count(AggregateParam aggregateParam, Expression child) { - super("count", aggregateParam, child); + public Count(boolean isDistinct, Expression arg0, Expression... varArgs) { + super("count", isDistinct, ExpressionUtils.mergeArguments(arg0, varArgs)); this.isStar = false; + if (!isDistinct && arity() > 1) { + throw new AnalysisException("COUNT must have DISTINCT for multiple arguments" + this.toSql()); + } } public boolean isStar() { @@ -62,46 +59,49 @@ public boolean isStar() { } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes); + public FunctionSignature customSignature() { + return FunctionSignature.of(BigIntType.INSTANCE, (List) getArgumentsTypes()); } @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BigIntType.INSTANCE); } @Override - public Count withChildren(List children) { - Preconditions.checkArgument(children.size() == 0 || children.size() == 1); + public Count withDistinctAndChildren(boolean isDistinct, List children) { if (children.size() == 0) { - return this; + if (isDistinct) { + throw new AnalysisException("Can not count distinct empty arguments"); + } + return new Count(); + } else if (children.size() == 1) { + return new Count(isDistinct, children.get(0)); + } else { + return new Count(isDistinct, children.get(0), + children.subList(1, children.size()).toArray(new Expression[0])); } - return new Count(getAggregateParam(), children.get(0)); } @Override - public Count withAggregateParam(AggregateParam aggregateParam) { - if (arity() == 0) { - return new Count(aggregateParam); + public Count withChildren(List children) { + if (children.size() == 0) { + return new Count(); + } + if (children.size() == 1) { + return new Count(isDistinct, children.get(0)); } else { - return new Count(aggregateParam, child(0)); + return new Count(isDistinct, children.get(0), + children.subList(1, children.size()).toArray(new Expression[0])); } } @Override - public String toSql() throws UnboundException { + public String toSql() { if (isStar) { return "count(*)"; } - String args = children() - .stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ")); - if (isDistinct()) { - return "count(distinct " + args + ")"; - } - return "count(" + args + ")"; + return super.toSql(); } @Override @@ -109,14 +109,7 @@ public String toString() { if (isStar) { return "count(*)"; } - String args = children() - .stream() - .map(Expression::toString) - .collect(Collectors.joining(", ")); - if (isDistinct()) { - return "count(distinct " + args + ")"; - } - return "count(" + args + ")"; + return super.toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java index 201e73d6e263b6..5b48359aea5e79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java @@ -41,24 +41,20 @@ public GroupBitmapXor(Expression arg0) { super("group_bitmap_xor", arg0); } - public GroupBitmapXor(AggregateParam aggregateParam, Expression arg0) { - super("group_bitmap_xor", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(BitmapType.INSTANCE); } @Override - public GroupBitmapXor withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new GroupBitmapXor(getAggregateParam(), children.get(0)); + public GroupBitmapXor withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public GroupBitmapXor withAggregateParam(AggregateParam aggregateParam) { - return new GroupBitmapXor(aggregateParam, child()); + public GroupBitmapXor withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new GroupBitmapXor(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java index 5802872b19302b..415d4f9e550a52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java @@ -41,24 +41,20 @@ public HllUnion(Expression arg0) { super("hll_union", arg0); } - public HllUnion(AggregateParam aggregateParam, Expression arg0) { - super("hll_union", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(HllType.INSTANCE); } @Override - public HllUnion withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new HllUnion(getAggregateParam(), children.get(0)); + public HllUnion withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public HllUnion withAggregateParam(AggregateParam aggregateParam) { - return new HllUnion(aggregateParam, child()); + public HllUnion withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new HllUnion(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java index eade606a425ff8..16c8519c3ec204 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java @@ -42,24 +42,20 @@ public HllUnionAgg(Expression arg0) { super("hll_union_agg", arg0); } - public HllUnionAgg(AggregateParam aggregateParam, Expression arg0) { - super("hll_union_agg", aggregateParam, arg0); - } - @Override - protected List intermediateTypes(List argumentTypes, List arguments) { + protected List intermediateTypes() { return ImmutableList.of(HllType.INSTANCE); } @Override - public HllUnionAgg withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new HllUnionAgg(getAggregateParam(), children.get(0)); + public HllUnionAgg withDistinctAndChildren(boolean isDistinct, List children) { + return withChildren(children); } @Override - public HllUnionAgg withAggregateParam(AggregateParam aggregateParam) { - return new HllUnionAgg(aggregateParam, child()); + public HllUnionAgg withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new HllUnionAgg(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java index 5c81e3ca404d0d..25764d8b72a176 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.util.List; @@ -35,29 +36,31 @@ public Max(Expression child) { super("max", child); } - public Max(AggregateParam aggregateParam, Expression child) { - super("max", aggregateParam, child); + public Max(boolean isDistinct, Expression arg) { + super("max", false, arg); } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0)); + public FunctionSignature customSignature() { + DataType dataType = getArgument(0).getDataType(); + return FunctionSignature.ret(dataType).args(dataType); } @Override - protected List intermediateTypes(List argumentTypes, List arguments) { - return argumentTypes; + protected List intermediateTypes() { + return ImmutableList.of(getDataType()); } @Override - public Max withChildren(List children) { + public Max withDistinctAndChildren(boolean isDistinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Max(getAggregateParam(), children.get(0)); + return new Max(isDistinct, children.get(0)); } @Override - public Max withAggregateParam(AggregateParam aggregateParam) { - return new Max(aggregateParam, child()); + public Max withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new Max(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java index 8a5d62fe548160..ea5e05584caaea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.util.List; @@ -36,29 +37,31 @@ public Min(Expression child) { super("min", child); } - public Min(AggregateParam aggregateParam, Expression child) { - super("min", aggregateParam, child); + public Min(boolean isDistinct, Expression arg) { + super("min", false, arg); } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0)); + public FunctionSignature customSignature() { + DataType dataType = getArgument(0).getDataType(); + return FunctionSignature.ret(dataType).args(dataType); } @Override - protected List intermediateTypes(List argumentTypes, List arguments) { - return argumentTypes; + protected List intermediateTypes() { + return ImmutableList.of(getDataType()); } @Override - public Min withChildren(List children) { + public Min withDistinctAndChildren(boolean isDistinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Min(getAggregateParam(), children.get(0)); + return new Min(isDistinct, children.get(0)); } @Override - public Min withAggregateParam(AggregateParam aggregateParam) { - return new Min(aggregateParam, child()); + public Min withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new Min(children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java new file mode 100644 index 00000000000000..9ca2a203ffe061 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** MultiDistinctCount */ +public class MultiDistinctCount extends AggregateFunction + implements AlwaysNotNullable, ExplicitlyCastableSignature { + public MultiDistinctCount(Expression arg0, Expression... varArgs) { + super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs)); + } + + public MultiDistinctCount(boolean isDistinct, Expression arg0, Expression... varArgs) { + super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs)); + } + + @Override + public List getSignatures() { + List argumentsTypes = getArgumentsTypes(); + return ImmutableList.of(FunctionSignature.of(BigIntType.INSTANCE, (List) argumentsTypes)); + } + + @Override + public MultiDistinctCount withChildren(List children) { + Preconditions.checkArgument(children.size() > 0); + if (children.size() > 1) { + return new MultiDistinctCount(children.get(0), + children.subList(1, children.size()).toArray(new Expression[0])); + } else { + return new MultiDistinctCount(children.get(0)); + } + } + + @Override + public MultiDistinctCount withDistinctAndChildren(boolean isDistinct, List children) { + if (children.size() > 1) { + return new MultiDistinctCount(isDistinct, children.get(0), + children.subList(1, children.size()).toArray(new Expression[0])); + } else { + return new MultiDistinctCount(isDistinct, children.get(0)); + } + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitMultiDistinctCount(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java new file mode 100644 index 00000000000000..81927a2d20d48c --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import com.google.common.base.Preconditions; + +import java.util.List; + +/** MultiDistinctSum */ +public class MultiDistinctSum extends AggregateFunction + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public MultiDistinctSum(Expression arg0) { + super("multi_distinct_sum", true, arg0); + } + + public MultiDistinctSum(boolean isDistinct, Expression arg0) { + super("multi_distinct_sum", true, arg0); + } + + @Override + public List getSignatures() { + return new Sum(getArgument(0)).getSignatures(); + } + + @Override + public MultiDistinctSum withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new MultiDistinctSum(children.get(0)); + } + + @Override + public MultiDistinctSum withDistinctAndChildren(boolean isDistinct, List children) { + Preconditions.checkArgument(children.size() == 1); + return new MultiDistinctSum(isDistinct, children.get(0)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitMultiDistinctSum(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index 25385ccb0df90e..853f73f4693b1f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -43,30 +43,32 @@ public Sum(Expression child) { super("sum", child); } - public Sum(AggregateParam aggregateParam, Expression child) { - super("sum", aggregateParam, child); + public Sum(boolean isDistinct, Expression child) { + super("sum", isDistinct, child); } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - DataType implicitCastType = implicitCast(argumentTypes.get(0)); - return FunctionSignature.ret(implicitCastType).args(NumericType.INSTANCE); + public FunctionSignature customSignature() { + DataType originDataType = getArgument(0).getDataType(); + DataType implicitCastType = implicitCast(originDataType); + return FunctionSignature.ret(implicitCastType).args(originDataType); } @Override - protected List intermediateTypes(List argumentTypes, List arguments) { - return ImmutableList.of(getFinalType()); + protected List intermediateTypes() { + return ImmutableList.of(getDataType()); } @Override - public Sum withChildren(List children) { + public Sum withDistinctAndChildren(boolean isDistinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Sum(getAggregateParam(), children.get(0)); + return new Sum(isDistinct, children.get(0)); } @Override - public Sum withAggregateParam(AggregateParam aggregateParam) { - return new Sum(aggregateParam, child()); + public Sum withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new Sum(isDistinct, children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java index 57c799dbfd8238..73d12005af09ff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java @@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShape; import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes; import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -41,9 +40,9 @@ public Grouping(Expression child) { } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { + public FunctionSignature customSignature() { // any argument type - return FunctionSignature.ret(BigIntType.INSTANCE).args(argumentTypes.get(0)); + return FunctionSignature.ret(BigIntType.INSTANCE).args(getArgument(0).getDataType()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java index 218afa085477e7..ee5a54187b5d30 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java @@ -23,7 +23,6 @@ import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShape; import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes; import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.BitUtils; import org.apache.doris.nereids.util.ExpressionUtils; @@ -47,9 +46,9 @@ private GroupingId(List children) { } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { + public FunctionSignature customSignature() { // any arguments type - return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes); + return FunctionSignature.of(BigIntType.INSTANCE, (List) getArgumentsTypes()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java index 148b7b9e9b0313..7ffc68ec141fa0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java @@ -104,7 +104,7 @@ private DataType getWiderType(List argumentsTypes) { } @Override - protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + protected FunctionSignature computeSignature(FunctionSignature signature) { DataType widerType = getWiderType(signature.argumentsTypes); List newArgumentsTypes = new ImmutableList.Builder() .add(signature.argumentsTypes.get(0)) @@ -113,7 +113,7 @@ protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + protected FunctionSignature computeSignature(FunctionSignature signature) { /* * The return type of str_to_date depends on whether the time part is included in the format. * If included, it is datetime, otherwise it is date. @@ -83,8 +83,8 @@ protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { - Optional length = arguments.size() == 3 - ? Optional.of(arguments.get(2)) : Optional.empty(); + protected FunctionSignature computeSignature(FunctionSignature signature) { + Optional length = arity() == 3 + ? Optional.of(getArgument(2)) + : Optional.empty(); DataType returnType = VarcharType.SYSTEM_DEFAULT; if (length.isPresent() && length.get() instanceof IntegerLiteral) { returnType = VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java index 6831b3f365ee76..fc94a3dc091e69 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java @@ -23,12 +23,12 @@ import org.apache.doris.common.Id; import org.apache.doris.common.NereidsException; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.StatsDeriveResult; import org.apache.doris.tablefunction.NumbersTableValuedFunction; @@ -47,8 +47,8 @@ public Numbers(TVFProperties properties) { } @Override - public FunctionSignature customSignature(List argumentTypes, List arguments) { - return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes); + public FunctionSignature customSignature() { + return FunctionSignature.of(BigIntType.INSTANCE, (List) getArgumentsTypes()); } @Override @@ -84,6 +84,15 @@ public R accept(ExpressionVisitor visitor, C context) { return visitor.visitNumbers(this, context); } + @Override + public PhysicalProperties getPhysicalProperties() { + String backendNum = getTVFProperties().getMap().getOrDefault(NumbersTableValuedFunction.BACKEND_NUM, "1"); + if (backendNum.trim().equals("1")) { + return PhysicalProperties.GATHER; + } + return PhysicalProperties.ANY; + } + @Override public Numbers withChildren(List children) { Preconditions.checkArgument(children().size() == 1 diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java index e570abe49db8d6..a1f83467c4290e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.FunctionGenTable; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; @@ -89,6 +90,10 @@ public boolean nullable() { throw new UnboundException("TableValuedFunction can not compute nullable"); } + public PhysicalProperties getPhysicalProperties() { + return PhysicalProperties.ANY; + } + @Override public DataType getDataType() throws UnboundException { throw new UnboundException("TableValuedFunction can not compute data type"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/IntervalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java similarity index 85% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/IntervalLiteral.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java index 79d0471b53b148..101004abafb198 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/IntervalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java @@ -17,26 +17,26 @@ package org.apache.doris.nereids.trees.expressions.literal; -import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateType; /** * Interval for timestamp calculation. */ -public class IntervalLiteral extends Expression implements AlwaysNotNullable { +public class Interval extends Expression implements AlwaysNotNullable { private final Expression value; private final TimeUnit timeUnit; - public IntervalLiteral(Expression value, String desc) { + public Interval(Expression value, String desc) { this.value = value; this.timeUnit = TimeUnit.valueOf(desc.toUpperCase()); } @Override - public DataType getDataType() throws UnboundException { + public DataType getDataType() { return DateType.INSTANCE; } @@ -48,6 +48,11 @@ public TimeUnit timeUnit() { return timeUnit; } + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitInterval(this, context); + } + /** * Supported time unit. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java index 561669098a4aad..ba70800e819975 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java @@ -19,8 +19,11 @@ import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.NullType; +import java.util.Objects; + /** * Represents Null literal */ @@ -28,10 +31,16 @@ public class NullLiteral extends Literal { public static final NullLiteral INSTANCE = new NullLiteral(); + private DataType dataType; + public NullLiteral() { super(NullType.INSTANCE); } + public NullLiteral(DataType dataType) { + super(dataType); + } + @Override public Object getValue() { return null; @@ -56,4 +65,24 @@ public LiteralExpr toLegacyLiteral() { public double getDouble() { return 0; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + NullLiteral that = (NullLiteral) o; + return Objects.equals(dataType, that.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), dataType); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 32257be5b816f8..b2e0dd83cdb1e8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; /** AggregateFunctionVisitor. */ @@ -44,6 +46,14 @@ default R visitMin(Min min, C context) { return visitAggregateFunction(min, context); } + default R visitMultiDistinctCount(MultiDistinctCount multiDistinctCount, C context) { + return visitAggregateFunction(multiDistinctCount, context); + } + + default R visitMultiDistinctSum(MultiDistinctSum multiDistinctSum, C context) { + return visitAggregateFunction(multiDistinctSum, context); + } + default R visitSum(Sum sum, C context) { return visitAggregateFunction(sum, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java index 62e990c49847db..d3324d2bad953c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java @@ -21,7 +21,9 @@ import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.analyzer.UnboundStar; +import org.apache.doris.nereids.rules.analysis.BindSlotReference.BoundStar; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; @@ -62,6 +64,7 @@ import org.apache.doris.nereids.trees.expressions.StringRegexPredicate; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.UnaryArithmetic; import org.apache.doris.nereids.trees.expressions.UnaryOperator; @@ -81,6 +84,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Interval; import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -116,6 +120,10 @@ public R visitBoundFunction(BoundFunction boundFunction, C context) { return visit(boundFunction, context); } + public R visitAggregateExpression(AggregateExpression aggregateExpression, C context) { + return visit(aggregateExpression, context); + } + public R visitAlias(Alias alias, C context) { return visitNamedExpression(alias, context); } @@ -364,6 +372,18 @@ public R visitVirtualReference(VirtualSlotReference virtualSlotReference, C cont return visit(virtualSlotReference, context); } + public R visitTVFProperties(TVFProperties tvfProperties, C context) { + return visit(tvfProperties, context); + } + + public R visitInterval(Interval interval, C context) { + return visit(interval, context); + } + + public R visitBoundStar(BoundStar boundStar, C context) { + return visit(boundStar, context); + } + /* ******************************************************************************************** * Unbound expressions * ********************************************************************************************/ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AggMode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AggMode.java new file mode 100644 index 00000000000000..6fd1a477a5ed6a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AggMode.java @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans; + +/** AggregateMode */ +public enum AggMode { + INPUT_TO_BUFFER(true, false, false), + INPUT_TO_RESULT(false, false, true), + BUFFER_TO_BUFFER(true, true, false), + BUFFER_TO_RESULT(false, true, true); + + public final boolean productAggregateBuffer; + public final boolean consumeAggregateBuffer; + + public final boolean isFinalPhase; + + AggMode(boolean productAggregateBuffer, boolean consumeAggregateBuffer, boolean isFinalPhase) { + this.productAggregateBuffer = productAggregateBuffer; + this.consumeAggregateBuffer = consumeAggregateBuffer; + this.isFinalPhase = isFinalPhase; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PushDownAggOperator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PushDownAggOperator.java deleted file mode 100644 index a73839c8efcfa8..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PushDownAggOperator.java +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.trees.plans; - -import org.apache.doris.thrift.TPushAggOp; - -/** - * use for push down agg without group by exprs to olap scan. - */ -public class PushDownAggOperator { - public static PushDownAggOperator NONE = new PushDownAggOperator(TPushAggOp.NONE); - public static PushDownAggOperator MIN_MAX = new PushDownAggOperator(TPushAggOp.MINMAX); - public static PushDownAggOperator COUNT = new PushDownAggOperator(TPushAggOp.COUNT); - public static PushDownAggOperator MIX = new PushDownAggOperator(TPushAggOp.MIX); - - private final TPushAggOp thriftOperator; - - private PushDownAggOperator(TPushAggOp thriftOperator) { - this.thriftOperator = thriftOperator; - } - - /** - * merge operator. - */ - public PushDownAggOperator merge(String functionName) { - PushDownAggOperator newOne; - if ("COUNT".equalsIgnoreCase(functionName)) { - newOne = COUNT; - } else { - newOne = MIN_MAX; - } - if (this == NONE || this == newOne) { - return newOne; - } else { - return MIX; - } - } - - public TPushAggOp toThrift() { - return thriftOperator; - } - - public boolean containsMinMax() { - return this == MIN_MAX || this == MIX; - } - - public boolean containsCount() { - return this == COUNT || this == MIX; - } - - @Override - public String toString() { - return thriftOperator.toString(); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 362274fc5285c3..7163d86106b23e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -19,10 +19,15 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.UnaryPlan; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; import java.util.List; +import java.util.Set; /** * Common interface for logical/physical Aggregate. @@ -37,4 +42,15 @@ public interface Aggregate extends UnaryPlan withChildren(List children); + + default Set getAggregateFunctions() { + return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance); + } + + default Set getDistinctArguments() { + return getAggregateFunctions().stream() + .filter(AggregateFunction::isDistinct) + .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .collect(ImmutableSet.toImmutableSet()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/OlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/OlapScan.java new file mode 100644 index 00000000000000..d0074c4119e08f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/OlapScan.java @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.algebra; + +import org.apache.doris.catalog.OlapTable; + +import java.util.List; + +/** OlapScan */ +public interface OlapScan extends Scan { + OlapTable getTable(); + + long getSelectedIndexId(); + + List getSelectedPartitionIds(); + + List getSelectedTabletIds(); + + /** getScanTabletNum */ + default int getScanTabletNum() { + List selectedTabletIds = getSelectedTabletIds(); + if (selectedTabletIds.size() > 0) { + return selectedTabletIds.size(); + } + + OlapTable olapTable = getTable(); + Integer selectTabletNumInPartitions = getSelectedPartitionIds().stream() + .map(partitionId -> olapTable.getPartition(partitionId)) + .map(partition -> partition.getDistributionInfo().getBucketNum()) + .reduce((b1, b2) -> b1 + b2) + .orElse(0); + if (selectTabletNumInPartitions > 0) { + return selectTabletNumInPartitions; + } + + // all partition's tablet + return olapTable.getAllPartitions() + .stream() + .map(partition -> partition.getDistributionInfo().getBucketNum()) + .reduce((b1, b2) -> b1 + b2) + .orElse(0); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java index ee2b6aacf357db..5ee0370134e6aa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java @@ -17,13 +17,19 @@ package org.apache.doris.nereids.trees.plans.algebra; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -81,6 +87,41 @@ default List mergeProjections(Project childProject) { .collect(Collectors.toList()); } + /** + * find projects, if not found the slot, then throw AnalysisException + */ + static List findProject( + Collection slotReferences, + List projects) throws AnalysisException { + Map exprIdToProject = projects.stream() + .collect(ImmutableMap.toImmutableMap(p -> p.getExprId(), p -> p)); + + return slotReferences.stream() + .map(slot -> { + ExprId exprId = slot.getExprId(); + NamedExpression project = exprIdToProject.get(exprId); + if (project == null) { + throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects); + } + return project; + }) + .collect(ImmutableList.toImmutableList()); + } + + /** + * findUsedProject. if not found the slot, then skip it + */ + static List filterUsedOutputs( + Collection slotReferences, List childOutput) { + Map exprIdToChildOutput = childOutput.stream() + .collect(ImmutableMap.toImmutableMap(p -> p.getExprId(), p -> p)); + + return slotReferences.stream() + .map(slot -> exprIdToChildOutput.get(slot.getExprId())) + .filter(project -> project != null) + .collect(ImmutableList.toImmutableList()); + } + /** * replace alias */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index e94d32dfbac80c..15571453921b35 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -22,8 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; @@ -36,7 +34,6 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; /** * Logical Aggregate plan. @@ -56,23 +53,11 @@ public class LogicalAggregate extends LogicalUnary implements Aggregate { - private final boolean disassembled; private final boolean normalized; - private final AggPhase aggPhase; - private final ImmutableList groupByExpressions; - private final ImmutableList outputExpressions; - // TODO: we should decide partition expression according to cost. - private final Optional> partitionExpressions; + private final List groupByExpressions; + private final List outputExpressions; - // use for scenes containing distinct agg - // 1. If there is LOCAL only, LOCAL is the final phase - // 2. If there are LOCAL and GLOBAL phases, global is the final phase - // 3. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase - // 4. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases, - // DISTINCT_GLOBAL is the final phase - private final boolean isFinalPhase; - - // When there are goruping sets/rollup/cube, LogicalAgg is generated by LogicalRepeat. + // When there are grouping sets/rollup/cube, LogicalAgg is generated by LogicalRepeat. private final Optional sourceRepeat; /** @@ -82,8 +67,8 @@ public LogicalAggregate( List groupByExpressions, List outputExpressions, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, Optional.empty(), false, - false, true, AggPhase.LOCAL, Optional.empty(), child); + this(groupByExpressions, outputExpressions, + false, Optional.empty(), child); } /** @@ -95,35 +80,17 @@ public LogicalAggregate( List outputExpressions, Optional sourceRepeat, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, Optional.empty(), false, - false, true, AggPhase.LOCAL, sourceRepeat, child); + this(groupByExpressions, outputExpressions, false, sourceRepeat, child); } public LogicalAggregate( List groupByExpressions, List outputExpressions, - boolean disassembled, boolean normalized, - boolean isFinalPhase, - AggPhase aggPhase, Optional sourceRepeat, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, Optional.empty(), disassembled, normalized, - isFinalPhase, aggPhase, sourceRepeat, Optional.empty(), Optional.empty(), child); - } - - public LogicalAggregate( - List groupByExpressions, - List outputExpressions, - Optional> partitionExpressions, - boolean disassembled, - boolean normalized, - boolean isFinalPhase, - AggPhase aggPhase, - Optional sourceRepeat, - CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, partitionExpressions, disassembled, normalized, isFinalPhase, - aggPhase, sourceRepeat, Optional.empty(), Optional.empty(), child); + this(groupByExpressions, outputExpressions, normalized, sourceRepeat, + Optional.empty(), Optional.empty(), child); } /** @@ -132,11 +99,7 @@ public LogicalAggregate( public LogicalAggregate( List groupByExpressions, List outputExpressions, - Optional> partitionExpressions, - boolean disassembled, boolean normalized, - boolean isFinalPhase, - AggPhase aggPhase, Optional sourceRepeat, Optional groupExpression, Optional logicalProperties, @@ -144,11 +107,7 @@ public LogicalAggregate( super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, child); this.groupByExpressions = ImmutableList.copyOf(groupByExpressions); this.outputExpressions = ImmutableList.copyOf(outputExpressions); - this.partitionExpressions = partitionExpressions.map(ImmutableList::copyOf); - this.disassembled = disassembled; this.normalized = normalized; - this.isFinalPhase = isFinalPhase; - this.aggPhase = aggPhase; this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat cannot be null"); } @@ -160,14 +119,6 @@ public List getOutputExpressions() { return outputExpressions; } - public List getPartitionExpressions() { - return partitionExpressions.orElse(groupByExpressions); - } - - public AggPhase getAggPhase() { - return aggPhase; - } - public Optional getSourceRepeat() { return sourceRepeat; } @@ -179,9 +130,8 @@ public boolean hasRepeat() { @Override public String toString() { return Utils.toSqlString("LogicalAggregate", - "phase", aggPhase, - "outputExpr", outputExpressions, "groupByExpr", groupByExpressions, + "outputExpr", outputExpressions, "hasRepeat", sourceRepeat.isPresent() ); } @@ -206,49 +156,10 @@ public List getExpressions() { .build(); } - public boolean isLocal() { - return aggPhase.isLocal(); - } - - public boolean isDisassembled() { - return disassembled; - } - - /** - * Check if disassembling is possible - * @return true means that disassembling is possible - */ - public boolean needDistinctDisassemble() { - // It is sufficient to split an aggregate function with a groupBy expression into two stages(Local & Global), - // no need for four stages(Local & Global & Distinct Local & Distinct Global) - if (!isFinalPhase || aggPhase != AggPhase.LOCAL) { - return false; - } - int distinctFunctionCount = 0; - for (NamedExpression originOutputExpr : outputExpressions) { - Set aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (aggregateFunction.isDistinct()) { - distinctFunctionCount++; - if (distinctFunctionCount > 1) { - return false; - } - } - } - } - // Only one distinct function is supported - return distinctFunctionCount == 1; - } - public boolean isNormalized() { return normalized; } - public boolean isFinalPhase() { - return isFinalPhase; - } - /** * Determine the equality with another plan */ @@ -262,51 +173,50 @@ public boolean equals(Object o) { LogicalAggregate that = (LogicalAggregate) o; return Objects.equals(groupByExpressions, that.groupByExpressions) && Objects.equals(outputExpressions, that.outputExpressions) - && Objects.equals(partitionExpressions, that.partitionExpressions) - && aggPhase == that.aggPhase - && disassembled == that.disassembled && normalized == that.normalized - && isFinalPhase == that.isFinalPhase && Objects.equals(sourceRepeat, that.sourceRepeat); } @Override public int hashCode() { - return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, normalized, disassembled, isFinalPhase, sourceRepeat); + return Objects.hash(groupByExpressions, outputExpressions, normalized, sourceRepeat); } @Override public LogicalAggregate withChildren(List children) { Preconditions.checkArgument(children.size() == 1); - return new LogicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions.map(List.class::cast), - disassembled, normalized, isFinalPhase, aggPhase, sourceRepeat, children.get(0)); + return new LogicalAggregate<>(groupByExpressions, outputExpressions, + normalized, sourceRepeat, children.get(0)); } @Override public LogicalAggregate withGroupExpression(Optional groupExpression) { - return new LogicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions.map(List.class::cast), - disassembled, normalized, isFinalPhase, aggPhase, sourceRepeat, - groupExpression, Optional.of(getLogicalProperties()), children.get(0)); + return new LogicalAggregate<>(groupByExpressions, outputExpressions, + normalized, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0)); } @Override public LogicalAggregate withLogicalProperties(Optional logicalProperties) { - return new LogicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions.map(List.class::cast), - disassembled, normalized, isFinalPhase, aggPhase, sourceRepeat, + return new LogicalAggregate<>(groupByExpressions, outputExpressions, + normalized, sourceRepeat, Optional.empty(), logicalProperties, children.get(0)); } public LogicalAggregate withGroupByAndOutput(List groupByExprList, List outputExpressionList) { - return new LogicalAggregate<>(groupByExprList, outputExpressionList, partitionExpressions.map(List.class::cast), - disassembled, normalized, isFinalPhase, aggPhase, sourceRepeat, child()); + return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, sourceRepeat, child()); } @Override public LogicalAggregate withAggOutput(List newOutput) { - return new LogicalAggregate<>(groupByExpressions, newOutput, partitionExpressions.map(List.class::cast), - disassembled, normalized, isFinalPhase, aggPhase, sourceRepeat, Optional.empty(), - Optional.empty(), child()); + return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, + sourceRepeat, Optional.empty(), Optional.empty(), child()); + } + + public LogicalAggregate withNormalized(List normalizedGroupBy, + List normalizedOutput, Plan normalizedChild) { + return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, + sourceRepeat, Optional.empty(), + Optional.empty(), normalizedChild); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java index 353f9c45851e12..ffefee32428d5b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java @@ -74,9 +74,7 @@ public List computeOutput() { @Override public String toString() { - return Utils.toSqlString("LogicalCheckPolicy", - "child", child() - ); + return Utils.toSqlString("LogicalCheckPolicy"); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalLimit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalLimit.java index 94ea4616e490be..48f94727d7870b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalLimit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalLimit.java @@ -43,7 +43,9 @@ * limit: 10 * offset 100 */ -public class LogicalLimit extends LogicalUnary implements Limit { +public class LogicalLimit + extends LogicalUnary + implements Limit { private final long limit; private final long offset; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index a6bb23029cb095..4b19cd2d634355 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -27,9 +27,9 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.PreAggStatus; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation; +import org.apache.doris.nereids.trees.plans.algebra.OlapScan; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; @@ -45,7 +45,7 @@ /** * Logical OlapScan. */ -public class LogicalOlapScan extends LogicalRelation implements CatalogRelation { +public class LogicalOlapScan extends LogicalRelation implements CatalogRelation, OlapScan { private final long selectedIndexId; private final List selectedTabletIds; @@ -57,9 +57,6 @@ public class LogicalOlapScan extends LogicalRelation implements CatalogRelation private final PreAggStatus preAggStatus; - private final boolean aggPushed; - private final PushDownAggOperator pushDownAggOperator; - public LogicalOlapScan(RelationId id, OlapTable table) { this(id, table, ImmutableList.of()); } @@ -67,13 +64,13 @@ public LogicalOlapScan(RelationId id, OlapTable table) { public LogicalOlapScan(RelationId id, OlapTable table, List qualifier) { this(id, table, qualifier, Optional.empty(), Optional.empty(), table.getPartitionIds(), false, ImmutableList.of(), false, - ImmutableList.of(), false, PreAggStatus.on(), false, PushDownAggOperator.NONE); + ImmutableList.of(), false, PreAggStatus.on()); } public LogicalOlapScan(RelationId id, Table table, List qualifier) { this(id, table, qualifier, Optional.empty(), Optional.empty(), ((OlapTable) table).getPartitionIds(), false, ImmutableList.of(), false, - ImmutableList.of(), false, PreAggStatus.on(), false, PushDownAggOperator.NONE); + ImmutableList.of(), false, PreAggStatus.on()); } /** @@ -83,8 +80,7 @@ public LogicalOlapScan(RelationId id, Table table, List qualifier, Optional groupExpression, Optional logicalProperties, List selectedPartitionIds, boolean partitionPruned, List selectedTabletIds, boolean tabletPruned, - List candidateIndexIds, boolean indexSelected, PreAggStatus preAggStatus, - boolean aggPushed, PushDownAggOperator pushDownAggOperator) { + List candidateIndexIds, boolean indexSelected, PreAggStatus preAggStatus) { super(id, PlanType.LOGICAL_OLAP_SCAN, table, qualifier, groupExpression, logicalProperties, selectedPartitionIds); // TODO: use CBO manner to select best index id, according to index's statistics info, @@ -97,8 +93,6 @@ public LogicalOlapScan(RelationId id, Table table, List qualifier, this.candidateIndexIds = ImmutableList.copyOf(candidateIndexIds); this.indexSelected = indexSelected; this.preAggStatus = preAggStatus; - this.aggPushed = aggPushed; - this.pushDownAggOperator = pushDownAggOperator; } @Override @@ -121,8 +115,7 @@ public String toString() { "output", getOutput(), "candidateIndexIds", candidateIndexIds, "selectedIndexId", selectedIndexId, - "preAgg", preAggStatus, - "pushAgg", pushDownAggOperator + "preAgg", preAggStatus ); } @@ -136,51 +129,44 @@ public boolean equals(Object o) { } return Objects.equals(selectedPartitionIds, ((LogicalOlapScan) o).selectedPartitionIds) && Objects.equals(candidateIndexIds, ((LogicalOlapScan) o).candidateIndexIds) - && Objects.equals(selectedTabletIds, ((LogicalOlapScan) o).selectedTabletIds) - && Objects.equals(pushDownAggOperator, ((LogicalOlapScan) o).pushDownAggOperator); + && Objects.equals(selectedTabletIds, ((LogicalOlapScan) o).selectedTabletIds); } @Override public int hashCode() { - return Objects.hash(id, selectedPartitionIds, candidateIndexIds, selectedTabletIds, pushDownAggOperator); + return Objects.hash(id, selectedPartitionIds, candidateIndexIds, selectedTabletIds); } @Override public Plan withGroupExpression(Optional groupExpression) { return new LogicalOlapScan(id, table, qualifier, groupExpression, Optional.of(getLogicalProperties()), selectedPartitionIds, partitionPruned, selectedTabletIds, tabletPruned, - candidateIndexIds, indexSelected, preAggStatus, aggPushed, pushDownAggOperator); + candidateIndexIds, indexSelected, preAggStatus); } @Override public LogicalOlapScan withLogicalProperties(Optional logicalProperties) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), logicalProperties, selectedPartitionIds, partitionPruned, selectedTabletIds, tabletPruned, - candidateIndexIds, indexSelected, preAggStatus, aggPushed, pushDownAggOperator); + candidateIndexIds, indexSelected, preAggStatus); } public LogicalOlapScan withSelectedPartitionIds(List selectedPartitionIds) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), selectedPartitionIds, true, selectedTabletIds, tabletPruned, - candidateIndexIds, indexSelected, preAggStatus, aggPushed, pushDownAggOperator); + candidateIndexIds, indexSelected, preAggStatus); } public LogicalOlapScan withMaterializedIndexSelected(PreAggStatus preAgg, List candidateIndexIds) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), selectedPartitionIds, partitionPruned, selectedTabletIds, tabletPruned, - candidateIndexIds, true, preAgg, aggPushed, pushDownAggOperator); + candidateIndexIds, true, preAgg); } public LogicalOlapScan withSelectedTabletIds(List selectedTabletIds) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), selectedPartitionIds, partitionPruned, selectedTabletIds, true, - candidateIndexIds, indexSelected, preAggStatus, aggPushed, pushDownAggOperator); - } - - public LogicalOlapScan withPushDownAggregateOperator(PushDownAggOperator pushDownAggOperator) { - return new LogicalOlapScan(id, table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), - selectedPartitionIds, partitionPruned, selectedTabletIds, true, - candidateIndexIds, indexSelected, preAggStatus, true, pushDownAggOperator); + candidateIndexIds, indexSelected, preAggStatus); } @Override @@ -200,6 +186,7 @@ public List getSelectedTabletIds() { return selectedTabletIds; } + @Override public long getSelectedIndexId() { return selectedIndexId; } @@ -212,14 +199,6 @@ public PreAggStatus getPreAggStatus() { return preAggStatus; } - public boolean isAggPushed() { - return aggPushed; - } - - public PushDownAggOperator getPushDownAggOperator() { - return pushDownAggOperator; - } - @VisibleForTesting public Optional getSelectedMaterializedIndexName() { return indexSelected ? Optional.ofNullable(((OlapTable) table).getIndexNameById(selectedIndexId)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java deleted file mode 100644 index 960d47006584b3..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java +++ /dev/null @@ -1,220 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.trees.plans.physical; - -import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.plans.AggPhase; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.PlanType; -import org.apache.doris.nereids.trees.plans.algebra.Aggregate; -import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; -import org.apache.doris.nereids.util.Utils; -import org.apache.doris.statistics.StatsDeriveResult; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** - * Physical aggregation plan. - * TODO: change class name to PhysicalHashAggregate - */ -public class PhysicalAggregate extends PhysicalUnary - implements Aggregate { - - private final ImmutableList groupByExpressions; - - private final ImmutableList outputExpressions; - - private final ImmutableList partitionExpressions; - - private final AggPhase aggPhase; - - private final boolean usingStream; - - // use for scenes containing distinct agg - // 1. If there are LOCAL and GLOBAL phases, global is the final phase - // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase - // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases, - // DISTINCT_GLOBAL is the final phase - private final boolean isFinalPhase; - - public PhysicalAggregate(List groupByExpressions, List outputExpressions, - List partitionExpressions, AggPhase aggPhase, boolean usingStream, - boolean isFinalPhase, LogicalProperties logicalProperties, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream, - isFinalPhase, Optional.empty(), logicalProperties, child); - } - - /** - * Constructor of PhysicalAggNode. - * - * @param groupByExpressions group by expr list. - * @param outputExpressions agg expr list. - * @param partitionExpressions partition expr list, used for analytic agg. - * @param usingStream whether it's stream agg. - */ - public PhysicalAggregate(List groupByExpressions, List outputExpressions, - List partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase, - Optional groupExpression, LogicalProperties logicalProperties, - CHILD_TYPE child) { - super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, child); - this.groupByExpressions = ImmutableList.copyOf(groupByExpressions); - this.outputExpressions = ImmutableList.copyOf(outputExpressions); - this.aggPhase = aggPhase; - this.partitionExpressions = ImmutableList.copyOf(partitionExpressions); - this.usingStream = usingStream; - this.isFinalPhase = isFinalPhase; - } - - /** - * Constructor of PhysicalAggNode. - * - * @param groupByExpressions group by expr list. - * @param outputExpressions agg expr list. - * @param partitionExpressions partition expr list, used for analytic agg. - * @param usingStream whether it's stream agg. - */ - public PhysicalAggregate(List groupByExpressions, List outputExpressions, - List partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase, - Optional groupExpression, LogicalProperties logicalProperties, - PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult, CHILD_TYPE child) { - super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, physicalProperties, statsDeriveResult, - child); - this.groupByExpressions = ImmutableList.copyOf(groupByExpressions); - this.outputExpressions = ImmutableList.copyOf(outputExpressions); - this.aggPhase = aggPhase; - this.partitionExpressions = ImmutableList.copyOf(partitionExpressions); - this.usingStream = usingStream; - this.isFinalPhase = isFinalPhase; - } - - public AggPhase getAggPhase() { - return aggPhase; - } - - public List getGroupByExpressions() { - return groupByExpressions; - } - - public List getOutputExpressions() { - return outputExpressions; - } - - public boolean isFinalPhase() { - return isFinalPhase; - } - - public boolean isUsingStream() { - return usingStream; - } - - public List getPartitionExpressions() { - return partitionExpressions; - } - - @Override - public R accept(PlanVisitor visitor, C context) { - return visitor.visitPhysicalAggregate(this, context); - } - - @Override - public List getExpressions() { - return new ImmutableList.Builder() - .addAll(groupByExpressions) - .addAll(outputExpressions) - .addAll(partitionExpressions).build(); - } - - @Override - public String toString() { - return Utils.toSqlString("PhysicalAggregate", - "phase", aggPhase, - "outputExpr", outputExpressions, - "groupByExpr", groupByExpressions, - "partitionExpr", partitionExpressions, - "stats", statsDeriveResult - ); - } - - /** - * Determine the equality with another operator - */ - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - PhysicalAggregate that = (PhysicalAggregate) o; - return Objects.equals(groupByExpressions, that.groupByExpressions) - && Objects.equals(outputExpressions, that.outputExpressions) - && Objects.equals(partitionExpressions, that.partitionExpressions) - && usingStream == that.usingStream - && aggPhase == that.aggPhase - && isFinalPhase == that.isFinalPhase; - } - - @Override - public int hashCode() { - return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, usingStream, isFinalPhase); - } - - @Override - public PhysicalAggregate withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, usingStream, isFinalPhase, getLogicalProperties(), children.get(0)); - } - - @Override - public PhysicalAggregate withGroupExpression(Optional groupExpression) { - return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, usingStream, isFinalPhase, groupExpression, getLogicalProperties(), child()); - } - - @Override - public PhysicalAggregate withLogicalProperties(Optional logicalProperties) { - return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, usingStream, isFinalPhase, Optional.empty(), logicalProperties.get(), child()); - } - - @Override - public PhysicalAggregate withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties, - StatsDeriveResult statsDeriveResult) { - return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, - aggPhase, usingStream, isFinalPhase, - Optional.empty(), getLogicalProperties(), physicalProperties, statsDeriveResult, child()); - } - - @Override - public PhysicalAggregate withAggOutput(List newOutput) { - return new PhysicalAggregate<>(groupByExpressions, newOutput, partitionExpressions, - aggPhase, usingStream, isFinalPhase, Optional.empty(), getLogicalProperties(), - physicalProperties, statsDeriveResult, child()); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java new file mode 100644 index 00000000000000..4d4f1d9a240e5a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.physical; + +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.properties.RequireProperties; +import org.apache.doris.nereids.properties.RequirePropertiesSupplier; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.plans.AggMode; +import org.apache.doris.nereids.trees.plans.AggPhase; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.PlanType; +import org.apache.doris.nereids.trees.plans.algebra.Aggregate; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.Utils; +import org.apache.doris.statistics.StatsDeriveResult; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * Physical hash aggregation plan. + */ +public class PhysicalHashAggregate extends PhysicalUnary + implements Aggregate, RequirePropertiesSupplier> { + + private final List groupByExpressions; + + private final List outputExpressions; + + private final Optional> partitionExpressions; + + private final AggregateParam aggregateParam; + + private final boolean maybeUsingStream; + + private final RequireProperties requireProperties; + + public PhysicalHashAggregate(List groupByExpressions, List outputExpressions, + AggregateParam aggregateParam, boolean maybeUsingStream, LogicalProperties logicalProperties, + RequireProperties requireProperties, CHILD_TYPE child) { + this(groupByExpressions, outputExpressions, Optional.empty(), aggregateParam, + maybeUsingStream, Optional.empty(), logicalProperties, requireProperties, child); + } + + public PhysicalHashAggregate(List groupByExpressions, List outputExpressions, + Optional> partitionExpressions, AggregateParam aggregateParam, + boolean maybeUsingStream, LogicalProperties logicalProperties, RequireProperties requireProperties, + CHILD_TYPE child) { + this(groupByExpressions, outputExpressions, partitionExpressions, aggregateParam, + maybeUsingStream, Optional.empty(), logicalProperties, requireProperties, child); + } + + /** + * Constructor of PhysicalAggNode. + * + * @param groupByExpressions group by expr list. + * @param outputExpressions agg expr list. + * @param partitionExpressions hash distribute expr list + * @param maybeUsingStream whether it's stream agg. + * @param requireProperties the request physical properties + */ + public PhysicalHashAggregate(List groupByExpressions, List outputExpressions, + Optional> partitionExpressions, AggregateParam aggregateParam, boolean maybeUsingStream, + Optional groupExpression, LogicalProperties logicalProperties, + RequireProperties requireProperties, CHILD_TYPE child) { + super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, child); + this.groupByExpressions = ImmutableList.copyOf( + Objects.requireNonNull(groupByExpressions, "groupByExpressions cannot be null")); + this.outputExpressions = ImmutableList.copyOf( + Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null")); + this.partitionExpressions = Objects.requireNonNull( + partitionExpressions, "partitionExpressions cannot be null"); + this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregate param cannot be null"); + this.maybeUsingStream = maybeUsingStream; + this.requireProperties = Objects.requireNonNull(requireProperties, "requireProperties cannot be null"); + } + + /** + * Constructor of PhysicalAggNode. + * + * @param groupByExpressions group by expr list. + * @param outputExpressions agg expr list. + * @param partitionExpressions hash distribute expr list + * @param maybeUsingStream whether it's stream agg. + * @param requireProperties the request physical properties + */ + public PhysicalHashAggregate(List groupByExpressions, List outputExpressions, + Optional> partitionExpressions, AggregateParam aggregateParam, boolean maybeUsingStream, + Optional groupExpression, LogicalProperties logicalProperties, + RequireProperties requireProperties, PhysicalProperties physicalProperties, + StatsDeriveResult statsDeriveResult, CHILD_TYPE child) { + super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, physicalProperties, statsDeriveResult, + child); + this.groupByExpressions = ImmutableList.copyOf( + Objects.requireNonNull(groupByExpressions, "groupByExpressions cannot be null")); + this.outputExpressions = ImmutableList.copyOf( + Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null")); + this.partitionExpressions = Objects.requireNonNull( + partitionExpressions, "partitionExpressions cannot be null"); + this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregate param cannot be null"); + this.maybeUsingStream = maybeUsingStream; + this.requireProperties = Objects.requireNonNull(requireProperties, "requireProperties cannot be null"); + } + + public List getGroupByExpressions() { + return groupByExpressions; + } + + public List getOutputExpressions() { + return outputExpressions; + } + + public Optional> getPartitionExpressions() { + return partitionExpressions; + } + + public AggregateParam getAggregateParam() { + return aggregateParam; + } + + public AggPhase getAggPhase() { + return aggregateParam.aggPhase; + } + + public AggMode getAggMode() { + return aggregateParam.aggMode; + } + + public boolean isMaybeUsingStream() { + return maybeUsingStream; + } + + @Override + public RequireProperties getRequireProperties() { + return requireProperties; + } + + @Override + public PhysicalHashAggregate withRequireAndChildren( + RequireProperties requireProperties, List children) { + Preconditions.checkArgument(children.size() == 1); + return withRequirePropertiesAndChild(requireProperties, children.get(0)); + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visitPhysicalHashAggregate(this, context); + } + + @Override + public List getExpressions() { + return new ImmutableList.Builder() + .addAll(groupByExpressions) + .addAll(outputExpressions) + .addAll(partitionExpressions.orElse(ImmutableList.of())) + .build(); + } + + @Override + public String toString() { + return Utils.toSqlString("PhysicalHashAggregate", + "aggPhase", aggregateParam.aggPhase, + "aggMode", aggregateParam.aggMode, + "maybeUseStreaming", maybeUsingStream, + "groupByExpr", groupByExpressions, + "outputExpr", outputExpressions, + "partitionExpr", partitionExpressions, + "requireProperties", requireProperties, + "stats", statsDeriveResult + ); + } + + /** + * Determine the equality with another operator + */ + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PhysicalHashAggregate that = (PhysicalHashAggregate) o; + return Objects.equals(groupByExpressions, that.groupByExpressions) + && Objects.equals(outputExpressions, that.outputExpressions) + && Objects.equals(partitionExpressions, that.partitionExpressions) + && Objects.equals(aggregateParam, that.aggregateParam) + && maybeUsingStream == that.maybeUsingStream + && Objects.equals(requireProperties, that.requireProperties); + } + + @Override + public int hashCode() { + return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, requireProperties); + } + + @Override + public PhysicalHashAggregate withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, getLogicalProperties(), requireProperties, children.get(0)); + } + + public PhysicalHashAggregate withPartitionExpressions(List partitionExpressions) { + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, + Optional.ofNullable(partitionExpressions), aggregateParam, maybeUsingStream, + Optional.empty(), getLogicalProperties(), requireProperties, child()); + } + + @Override + public PhysicalHashAggregate withGroupExpression(Optional groupExpression) { + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child()); + } + + @Override + public PhysicalHashAggregate withLogicalProperties(Optional logicalProperties) { + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, Optional.empty(), logicalProperties.get(), + requireProperties, child()); + } + + @Override + public PhysicalHashAggregate withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties, + StatsDeriveResult statsDeriveResult) { + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(), + requireProperties, physicalProperties, statsDeriveResult, + child()); + } + + @Override + public PhysicalHashAggregate withAggOutput(List newOutput) { + return new PhysicalHashAggregate<>(groupByExpressions, newOutput, partitionExpressions, + aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(), + requireProperties, physicalProperties, statsDeriveResult, child()); + } + + public PhysicalHashAggregate withRequirePropertiesAndChild( + RequireProperties requireProperties, C newChild) { + return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, + aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(), + requireProperties, physicalProperties, statsDeriveResult, newChild); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java index c4228b91b47f84..2910e727184350 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java @@ -24,8 +24,8 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.PreAggStatus; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.algebra.OlapScan; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.StatsDeriveResult; @@ -39,22 +39,20 @@ /** * Physical olap scan plan. */ -public class PhysicalOlapScan extends PhysicalRelation { +public class PhysicalOlapScan extends PhysicalRelation implements OlapScan { private final OlapTable olapTable; private final DistributionSpec distributionSpec; private final long selectedIndexId; private final ImmutableList selectedTabletIds; private final ImmutableList selectedPartitionIds; private final PreAggStatus preAggStatus; - private final PushDownAggOperator pushDownAggOperator; /** * Constructor for PhysicalOlapScan. */ public PhysicalOlapScan(RelationId id, OlapTable olapTable, List qualifier, long selectedIndexId, List selectedTabletIds, List selectedPartitionIds, DistributionSpec distributionSpec, - PreAggStatus preAggStatus, PushDownAggOperator pushDownAggOperator, - Optional groupExpression, LogicalProperties logicalProperties) { + PreAggStatus preAggStatus, Optional groupExpression, LogicalProperties logicalProperties) { super(id, PlanType.PHYSICAL_OLAP_SCAN, qualifier, groupExpression, logicalProperties); this.olapTable = olapTable; this.selectedIndexId = selectedIndexId; @@ -62,7 +60,6 @@ public PhysicalOlapScan(RelationId id, OlapTable olapTable, List qualifi this.selectedPartitionIds = ImmutableList.copyOf(selectedPartitionIds); this.distributionSpec = distributionSpec; this.preAggStatus = preAggStatus; - this.pushDownAggOperator = pushDownAggOperator; } /** @@ -70,8 +67,7 @@ public PhysicalOlapScan(RelationId id, OlapTable olapTable, List qualifi */ public PhysicalOlapScan(RelationId id, OlapTable olapTable, List qualifier, long selectedIndexId, List selectedTabletIds, List selectedPartitionIds, DistributionSpec distributionSpec, - PreAggStatus preAggStatus, PushDownAggOperator pushDownAggOperator, - Optional groupExpression, LogicalProperties logicalProperties, + PreAggStatus preAggStatus, Optional groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) { super(id, PlanType.PHYSICAL_OLAP_SCAN, qualifier, groupExpression, logicalProperties, physicalProperties, statsDeriveResult); @@ -81,13 +77,14 @@ public PhysicalOlapScan(RelationId id, OlapTable olapTable, List qualifi this.selectedPartitionIds = ImmutableList.copyOf(selectedPartitionIds); this.distributionSpec = distributionSpec; this.preAggStatus = preAggStatus; - this.pushDownAggOperator = pushDownAggOperator; } + @Override public long getSelectedIndexId() { return selectedIndexId; } + @Override public List getSelectedTabletIds() { return selectedTabletIds; } @@ -109,10 +106,6 @@ public PreAggStatus getPreAggStatus() { return preAggStatus; } - public PushDownAggOperator getPushDownAggOperator() { - return pushDownAggOperator; - } - @Override public String toString() { return Utils.toSqlString("PhysicalOlapScan", @@ -150,23 +143,21 @@ public R accept(PlanVisitor visitor, C context) { @Override public PhysicalOlapScan withGroupExpression(Optional groupExpression) { return new PhysicalOlapScan(id, olapTable, qualifier, selectedIndexId, selectedTabletIds, - selectedPartitionIds, distributionSpec, preAggStatus, pushDownAggOperator, - groupExpression, getLogicalProperties()); + selectedPartitionIds, distributionSpec, preAggStatus, groupExpression, getLogicalProperties()); } @Override public PhysicalOlapScan withLogicalProperties(Optional logicalProperties) { return new PhysicalOlapScan(id, olapTable, qualifier, selectedIndexId, selectedTabletIds, - selectedPartitionIds, distributionSpec, preAggStatus, pushDownAggOperator, - Optional.empty(), logicalProperties.get()); + selectedPartitionIds, distributionSpec, preAggStatus, Optional.empty(), + logicalProperties.get()); } @Override public PhysicalOlapScan withPhysicalPropertiesAndStats( PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) { return new PhysicalOlapScan(id, olapTable, qualifier, selectedIndexId, selectedTabletIds, - selectedPartitionIds, distributionSpec, preAggStatus, pushDownAggOperator, - Optional.empty(), getLogicalProperties(), - physicalProperties, statsDeriveResult); + selectedPartitionIds, distributionSpec, preAggStatus, Optional.empty(), + getLogicalProperties(), physicalProperties, statsDeriveResult); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java new file mode 100644 index 00000000000000..08eebb3e36860d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.physical; + +import org.apache.doris.catalog.Table; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.Utils; +import org.apache.doris.statistics.StatsDeriveResult; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** PhysicalStorageLayerAggregate */ +public class PhysicalStorageLayerAggregate extends PhysicalRelation { + private final PhysicalRelation relation; + private final PushDownAggOp aggOp; + + public PhysicalStorageLayerAggregate(PhysicalRelation relation, PushDownAggOp aggOp) { + super(relation.getId(), relation.getType(), relation.getQualifier(), + Optional.empty(), relation.getLogicalProperties()); + this.relation = Objects.requireNonNull(relation, "relation cannot be null"); + this.aggOp = Objects.requireNonNull(aggOp, "aggOp cannot be null"); + } + + public PhysicalStorageLayerAggregate(PhysicalRelation relation, PushDownAggOp aggOp, + Optional groupExpression, LogicalProperties logicalProperties, + PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) { + super(relation.getId(), relation.getType(), relation.getQualifier(), groupExpression, + logicalProperties, physicalProperties, statsDeriveResult); + this.relation = Objects.requireNonNull(relation, "relation cannot be null"); + this.aggOp = Objects.requireNonNull(aggOp, "aggOp cannot be null"); + } + + public PhysicalRelation getRelation() { + return relation; + } + + public PushDownAggOp getAggOp() { + return aggOp; + } + + @Override + public Table getTable() { + return relation.getTable(); + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visitPhysicalStorageLayerAggregate(this, context); + } + + @Override + public List getExpressions() { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + PhysicalStorageLayerAggregate that = (PhysicalStorageLayerAggregate) o; + return Objects.equals(relation, that.relation) && aggOp == that.aggOp; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), relation, aggOp); + } + + @Override + public String toString() { + return Utils.toSqlString("PhysicalStorageLayerAggregate", + "pushDownAggOp", aggOp, + "relation", relation, + "stats", statsDeriveResult + ); + } + + public PhysicalStorageLayerAggregate withPhysicalOlapScan(PhysicalOlapScan physicalOlapScan) { + return new PhysicalStorageLayerAggregate(relation, aggOp); + } + + @Override + public PhysicalStorageLayerAggregate withGroupExpression(Optional groupExpression) { + return new PhysicalStorageLayerAggregate(relation, aggOp, groupExpression, getLogicalProperties(), + physicalProperties, statsDeriveResult); + } + + @Override + public Plan withLogicalProperties(Optional logicalProperties) { + return new PhysicalStorageLayerAggregate(relation, aggOp, Optional.empty(), + logicalProperties.get(), physicalProperties, statsDeriveResult); + } + + @Override + public PhysicalPlan withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties, + StatsDeriveResult statsDeriveResult) { + return new PhysicalStorageLayerAggregate(relation, aggOp, Optional.empty(), + getLogicalProperties(), physicalProperties, statsDeriveResult); + } + + /** PushAggOp */ + public enum PushDownAggOp { + COUNT, MIN_MAX, MIX; + + public static Map supportedFunctions() { + return ImmutableMap.builder() + .put(Count.class, PushDownAggOp.COUNT) + .put(Min.class, PushDownAggOp.MIN_MAX) + .put(Max.class, PushDownAggOp.MIN_MAX) + .build(); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/DefaultPlanRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/DefaultPlanRewriter.java index 144da6428c39dc..5339b13d42afc0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/DefaultPlanRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/DefaultPlanRewriter.java @@ -18,6 +18,8 @@ package org.apache.doris.nereids.trees.plans.visitor; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import java.util.ArrayList; import java.util.List; @@ -41,4 +43,13 @@ public Plan visit(Plan plan, C context) { } return hasNewChildren ? plan.withChildren(newChildren) : plan; } + + @Override + public Plan visitPhysicalStorageLayerAggregate(PhysicalStorageLayerAggregate storageLayerAggregate, C context) { + PhysicalOlapScan olapScan = (PhysicalOlapScan) storageLayerAggregate.getRelation().accept(this, context); + if (olapScan != storageLayerAggregate.getRelation()) { + return storageLayerAggregate.withPhysicalOlapScan(olapScan); + } + return storageLayerAggregate; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java index 0490b70c8c71ff..07ca552e8d5d89 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java @@ -47,11 +47,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalEmptyRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; @@ -62,6 +62,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFRelation; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; @@ -195,7 +196,7 @@ public R visitLogicalHaving(LogicalHaving having, C context) { // Physical plans // ******************************* - public R visitPhysicalAggregate(PhysicalAggregate agg, C context) { + public R visitPhysicalHashAggregate(PhysicalHashAggregate agg, C context) { return visit(agg, context); } @@ -219,6 +220,10 @@ public R visitPhysicalOlapScan(PhysicalOlapScan olapScan, C context) { return visitPhysicalScan(olapScan, context); } + public R visitPhysicalStorageLayerAggregate(PhysicalStorageLayerAggregate storageLayerAggregate, C context) { + return storageLayerAggregate.getRelation().accept(this, context); + } + public R visitPhysicalTVFRelation(PhysicalTVFRelation tvfRelation, C context) { return visitPhysicalScan(tvfRelation, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 6d85eec105a93e..17f45d452bbc38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -49,6 +49,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; /** * Expression rewrite helper class. @@ -267,6 +268,13 @@ public static List replace(List exprs, .collect(ImmutableList.toImmutableList()); } + public static List rewriteDownShortCircuit( + List exprs, Function rewriteFunction) { + return exprs.stream() + .map(expr -> (E) expr.rewriteDownShortCircuit(rewriteFunction)) + .collect(ImmutableList.toImmutableList()); + } + private static class ExpressionReplacer extends DefaultExpressionRewriter> { public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); @@ -344,6 +352,10 @@ public static boolean anyMatch(List expressions, Predicate .anyMatch(expr -> expr.anyMatch(predicate)); } + public static boolean containsType(List expressions, Class type) { + return anyMatch(expressions, type::isInstance); + } + public static Set collect(List expressions, Predicate> predicate) { return expressions.stream() @@ -351,6 +363,13 @@ public static Set collect(List expressions, .collect(ImmutableSet.toImmutableSet()); } + public static List collectAll(List expressions, + Predicate> predicate) { + return expressions.stream() + .flatMap(expr -> expr.>collect(predicate).stream()) + .collect(ImmutableList.toImmutableList()); + } + public static List> rollupToGroupingSets(List rollupExpressions) { List> groupingSets = Lists.newArrayList(); for (int end = rollupExpressions.size(); end >= 0; --end) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java index da1492c2537c04..0e69bfba07cd4d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java @@ -42,6 +42,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -304,8 +305,14 @@ public String getNodeExplainString(String detailPrefix, TExplainLevel detailLeve } if (aggInfo.getAggregateExprs() != null && aggInfo.getMaterializedAggregateExprs().size() > 0) { - output.append(detailPrefix).append("output: ") - .append(getExplainString(aggInfo.getMaterializedAggregateExprs())).append("\n"); + List labels = aggInfo.getMaterializedAggregateExprLabels(); + if (labels.isEmpty()) { + output.append(detailPrefix).append("output: ") + .append(getExplainString(aggInfo.getMaterializedAggregateExprs())).append("\n"); + } else { + output.append(detailPrefix).append("output: ") + .append(StringUtils.join(labels, ", ")).append("\n"); + } } // TODO: group by can be very long. Break it into multiple lines output.append(detailPrefix).append("group by: ") diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java index 9a2cf9c865d784..6edf3fc0238feb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java @@ -985,6 +985,9 @@ public String getNodeExplainString(String prefix, TExplainLevel detailLevel) { output.append(prefix).append(String.format("cardinality=%s", cardinality)) .append(String.format(", avgRowSize=%s", avgRowSize)).append(String.format(", numNodes=%s", numNodes)); output.append("\n"); + if (pushDownAggNoGroupingOp != null) { + output.append(prefix).append("pushAggOp=").append(pushDownAggNoGroupingOp).append("\n"); + } return output.toString(); } @@ -1216,4 +1219,19 @@ public String getReasonOfPreAggregation() { public String getSelectedIndexName() { return olapTable.getIndexNameById(selectedIndexId); } + + public void finalizeForNerieds() { + computeNumNodes(); + computeStatsForNerieds(); + } + + private void computeStatsForNerieds() { + if (cardinality > 0 && avgRowSize <= 0) { + avgRowSize = totalBytes / (float) cardinality * COMPRESSION_RATIO; + capCardinalityAtLimit(); + } + // when node scan has no data, cardinality should be 0 instead of a invalid + // value after computeStats() + cardinality = cardinality == -1 ? 0 : cardinality; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java index 7cc059ed471285..cae81933112b77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java @@ -231,6 +231,10 @@ public boolean hasColocatePlanNode() { return hasColocatePlanNode; } + public void setDataPartition(DataPartition dataPartition) { + this.dataPartition = dataPartition; + } + /** * Finalize plan tree and create stream sink, if needed. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 677d8042eaf839..5e615850eed2a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -28,6 +28,7 @@ import org.apache.doris.thrift.TResourceLimit; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -39,8 +40,10 @@ import java.io.IOException; import java.io.Serializable; import java.lang.reflect.Field; +import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.Set; @@ -190,6 +193,7 @@ public class SessionVariable implements Serializable, Writable { = "trim_tailing_spaces_for_external_table_query"; public static final String ENABLE_NEREIDS_PLANNER = "enable_nereids_planner"; + public static final String DISABLE_NEREIDS_RULES = "disable_nereids_rules"; public static final String ENABLE_FALLBACK_TO_ORIGINAL_PLANNER = "enable_fallback_to_original_planner"; @@ -540,6 +544,9 @@ public class SessionVariable implements Serializable, Writable { @VariableMgr.VarAttr(name = ENABLE_NEREIDS_PLANNER) private boolean enableNereidsPlanner = false; + @VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES) + private String disableNereidsRules = ""; + @VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT) private boolean nereidsStarSchemaSupport = true; @@ -1233,6 +1240,16 @@ public void setEnableNereidsPlanner(boolean enableNereidsPlanner) { this.enableNereidsPlanner = enableNereidsPlanner; } + public Set getDisableNereidsRules() { + return Arrays.stream(disableNereidsRules.split(",[\\s]*")) + .map(rule -> rule.toUpperCase(Locale.ROOT)) + .collect(ImmutableSet.toImmutableSet()); + } + + public void setDisableNereidsRules(String disableNereidsRules) { + this.disableNereidsRules = disableNereidsRules; + } + public boolean isNereidsStarSchemaSupport() { return isEnableNereidsPlanner() && nereidsStarSchemaSupport; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java index 2ce0f10c7a7de7..f9553d9b146ee3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.PreAggStatus; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; @@ -64,9 +63,8 @@ public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exce t1Output.add(col3); LogicalProperties t1Properties = new LogicalProperties(() -> t1Output); PhysicalOlapScan scan = new PhysicalOlapScan(RelationUtil.newRelationId(), t1, qualifier, 0L, - Collections.emptyList(), Collections.emptyList(), null, PreAggStatus.on(), PushDownAggOperator.NONE, - Optional.empty(), - t1Properties); + Collections.emptyList(), Collections.emptyList(), null, PreAggStatus.on(), + Optional.empty(), t1Properties); Literal t1FilterRight = new IntegerLiteral(1); Expression t1FilterExpr = new GreaterThan(col1, t1FilterRight); PhysicalFilter filter = diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java index 9cd75c66ee4791..080c40a3edae58 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.PreAggStatus; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; @@ -76,7 +75,7 @@ public void testMergeProj(@Injectable LogicalProperties placeHolder, @Injectable t1Output.add(c); LogicalProperties t1Properties = new LogicalProperties(() -> t1Output); PhysicalOlapScan scan = new PhysicalOlapScan(RelationId.createGenerator().getNextId(), t1, qualifier, 0L, - Collections.emptyList(), Collections.emptyList(), null, PreAggStatus.on(), PushDownAggOperator.NONE, + Collections.emptyList(), Collections.emptyList(), null, PreAggStatus.on(), Optional.empty(), t1Properties); Alias x = new Alias(a, "x"); List projList3 = Lists.newArrayList(x, b, c); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java index ec7347ebbddd62..0810b5ed0a953d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java @@ -27,13 +27,15 @@ import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; @@ -271,14 +273,13 @@ public void testNestedLoopJoin() { @Test public void testLocalPhaseAggregate() { SlotReference key = new SlotReference("col1", IntegerType.INSTANCE); - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(key), Lists.newArrayList(key), - Lists.newArrayList(key), - AggPhase.LOCAL, - true, + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER), true, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); @@ -296,14 +297,13 @@ public void testLocalPhaseAggregate() { public void testGlobalPhaseAggregate() { SlotReference key = new SlotReference("col1", IntegerType.INSTANCE); SlotReference partition = new SlotReference("col2", BigIntType.INSTANCE); - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(key), Lists.newArrayList(key), - Lists.newArrayList(partition), - AggPhase.GLOBAL, - true, + new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.createHash(ImmutableList.of(partition), ShuffleType.AGGREGATE)), groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); @@ -326,14 +326,13 @@ public void testGlobalPhaseAggregate() { @Test public void testAggregateWithoutGroupBy() { - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(), Lists.newArrayList(), - Lists.newArrayList(), - AggPhase.GLOBAL, - true, + new AggregateParam(AggPhase.LOCAL, AggMode.BUFFER_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index b6330c29482540..9e007f8bea1260 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -24,19 +24,22 @@ import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import mockit.Expectations; import mockit.Injectable; @@ -143,14 +146,13 @@ Pair, List> getOnClauseUsedSlots( @Test public void testLocalAggregate() { SlotReference key = new SlotReference("col1", IntegerType.INSTANCE); - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(key), Lists.newArrayList(key), - Lists.newArrayList(key), - AggPhase.LOCAL, + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT), true, - false, logicalProperties, + RequireProperties.of(PhysicalProperties.ANY), groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); @@ -166,14 +168,13 @@ public void testLocalAggregate() { public void testGlobalAggregate() { SlotReference key = new SlotReference("col1", IntegerType.INSTANCE); SlotReference partition = new SlotReference("partition", IntegerType.INSTANCE); - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(key), Lists.newArrayList(key), - Lists.newArrayList(partition), - AggPhase.GLOBAL, - true, + new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.createHash(ImmutableList.of(partition), ShuffleType.AGGREGATE)), groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); @@ -191,14 +192,13 @@ public void testGlobalAggregate() { @Test public void testGlobalAggregateWithoutPartition() { SlotReference key = new SlotReference("col1", IntegerType.INSTANCE); - PhysicalAggregate aggregate = new PhysicalAggregate<>( + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( Lists.newArrayList(), Lists.newArrayList(key), - Lists.newArrayList(), - AggPhase.GLOBAL, - true, + new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index 6e59b9123dc07f..8aa1e1cc84134e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -25,7 +25,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; -import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; +import org.apache.doris.nereids.rules.rewrite.AggregateStrategies; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg; import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin; @@ -126,7 +126,7 @@ public void testTranslateCase() throws Exception { new MockUp() { @Mock public List getExplorationRules() { - return Lists.newArrayList(new AggregateDisassemble().build()); + return Lists.newArrayList(new AggregateStrategies().buildRules()); } }; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java index 65c15602e0a415..045a6f026ec4df 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java @@ -39,6 +39,7 @@ public class CheckAnalysisTest { public void testCheckExpressionInputTypes() { Plan plan = new LogicalFilter<>(new And(new IntegerLiteral(1), BooleanLiteral.TRUE), groupPlan); CheckAnalysis checkAnalysis = new CheckAnalysis(); - Assertions.assertThrows(RuntimeException.class, () -> checkAnalysis.build().transform(plan, cascadesContext)); + Assertions.assertThrows(RuntimeException.class, () -> + checkAnalysis.buildRules().forEach(rule -> rule.transform(plan, cascadesContext))); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java index 6a0321be355306..542e28f1a4c12d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java @@ -28,7 +28,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; -import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; +import org.apache.doris.nereids.rules.rewrite.AggregateStrategies; import org.apache.doris.nereids.rules.rewrite.logical.InApplyToJoin; import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter; import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderProject; @@ -120,7 +120,7 @@ public void testTranslateCase() throws Exception { new MockUp() { @Mock public List getExplorationRules() { - return Lists.newArrayList(new AggregateDisassemble().build()); + return Lists.newArrayList(new AggregateStrategies().buildRules()); } }; @@ -198,9 +198,9 @@ public void testCTEInHavingAndSubquery() { false, ImmutableList.of("cte1")); SlotReference region2 = new SlotReference(new ExprId(12), "s_region", VarcharType.INSTANCE, false, ImmutableList.of("cte2")); - SlotReference count = new SlotReference(new ExprId(14), "count()", BigIntType.INSTANCE, + SlotReference count = new SlotReference(new ExprId(14), "count(*)", BigIntType.INSTANCE, false, ImmutableList.of()); - Alias countAlias = new Alias(new ExprId(14), new Count(), "count()"); + Alias countAlias = new Alias(new ExprId(14), new Count(), "count(*)"); PlanChecker.from(connectContext) .analyze(sql3) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java index c84af5e198c06b..6ce5ff20626697 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java @@ -28,7 +28,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; -import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit; +import org.apache.doris.nereids.trees.expressions.literal.Interval.TimeUnit; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DateTimeType; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java index 1a624f32d97f4c..657eea4f7b669e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java @@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -54,7 +53,6 @@ public class ImplementationTest { private static final Map rulesMap = ImmutableMap.builder() .put(LogicalProject.class.getName(), (new LogicalProjectToPhysicalProject()).build()) - .put(LogicalAggregate.class.getName(), (new LogicalAggToPhysicalHashAgg()).build()) .put(LogicalJoin.class.getName(), (new LogicalJoinToHashJoin()).build()) .put(LogicalOlapScan.class.getName(), (new LogicalOlapScanToPhysicalOlapScan()).build()) .put(LogicalFilter.class.getName(), (new LogicalFilterToPhysicalFilter()).build()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java index 6eb5f301600f37..840402bb72b10d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java @@ -364,7 +364,7 @@ public void testCountDistinctKeyColumn() { public void testCountDistinctValueColumn() { singleTableTest("select k1, count(distinct v1) from from t group by k1", scan -> { Assertions.assertFalse(scan.isPreAggregation()); - Assertions.assertEquals("Count distinct is only valid for key columns, but meet count(distinct v1).", + Assertions.assertEquals("Count distinct is only valid for key columns, but meet count(DISTINCT v1).", scan.getReasonOfPreAggregation()); Assertions.assertEquals("t", scan.getSelectedIndexName()); }); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java deleted file mode 100644 index 4292c52541c101..00000000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java +++ /dev/null @@ -1,354 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; -import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble; -import org.apache.doris.nereids.trees.expressions.Add; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.plans.AggPhase; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PatternMatchSupported; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; - -import java.util.ArrayList; -import java.util.List; - -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class AggregateDisassembleTest implements PatternMatchSupported { - private Plan rStudent; - - @BeforeAll - public final void beforeAll() { - rStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); - } - - /** - *

-     * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
-     *   +--childPlan(id, name, age)
-     * we should rewrite to:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
-     *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
-     *       +--childPlan(id, name, age)
-     * 
- */ - @Test - public void slotReferenceGroupBy() { - List groupExpressionList = Lists.newArrayList( - rStudent.getOutput().get(2).toSlot()); - List outputExpressionList = Lists.newArrayList( - rStudent.getOutput().get(2).toSlot(), - new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); - Sum localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); - Expression localGroupBy = rStudent.getOutput().get(2).toSlot(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new AggregateDisassemble()) - .printlnTree() - .matchesFromRoot( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().size() == 2) - .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) - .when(agg -> agg.getOutputExpressions().get(1).child(0) - .children().equals(localOutput1.children())) - .when(agg -> agg.getGroupByExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> agg.getOutputExpressions().size() == 2) - .when(agg -> agg.getOutputExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0).toSlot())) - .when(agg -> agg.getOutputExpressions().get(1).child(0).child(0) - .equals(agg.child().getOutputExpressions().get(1).toSlot())) - .when(agg -> agg.getGroupByExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0).toSlot())) - // check id: - .when(agg -> agg.getOutputExpressions().get(0).getExprId() - .equals(outputExpressionList.get(0).getExprId())) - .when(agg -> agg.getOutputExpressions().get(1).getExprId() - .equals(outputExpressionList.get(1).getExprId())) - ); - } - - /** - *
-     * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
-     *   +--childPlan(id, name, age)
-     * we should rewrite to:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
-     *   +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
-     *       +--childPlan(id, name, age)
-     * 
- */ - @Test - public void globalAggregate() { - List groupExpressionList = Lists.newArrayList(); - List outputExpressionList = Lists.newArrayList( - new Alias(new Sum(rStudent.getOutput().get(0)), "sum")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - Sum localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot()); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new AggregateDisassemble()) - .printlnTree() - .matchesFromRoot( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().size() == 1) - .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) - .equals(localOutput0.child())) - .when(agg -> agg.getGroupByExpressions().size() == 0) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> agg.getOutputExpressions().size() == 1) - .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) - .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) - .equals(agg.child().getOutputExpressions().get(0).toSlot())) - .when(agg -> agg.getGroupByExpressions().size() == 0) - // check id: - .when(agg -> agg.getOutputExpressions().get(0).getExprId() - .equals(outputExpressionList.get(0).getExprId())) - ); - } - - /** - *
-     * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
-     *   +--childPlan(id, name, age)
-     * we should rewrite to:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])
-     *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
-     *       +--childPlan(id, name, age)
-     * 
- */ - @Test - public void groupExpressionNotInOutput() { - List groupExpressionList = Lists.newArrayList( - rStudent.getOutput().get(2).toSlot()); - List outputExpressionList = Lists.newArrayList( - new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); - Sum localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); - Expression localGroupBy = rStudent.getOutput().get(2).toSlot(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new AggregateDisassemble()) - .printlnTree() - .matchesFromRoot( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().size() == 2) - .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) - .when(agg -> agg.getOutputExpressions().get(1).child(0).child(0) - .equals(localOutput1.child())) - .when(agg -> agg.getGroupByExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> agg.getOutputExpressions().size() == 1) - .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) - .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) - .equals(agg.child().getOutputExpressions().get(1).toSlot())) - .when(agg -> agg.getGroupByExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0).toSlot())) - // check id: - .when(agg -> agg.getOutputExpressions().get(0).getExprId() - .equals(outputExpressionList.get(0).getExprId())) - ); - } - - /** - *
-     * the initial plan is:
-     *   Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [])
-     *   +-- childPlan(id, name, age)
-     * we should rewrite to:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [count(distinct c)], groupByExpr: [])
-     *   +-- Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [])
-     *       +-- childPlan(id, name, age)
-     * 
- */ - @Test - public void distinctAggregateWithoutGroupByApply2PhaseRule() { - List groupExpressionList = new ArrayList<>(); - List outputExpressionList = Lists.newArrayList(new Alias( - new Add(new Count(AggregateParam.distinctAndFinalPhase(), rStudent.getOutput().get(2).toSlot()), - new IntegerLiteral(2)), "c")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new AggregateDisassemble()) - .matchesFromRoot( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().isEmpty()) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> agg.getOutputExpressions().size() == 1) - .when(agg -> agg.getGroupByExpressions().isEmpty()) - ); - } - - @Test - public void distinctWithNormalAggregateFunctionApply2PhaseRule() { - List groupExpressionList = Lists.newArrayList(rStudent.getOutput().get(0).toSlot()); - List outputExpressionList = Lists.newArrayList( - new Alias(new Count(AggregateParam.distinctAndFinalPhase(), rStudent.getOutput().get(2).toSlot()), "c"), - new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - // check local: - // id - Expression localOutput0 = rStudent.getOutput().get(0); - // count - Count localOutput1 = new Count(new AggregateParam(true, false, AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot()); - // sum - Sum localOutput2 = new Sum(new AggregateParam(false, false, AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot()); - // id - Expression localGroupBy0 = rStudent.getOutput().get(0); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new AggregateDisassemble()) - .matchesFromRoot( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) - .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1)) - .when(agg -> agg.getOutputExpressions().get(2).child(0).equals(localOutput2)) - .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0)) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(1).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(0).child(0) instanceof Count); - return agg.getOutputExpressions().get(0).child(0).child(0).equals(child); - }) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(2).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Sum); - return ((Sum) agg.getOutputExpressions().get(1).child(0)).child().equals(child); - }) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0))) - ); - } - - @Test - public void distinctWithNormalAggregateFunctionApply4PhaseRule() { - List groupExpressionList = Lists.newArrayList(rStudent.getOutput().get(0).toSlot()); - List outputExpressionList = Lists.newArrayList( - new Alias(new Count(AggregateParam.distinctAndFinalPhase(), rStudent.getOutput().get(2).toSlot()), "c"), - new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - - // check local: - // id - Expression localOutput0 = rStudent.getOutput().get(0); - // count - Count localOutput1 = new Count(new AggregateParam(true, false, AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot()); - // sum - Sum localOutput2 = new Sum(new AggregateParam(false, false, AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot()); - // id - Expression localGroupBy0 = rStudent.getOutput().get(0); - - PlanChecker.from(MemoTestUtils.createConnectContext(), root) - .applyTopDown(new DistinctAggregateDisassemble()) - .matchesFromRoot( - logicalAggregate( - logicalAggregate( - logicalAggregate( - logicalAggregate() - .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) - .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) - .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1)) - .when(agg -> agg.getOutputExpressions().get(2).child(0).equals(localOutput2)) - .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0)) - ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(1).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Count); - return agg.getOutputExpressions().get(1).child(0).child(0).equals(child); - }) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(2).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof Sum); - return ((Sum) agg.getOutputExpressions().get(2).child(0)).child().equals(child); - }) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0))) - ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL)) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(1).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Count); - return agg.getOutputExpressions().get(1).child(0).child(0).equals(child); - }) - .when(agg -> { - Slot child = agg.child().getOutputExpressions().get(2).toSlot(); - Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof Sum); - return ((Sum) agg.getOutputExpressions().get(2).child(0)).child().equals(child); - }) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().getOutputExpressions().get(0))) - ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL)) - .when(agg -> agg.getOutputExpressions().size() == 2) - .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) - .when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Count) - .when(agg -> agg.getOutputExpressions().get(1).child(0) instanceof Sum) - .when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get( - 0).getExprId()) - .when(agg -> agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get( - 1).getExprId()) - .when(agg -> agg.getGroupByExpressions().get(0) - .equals(agg.child().child().child().getOutputExpressions().get(0))) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java new file mode 100644 index 00000000000000..0376aa23a559d2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java @@ -0,0 +1,400 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.AggregateStrategies; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.AggMode; +import org.apache.doris.nereids.trees.plans.AggPhase; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.RelationUtil; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class AggregateStrategiesTest implements PatternMatchSupported { + private Plan rStudent; + + @BeforeAll + public final void beforeAll() { + rStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, + ImmutableList.of("")); + } + + /** + *
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
+     *       +--childPlan(id, name, age)
+     * 
+ */ + @Test + public void slotReferenceGroupBy() { + List groupExpressionList = Lists.newArrayList( + rStudent.getOutput().get(2).toSlot()); + List outputExpressionList = Lists.newArrayList( + rStudent.getOutput().get(2).toSlot(), + new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); + Sum localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); + Slot localGroupBy = rStudent.getOutput().get(2).toSlot(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(twoPhaseAggregateWithoutDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).child(0) + .children().equals(localOutput1.children())) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + .when(agg -> agg.getOutputExpressions().get(1).child(0).child(0) + .equals(agg.child().getOutputExpressions().get(1).toSlot())) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + .when(agg -> agg.getOutputExpressions().get(1).getExprId() + .equals(outputExpressionList.get(1).getExprId())) + ); + } + + /** + *
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
+     *       +--childPlan(id, name, age)
+     * 
+ */ + @Test + public void globalAggregate() { + List groupExpressionList = Lists.newArrayList(); + List outputExpressionList = Lists.newArrayList( + new Alias(new Sum(rStudent.getOutput().get(0)), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + Sum localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot()); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(twoPhaseAggregateWithoutDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) + .equals(localOutput0)) + .when(agg -> agg.getGroupByExpressions().size() == 0) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + .when(agg -> agg.getGroupByExpressions().size() == 0) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + ); + } + + /** + *
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
+     *       +--childPlan(id, name, age)
+     * 
+ */ + @Test + public void groupExpressionNotInOutput() { + List groupExpressionList = Lists.newArrayList( + rStudent.getOutput().get(2).toSlot()); + List outputExpressionList = Lists.newArrayList( + new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); + Sum localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); + Expression localGroupBy = rStudent.getOutput().get(2).toSlot(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(twoPhaseAggregateWithoutDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).child(0) + .equals(localOutput1)) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0).child(0) + .equals(agg.child().getOutputExpressions().get(1).toSlot())) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + ); + } + + /** + *
+     * the initial plan is:
+     *   Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [])
+     *   +-- childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [count(distinct c)], groupByExpr: [])
+     *   +-- Aggregate(phase: [LOCAL], outputExpr: [age], groupByExpr: [age])
+     *       +-- childPlan(id, name, age)
+     * 
+ */ + @Test + public void distinctAggregateWithoutGroupByApply2PhaseRule() { + List groupExpressionList = new ArrayList<>(); + List outputExpressionList = Lists.newArrayList(new Alias( + new Add(new Count(true, rStudent.getOutput().get(2).toSlot()), + new IntegerLiteral(2)), "c")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + false, Optional.empty(), rStudent); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyBottomUp(new NormalizeAggregate()) + .applyImplementation(twoPhaseAggregateWithDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).equals(rStudent.getOutput().get(2).toSlot())) // group by name + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 1 + && agg.getOutputExpressions().get(0).child(0) instanceof AggregateExpression + && agg.getOutputExpressions().get(0).child(0).child(0) instanceof Count + && agg.getOutputExpressions().get(0).child(0).child(0).child(0).equals(rStudent.getOutput().get(2).toSlot())) // count(name) + .when(agg -> agg.getGroupByExpressions().isEmpty()) + ); + } + + @Test + public void distinctWithNormalAggregateFunctionApply2PhaseRule() { + Slot id = rStudent.getOutput().get(0); + Slot name = rStudent.getOutput().get(2).toSlot(); + List groupExpressionList = Lists.newArrayList(id.toSlot()); + List outputExpressionList = Lists.newArrayList( + new Alias(new Count(true, name), "c"), + new Alias(new Sum(id.toSlot()), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + // check local: + // id + AggregateParam phaseTwoCountAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT); + AggregateParam phaseOneSumAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT); + // sum + Sum sumId = new Sum(false, id.toSlot()); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(twoPhaseAggregateWithDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL) + && agg.getOutputExpressions().size() == 3 + && agg.getGroupByExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0).equals(id)) + .when(agg -> agg.getOutputExpressions().get(1).equals(name)) + .when(agg -> agg.getOutputExpressions().get(2).child(0).child(0).equals(sumId) + && ((AggregateExpression) agg.getOutputExpressions().get(2).child(0)).getAggregateParam().equals(phaseOneSumAggParam)) + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(id, name))) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(1).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(0).child(0).child(0) instanceof Count); + return agg.getOutputExpressions().get(0).child(0).child(0).child(0).equals(child); + }) + .when(agg -> { + Slot partialSum = agg.child().getOutputExpressions().get(2).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof AggregateExpression); + Assertions.assertEquals(phaseTwoSumAggParam, ((AggregateExpression) agg.getOutputExpressions().get(1).child(0)).getAggregateParam()); + Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0).child(0).equals(partialSum)); + + Assertions.assertEquals(phaseTwoCountAggParam, ((AggregateExpression) agg.getOutputExpressions().get(0).child(0)).getAggregateParam()); + return true; + }) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + ); + } + + @Test + @Disabled + @Developing("not support four phase aggregate") + public void distinctWithNormalAggregateFunctionApply4PhaseRule() { + Slot id = rStudent.getOutput().get(0).toSlot(); + Slot name = rStudent.getOutput().get(2).toSlot(); + List groupExpressionList = Lists.newArrayList(id); + List outputExpressionList = Lists.newArrayList( + new Alias(new Count(true, name), "c"), + new Alias(new Sum(id), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + // check local: + // count + Count phaseOneCountName = new Count(true, name); + // sum + Sum phaseOneSumId = new Sum(id); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(fourPhaseAggregateWithDistinct()) + .matchesFromRoot( + physicalHashAggregate( + physicalHashAggregate( + physicalHashAggregate( + physicalHashAggregate() // select id, count(distinct name), sum(id) group by id + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().get(0).equals(id)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(phaseOneCountName)) + .when(agg -> agg.getOutputExpressions().get(2).child(0).equals(phaseOneSumId)) + .when(agg -> agg.getGroupByExpressions().get(0).equals(id)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) // select id, count(distinct name), sum(id) group by id + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(1).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Count); + return agg.getOutputExpressions().get(1).child(0).child(0).equals(child); + }) + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(2).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof Sum); + return ((Sum) agg.getOutputExpressions().get(2).child(0)).child().equals(child); + }) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL)) + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(1).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Count); + return agg.getOutputExpressions().get(1).child(0).child(0).equals(child); + }) + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(2).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof Sum); + return ((Sum) agg.getOutputExpressions().get(2).child(0)).child().equals(child); + }) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Count) + .when(agg -> agg.getOutputExpressions().get(1).child(0) instanceof Sum) + .when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get( + 0).getExprId()) + .when(agg -> agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get( + 1).getExprId()) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().child().child().getOutputExpressions().get(0))) + ); + } + + private Rule twoPhaseAggregateWithoutDistinct() { + return new AggregateStrategies().buildRules() + .stream() + .filter(rule -> rule.getRuleType() == RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT) + .findFirst() + .get(); + } + + private Rule twoPhaseAggregateWithDistinct() { + return new AggregateStrategies().buildRules() + .stream() + .filter(rule -> rule.getRuleType() == RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT) + .findFirst() + .get(); + } + + @Developing + private Rule fourPhaseAggregateWithDistinct() { + return new AggregateStrategies().buildRules() + .stream() + .filter(rule -> rule.getRuleType() == RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT) + .findFirst() + .get(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java index 3c31b0466caa59..79d7af0e49c2e9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -78,8 +77,7 @@ protected void runBeforeAll() throws Exception { @Test public void inferPredicatesTest01() { String sql = "select * from student join score on student.id = score.sid where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -98,8 +96,7 @@ public void inferPredicatesTest01() { @Test public void inferPredicatesTest02() { String sql = "select * from student join score on student.id = score.sid"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -114,8 +111,7 @@ public void inferPredicatesTest02() { @Test public void inferPredicatesTest03() { String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -132,8 +128,7 @@ public void inferPredicatesTest03() { @Test public void inferPredicatesTest04() { String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -150,8 +145,7 @@ public void inferPredicatesTest04() { @Test public void inferPredicatesTest05() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -175,8 +169,7 @@ public void inferPredicatesTest05() { @Test public void inferPredicatesTest06() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -200,8 +193,7 @@ public void inferPredicatesTest06() { @Test public void inferPredicatesTest07() { String sql = "select * from student left join score on student.id = score.sid where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -220,8 +212,7 @@ public void inferPredicatesTest07() { @Test public void inferPredicatesTest08() { String sql = "select * from student left join score on student.id = score.sid and student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -239,8 +230,7 @@ public void inferPredicatesTest08() { public void inferPredicatesTest09() { // convert left join to inner join String sql = "select * from student left join score on student.id = score.sid where score.sid > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -259,8 +249,7 @@ public void inferPredicatesTest09() { @Test public void inferPredicatesTest10() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -281,8 +270,7 @@ public void inferPredicatesTest10() { @Test public void inferPredicatesTest11() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -301,24 +289,21 @@ public void inferPredicatesTest11() { @Test public void inferPredicatesTest12() { String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() .matchesFromRoot( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filer -> filer.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("id > 1")), + logicalAggregate( logicalProject( - logicalAggregate( - logicalProject( - logicalFilter( - logicalOlapScan() - ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) - )) - ) + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + )) ) ); } @@ -326,8 +311,7 @@ public void inferPredicatesTest12() { @Test public void inferPredicatesTest13() { String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -348,8 +332,7 @@ public void inferPredicatesTest13() { @Test public void inferPredicatesTest14() { String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -370,8 +353,7 @@ public void inferPredicatesTest14() { @Test public void inferPredicatesTest15() { String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -392,8 +374,7 @@ public void inferPredicatesTest15() { @Test public void inferPredicatesTest16() { String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -412,8 +393,7 @@ public void inferPredicatesTest16() { @Test public void inferPredicatesTest17() { String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -432,8 +412,7 @@ public void inferPredicatesTest17() { @Test public void inferPredicatesTest18() { String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -453,9 +432,30 @@ public void inferPredicatesTest18() { @Test public void inferPredicatesTest19() { - String sql = "select * from subquery1 left semi join (select t1.k3 from (select * from subquery3 left semi join (select k1 from subquery4 where k1 = 3) t on subquery3.k3 = t.k1) t1 inner join (select k2,sum(k2) as sk2 from subquery2 group by k2) t2 on t2.k2 = t1.v1 and t1.v2 > t2.sk2) t3 on t3.k3 = subquery1.k1"; - Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); - System.out.println(plan.treeString()); + String sql = "select * from subquery1\n" + + "left semi join (\n" + + " select t1.k3\n" + + " from (\n" + + " select *\n" + + " from subquery3\n" + + " left semi join\n" + + " (\n" + + " select k1\n" + + " from subquery4\n" + + " where k1 = 3\n" + + " ) t\n" + + " on subquery3.k3 = t.k1\n" + + " ) t1\n" + + " inner join\n" + + " (\n" + + " select k2,sum(k2) as sk2\n" + + " from subquery2\n" + + " group by k2\n" + + " ) t2\n" + + " on t2.k2 = t1.v1 and t1.v2 > t2.sk2\n" + + ") t3\n" + + "on t3.k3 = subquery1.k1"; + PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -465,21 +465,26 @@ public void inferPredicatesTest19() { logicalOlapScan() ).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")), logicalProject( + logicalJoin( logicalJoin( - logicalJoin( - logicalProject( - logicalFilter( - logicalOlapScan() - ).when(filter -> filter.getPredicates().toSql().contains("k3 = 3")) - ), - logicalProject( - logicalFilter( - logicalOlapScan() - ).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")) - ) - ), - logicalProject() + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("k3 = 3")) + ), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")) + ) + + ), + logicalAggregate( + logicalProject( + logicalOlapScan() + ) ) + ) ) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java index 2ef52a42cf404a..53848e4c50299f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java @@ -53,7 +53,7 @@ public class NormalizeAggregateTest implements PatternMatchSupported { @BeforeAll public final void beforeAll() { rStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); + ImmutableList.of()); } /*- @@ -79,23 +79,25 @@ public void testSimpleKeyWithSimpleAggregateFunction() { .matchesFromRoot( logicalProject( logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key))) - .when(aggregate -> aggregate.getOutputExpressions().get(0).equals(key)) - .when(aggregate -> aggregate.getOutputExpressions().get(1).child(0) - .equals(aggregateFunction.child(0))) - .when(FieldChecker.check("normalized", true)) + logicalProject( + logicalOlapScan() + ) + ).when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(key))) + .when(aggregate -> aggregate.getOutputExpressions().get(0).equals(key)) + .when(aggregate -> aggregate.getOutputExpressions().get(1).child(0) + .equals(aggregateFunction.child(0))) + .when(FieldChecker.check("normalized", true)) ).when(project -> project.getProjects().get(0).equals(key)) - .when(project -> project.getProjects().get(1) instanceof Alias) - .when(project -> (project.getProjects().get(1)).getExprId() - .equals(aggregateFunction.getExprId())) - .when(project -> project.getProjects().get(1).child(0) instanceof SlotReference) + .when(project -> project.getProjects().get(1) instanceof Alias) + .when(project -> (project.getProjects().get(1)).getExprId() + .equals(aggregateFunction.getExprId())) + .when(project -> project.getProjects().get(1).child(0) instanceof SlotReference) ); } /*- * original plan: - * LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS `(sum((id * 1)) + 2)`#4], groupBy: [name#2]) + * LogicalAggregate (output: [(sum((id#0 * 1)) + 2) AS `(sum((id * 1)) + 2)`#4], groupBy: [name#2]) * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) * * after rewrite: @@ -126,11 +128,12 @@ public void testComplexFuncWithComplexOutputOfFunc() { ).when(project -> project.getProjects().size() == 2) .when(project -> project.getProjects().get(0) instanceof SlotReference) .when(project -> project.getProjects().get(1).child(0).equals(multiply)) - ).when(FieldChecker.check("groupByExpressions", - ImmutableList.of(rStudent.getOutput().get(2)))) - .when(aggregate -> aggregate.getOutputExpressions().size() == 1) - .when(aggregate -> aggregate.getOutputExpressions().get(0) - .child(0) instanceof AggregateFunction) + ).when(agg -> agg.getGroupByExpressions().equals( + ImmutableList.of(rStudent.getOutput().get(2))) + ) + .when(aggregate -> aggregate.getOutputExpressions().size() == 2) + .when(aggregate -> aggregate.getOutputExpressions().get(1) + .child(0) instanceof AggregateFunction) ).when(project -> project.getProjects().size() == 1) .when(project -> project.getProjects().get(0) instanceof Alias) .when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId())) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScanTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java similarity index 55% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScanTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java index 0f275e465e9766..dac706d82598a3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushAggregateToOlapScanTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java @@ -18,69 +18,87 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.pattern.GeneratedPatterns; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RulePromise; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.AggregateStrategies; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ln; -import org.apache.doris.nereids.trees.plans.PushDownAggOperator; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Collections; +import java.util.Optional; -public class PushAggregateToOlapScanTest { +public class PhysicalStorageLayerAggregateTest implements GeneratedPatterns { @Test public void testWithoutProject() { LogicalOlapScan olapScan = PlanConstructor.newLogicalOlapScan(1, "tbl", 0); LogicalAggregate aggregate; CascadesContext context; - LogicalOlapScan pushedOlapScan; // min max aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Min(olapScan.getOutput().get(0)), "min")), - olapScan); + true, Optional.empty(), olapScan); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.MIN_MAX, pushedOlapScan.getPushDownAggOperator()); - + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithoutProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.MIN_MAX) + ) + ); // count aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Count(olapScan.getOutput().get(0)), "count")), - olapScan); + true, Optional.empty(), olapScan); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.COUNT, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithoutProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.COUNT) + ) + ); // mix aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Count(olapScan.getOutput().get(0)), "count"), new Alias(new Max(olapScan.getOutput().get(0)), "max")), - olapScan); + true, Optional.empty(), olapScan); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.MIX, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithoutProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.MIX) + ) + ); + } + + @Override + public RulePromise defaultPromise() { + return RulePromise.IMPLEMENT; } @Test @@ -90,44 +108,53 @@ public void testWithProject() { ImmutableList.of(olapScan.getOutput().get(0)), olapScan); LogicalAggregate> aggregate; CascadesContext context; - LogicalOlapScan pushedOlapScan; // min max aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Min(project.getOutput().get(0)), "min")), - project); + true, Optional.empty(), project); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0).child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.MIN_MAX, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.MIN_MAX) + ) + ); // count aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Count(project.getOutput().get(0)), "count")), - project); + true, Optional.empty(), project); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0).child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.COUNT, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.COUNT) + ) + ); // mix aggregate = new LogicalAggregate<>( Collections.emptyList(), ImmutableList.of(new Alias(new Count(project.getOutput().get(0)), "count"), new Alias(new Max(olapScan.getOutput().get(0)), "max")), + true, Optional.empty(), project); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0).child(0)); - Assertions.assertTrue(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.MIX, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithProject()) + .matches( + logicalAggregate( + physicalStorageLayerAggregate().when(agg -> agg.getAggOp() == PushDownAggOp.MIX) + ) + ); } @Test @@ -137,7 +164,6 @@ void testProjectionCheck() { ImmutableList.of(new Alias(new Ln(olapScan.getOutput().get(0)), "alias")), olapScan); LogicalAggregate> aggregate; CascadesContext context; - LogicalOlapScan pushedOlapScan; // min max aggregate = new LogicalAggregate<>( @@ -146,9 +172,30 @@ void testProjectionCheck() { project); context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new PushAggregateToOlapScan()); - pushedOlapScan = (LogicalOlapScan) (context.getMemo().copyOut().child(0).child(0)); - Assertions.assertFalse(pushedOlapScan.isAggPushed()); - Assertions.assertEquals(PushDownAggOperator.NONE, pushedOlapScan.getPushDownAggOperator()); + PlanChecker.from(context) + .applyImplementation(storageLayerAggregateWithProject()) + .matches( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ) + ); + } + + private Rule storageLayerAggregateWithoutProject() { + return new AggregateStrategies().buildRules() + .stream() + .filter(rule -> rule.getRuleType() == RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT) + .findFirst() + .get(); + } + + private Rule storageLayerAggregateWithProject() { + return new AggregateStrategies().buildRules() + .stream() + .filter(rule -> rule.getRuleType() == RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT) + .findFirst() + .get(); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java index c3bd3830cc8c70..c26c0205634832 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java @@ -20,7 +20,6 @@ import org.apache.doris.nereids.analyzer.UnboundAlias; import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundStar; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.types.IntegerType; @@ -156,8 +155,8 @@ public void testArithmetic() { @Test public void testUnboundFunction() { - UnboundFunction unboundFunction1 = new UnboundFunction("name", false, false, Lists.newArrayList(child1)); - UnboundFunction unboundFunction2 = new UnboundFunction("name", false, false, Lists.newArrayList(child2)); + UnboundFunction unboundFunction1 = new UnboundFunction("name", false, Lists.newArrayList(child1)); + UnboundFunction unboundFunction2 = new UnboundFunction("name", false, Lists.newArrayList(child2)); Assertions.assertEquals(unboundFunction1, unboundFunction2); Assertions.assertEquals(unboundFunction1.hashCode(), unboundFunction2.hashCode()); } @@ -177,13 +176,13 @@ public void testAggregateFunction() { Assertions.assertEquals(count1, count2); Assertions.assertEquals(count1.hashCode(), count2.hashCode()); - Count count3 = new Count(AggregateParam.distinctAndFinalPhase(), child1); - Count count4 = new Count(AggregateParam.distinctAndFinalPhase(), child2); + Count count3 = new Count(true, child1); + Count count4 = new Count(true, child2); Assertions.assertEquals(count3, count4); Assertions.assertEquals(count3.hashCode(), count4.hashCode()); // bad case - Count count5 = new Count(AggregateParam.distinctAndFinalPhase(), child1); + Count count5 = new Count(true, child1); Count count6 = new Count(child2); Assertions.assertNotEquals(count5, count6); Assertions.assertNotEquals(count5.hashCode(), count6.hashCode()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java index bf5a57ec77ebe2..c91cffabf9bc08 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java @@ -22,10 +22,13 @@ import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.properties.RequireProperties; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; @@ -34,8 +37,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; @@ -73,17 +76,17 @@ public void testLogicalAggregate(@Mocked Plan child) { unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - true, false, true, AggPhase.GLOBAL, Optional.empty(), child); + false, Optional.empty(), child); Assertions.assertNotEquals(unexpected, actual); unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - false, true, true, AggPhase.GLOBAL, Optional.empty(), child); + true, Optional.empty(), child); Assertions.assertNotEquals(unexpected, actual); unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - false, false, true, AggPhase.LOCAL, Optional.empty(), child); + false, Optional.empty(), child); Assertions.assertNotEquals(unexpected, actual); } @@ -183,21 +186,24 @@ public void testLogicalSort(@Mocked Plan child) { public void testPhysicalAggregate(@Mocked Plan child, @Mocked LogicalProperties logicalProperties) { List outputExpressionList = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); - PhysicalAggregate actual = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList, - Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child); + PhysicalHashAggregate actual = new PhysicalHashAggregate<>(Lists.newArrayList(), outputExpressionList, + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), child); List outputExpressionList1 = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); - PhysicalAggregate expected = new PhysicalAggregate<>(Lists.newArrayList(), + PhysicalHashAggregate expected = new PhysicalHashAggregate<>(Lists.newArrayList(), outputExpressionList1, - Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child); + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT), true, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), child); Assertions.assertEquals(expected, actual); List outputExpressionList2 = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); - PhysicalAggregate unexpected = new PhysicalAggregate<>(Lists.newArrayList(), + PhysicalHashAggregate unexpected = new PhysicalHashAggregate<>(Lists.newArrayList(), outputExpressionList2, - Lists.newArrayList(), AggPhase.LOCAL, false, true, logicalProperties, child); + new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT), false, logicalProperties, + RequireProperties.of(PhysicalProperties.GATHER), child); Assertions.assertNotEquals(unexpected, actual); } @@ -252,16 +258,16 @@ public void testPhysicalOlapScan( PhysicalOlapScan actual = new PhysicalOlapScan(id, olapTable, Lists.newArrayList("a"), olapTable.getBaseIndexId(), selectedTabletId, olapTable.getPartitionIds(), distributionSpecHash, - PreAggStatus.on(), PushDownAggOperator.NONE, Optional.empty(), logicalProperties); + PreAggStatus.on(), Optional.empty(), logicalProperties); PhysicalOlapScan expected = new PhysicalOlapScan(id, olapTable, Lists.newArrayList("a"), olapTable.getBaseIndexId(), selectedTabletId, olapTable.getPartitionIds(), distributionSpecHash, - PreAggStatus.on(), PushDownAggOperator.NONE, Optional.empty(), logicalProperties); + PreAggStatus.on(), Optional.empty(), logicalProperties); Assertions.assertEquals(expected, actual); PhysicalOlapScan unexpected = new PhysicalOlapScan(id, olapTable, Lists.newArrayList("b"), olapTable.getBaseIndexId(), selectedTabletId, olapTable.getPartitionIds(), distributionSpecHash, - PreAggStatus.on(), PushDownAggOperator.NONE, Optional.empty(), logicalProperties); + PreAggStatus.on(), Optional.empty(), logicalProperties); Assertions.assertNotEquals(unexpected, actual); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 97710b3713aa3f..5ddbce94c49bb4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -56,7 +56,7 @@ public void testLogicalAggregate(@Mocked Plan child) { new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertTrue(plan.toString() - .matches("LogicalAggregate \\( phase=LOCAL, outputExpr=\\[a#\\d+], groupByExpr=\\[], hasRepeat=false \\)")); + .matches("LogicalAggregate \\( groupByExpr=\\[], outputExpr=\\[a#\\d+], hasRepeat=false \\)")); } @Test @@ -81,7 +81,7 @@ public void testLogicalOlapScan() { LogicalOlapScan plan = PlanConstructor.newLogicalOlapScan(0, "table", 0); Assertions.assertTrue( plan.toString().matches("LogicalOlapScan \\( qualified=db\\.table, " - + "output=\\[id#\\d+, name#\\d+], candidateIndexIds=\\[], selectedIndexId=-1, preAgg=ON, pushAgg=NONE \\)")); + + "output=\\[id#\\d+, name#\\d+], candidateIndexIds=\\[], selectedIndexId=-1, preAgg=ON \\)")); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java index 5b6338aaf961a5..15861f724ec065 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java @@ -26,12 +26,26 @@ public class GroupMatchingUtils { public static boolean topDownFindMatching(Group group, Pattern pattern) { - GroupExpression logicalExpr = group.getLogicalExpression(); - GroupExpressionMatching matchingResult = new GroupExpressionMatching(pattern, logicalExpr); + for (GroupExpression logicalExpr : group.getLogicalExpressions()) { + if (topDownFindMatch(logicalExpr, pattern)) { + return true; + } + } + + for (GroupExpression physicalExpr : group.getPhysicalExpressions()) { + if (topDownFindMatch(physicalExpr, pattern)) { + return true; + } + } + return false; + } + + public static boolean topDownFindMatch(GroupExpression groupExpression, Pattern pattern) { + GroupExpressionMatching matchingResult = new GroupExpressionMatching(pattern, groupExpression); if (matchingResult.iterator().hasNext()) { return true; } else { - for (Group childGroup : logicalExpr.children()) { + for (Group childGroup : groupExpression.children()) { boolean checkResult = topDownFindMatching(childGroup, pattern); if (checkResult) { return true; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MemoTestUtils.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MemoTestUtils.java index 4999cde6012fed..f80fc7b0c4a1c5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MemoTestUtils.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MemoTestUtils.java @@ -20,8 +20,10 @@ import org.apache.doris.analysis.UserIdentity; import org.apache.doris.catalog.Env; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.qe.ConnectContext; @@ -79,7 +81,9 @@ public static CascadesContext createCascadesContext(ConnectContext connectContex } public static CascadesContext createCascadesContext(StatementContext statementContext, Plan initPlan) { - CascadesContext cascadesContext = CascadesContext.newContext(statementContext, initPlan); + PhysicalProperties requestProperties = NereidsPlanner.buildInitRequireProperties(initPlan); + CascadesContext cascadesContext = CascadesContext.newContext( + statementContext, initPlan, requestProperties); MemoValidator.validateInitState(cascadesContext.getMemo(), initPlan); return cascadesContext; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index fa7376885eb1b9..87754d1642e7be 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -268,6 +268,46 @@ private void applyExploration(GroupExpression groupExpression, Rule rule) { } } + public PlanChecker applyImplementation(Rule rule) { + return applyImplementation(cascadesContext.getMemo().getRoot(), rule); + } + + private PlanChecker applyImplementation(Group group, Rule rule) { + // copy groupExpressions can prevent ConcurrentModificationException + for (GroupExpression logicalExpression : Lists.newArrayList(group.getLogicalExpressions())) { + applyImplementation(logicalExpression, rule); + } + + for (GroupExpression physicalExpression : Lists.newArrayList(group.getPhysicalExpressions())) { + applyImplementation(physicalExpression, rule); + } + return this; + } + + private PlanChecker applyImplementation(GroupExpression groupExpression, Rule rule) { + GroupExpressionMatching matchResult = new GroupExpressionMatching(rule.getPattern(), groupExpression); + + for (Plan before : matchResult) { + List afters = rule.transform(before, cascadesContext); + for (Plan after : afters) { + if (before != after) { + cascadesContext.getMemo().copyIn(after, before.getGroupExpression().get().getOwnerGroup(), false); + } + } + } + + for (Group childGroup : groupExpression.children()) { + for (GroupExpression logicalExpression : childGroup.getLogicalExpressions()) { + applyImplementation(logicalExpression, rule); + } + + for (GroupExpression physicalExpression : childGroup.getPhysicalExpressions()) { + applyImplementation(physicalExpression, rule); + } + } + return this; + } + public PlanChecker deriveStats() { cascadesContext.pushJob( new DeriveStatsJob(cascadesContext.getMemo().getRoot().getLogicalExpression(), diff --git a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java index 3da84c62c35764..d6190b18919355 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java +++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java @@ -53,8 +53,10 @@ import org.apache.doris.common.MetaNotFoundException; import org.apache.doris.common.util.SqlParserUtils; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.planner.Planner; @@ -177,7 +179,8 @@ protected StatementContext createStatementCtx(String sql) { protected CascadesContext createCascadesContext(String sql) { StatementContext statementCtx = createStatementCtx(sql); LogicalPlan initPlan = new NereidsParser().parseSingle(sql); - return CascadesContext.newContext(statementCtx, initPlan); + PhysicalProperties requestProperties = NereidsPlanner.buildInitRequireProperties(initPlan); + return CascadesContext.newContext(statementCtx, initPlan, requestProperties); } public LogicalPlan analyze(String sql) { diff --git a/regression-test/data/nereids_syntax_p0/aggregate_strategies.out b/regression-test/data/nereids_syntax_p0/aggregate_strategies.out new file mode 100644 index 00000000000000..5c8d99b8c3c82f --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/aggregate_strategies.out @@ -0,0 +1,237 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !count_all -- +10 + +-- !count_all_group_by -- +2 +2 +2 +2 +2 + +-- !count_all_group_by_2 -- +2 +2 +2 +2 +2 + +-- !count_2_all_group_by_2 -- +2 2 +2 2 +2 2 +2 2 +2 2 + +-- !count_sum -- +20 10 + +-- !group_select_same -- +0 +1 +2 +3 +4 + +-- !group_select_difference -- +2 +2 +2 +2 +2 + +-- !count_distinct -- +5 + +-- !count_distinct_group_by -- +1 +1 +1 +1 +1 + +-- !count_distinct_group_by_select_key -- +name_0 1 +name_1 1 +name_2 1 +name_3 1 +name_4 1 + +-- !count_distinct_muilti -- +5 + +-- !count_distinct_muilti_group_by -- +1 +1 +1 +1 +1 + +-- !count_distinct_muilti_group_by_select_key -- +name_0 1 +name_1 1 +name_2 1 +name_3 1 +name_4 1 + +-- !count_distinct_sum_distinct_same -- +4 10 + +-- !count_distinct_sum_distinct_same -- +4 10 + +-- !count_distinct_sum_distinct_difference -- +5 10 + +-- !count_distinct_sum_distinct_group_by -- +1 0 +1 1 +1 2 +1 3 +1 4 + +-- !count_distinct_sum_distinct_group_by_select_key -- +name_0 1 0 +name_1 1 1 +name_2 1 2 +name_3 1 3 +name_4 1 4 + +-- !group_by_all_group_by -- +0 +1 +2 +3 +4 + +-- !group_by_partial_group_by -- +0 +1 +2 +3 +4 + +-- !group_by_count_distinct_sum_distinct -- +5 5 + +-- !group_by_count_distinct -- +5 + +-- !count_all -- +10 + +-- !count_all_group_by -- +2 +2 +2 +2 +2 + +-- !count_all_group_by_2 -- +2 +2 +2 +2 +2 + +-- !count_2_all_group_by_2 -- +2 2 +2 2 +2 2 +2 2 +2 2 + +-- !count_sum -- +20 10 + +-- !group_select_same -- +0 +1 +2 +3 +4 + +-- !group_select_difference -- +2 +2 +2 +2 +2 + +-- !count_distinct -- +5 + +-- !count_distinct_group_by -- +1 +1 +1 +1 +1 + +-- !count_distinct_group_by_select_key -- +name_0 1 +name_1 1 +name_2 1 +name_3 1 +name_4 1 + +-- !count_distinct_muilti -- +5 + +-- !count_distinct_muilti_group_by -- +1 +1 +1 +1 +1 + +-- !count_distinct_muilti_group_by_select_key -- +name_0 1 +name_1 1 +name_2 1 +name_3 1 +name_4 1 + +-- !count_distinct_sum_distinct_same -- +4 10 + +-- !count_distinct_sum_distinct_same -- +4 10 + +-- !count_distinct_sum_distinct_difference -- +5 10 + +-- !count_distinct_sum_distinct_group_by -- +1 0 +1 1 +1 2 +1 3 +1 4 + +-- !count_distinct_sum_distinct_group_by_select_key -- +name_0 1 0 +name_1 1 1 +name_2 1 2 +name_3 1 3 +name_4 1 4 + +-- !group_by_all_group_by -- +0 +1 +2 +3 +4 + +-- !group_by_partial_group_by -- +0 +1 +2 +3 +4 + +-- !group_by_count_distinct_sum_distinct -- +5 5 + +-- !group_by_count_distinct -- +5 + diff --git a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy new file mode 100644 index 00000000000000..977c51f0e44aa0 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("aggregate_strategies") { + + def test_aggregate_strategies = { tableName, bucketNum -> + sql "SET enable_fallback_to_original_planner=true" + + sql "drop table if exists $tableName" + sql """CREATE TABLE `$tableName` ( + `id` int(11) NOT NULL, + `name` varchar(32) NULL + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + COMMENT 'OLAP' + DISTRIBUTED BY HASH(`id`) BUCKETS $bucketNum + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "in_memory" = "false", + "storage_format" = "V2", + "disable_auto_compaction" = "false" + );""" + + + // insert 10 rows, with duplicate + sql "insert into $tableName select number, concat('name_', number) from numbers('number'='5')" + sql "insert into $tableName select number, concat('name_', number) from numbers('number'='5')" + + + sql "SET enable_vectorized_engine=true" + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + order_qt_count_all "select count(*) from $tableName" + order_qt_count_all_group_by "select count(*) from $tableName group by id" + order_qt_count_all_group_by_2 "select count(*) from $tableName group by id, name" + order_qt_count_2_all_group_by_2 "select count(*), count(*) from $tableName group by id, name" + order_qt_count_sum "select sum(id), count(name) from $tableName" + order_qt_group_select_same "select id from $tableName group by id" + order_qt_group_select_difference "select count(name) from $tableName group by id" + + order_qt_count_distinct "select count(distinct id) from $tableName" + + /* + * should not use streaming, there has some bug in be will compute wrong result. + * + * the case is: + * ``` + * CREATE TABLE `n` ( + * `id` bigint NOT NULL + * ) ENGINE=OLAP + * DUPLICATE KEY(`id`) + * COMMENT 'OLAP' + * DISTRIBUTED BY HASH(`id`) BUCKETS 1 + * PROPERTIES ( + * "replication_allocation" = "tag.location.default: 1", + * "in_memory" = "false", + * "storage_format" = "V2", + * "disable_auto_compaction" = "false" + * ); + * + * insert into n select number from numbers('number'='10000000'); + * insert into n select number from numbers('number'='10000000'); + * ``` + * + * when open streaming aggregate, the result is 19999800, but the correct result is 10000000 + */ + explain { + sql """ + select + /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct id) + from $tableName + """ + + notContains "STREAMING" + } + + explain { + sql """ + select count(*) + from ( + select id + from $tableName + group by id + )a + """ + + notContains "STREAMING" + } + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct id) + from $tableName""" + result([[5L]]) + } + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct id) + from $tableName""" + result([[5L]]) + } + + order_qt_count_distinct_group_by "select count(distinct id) from $tableName group by name" + order_qt_count_distinct_group_by_select_key "select name, count(distinct id) from $tableName group by name" + order_qt_count_distinct_muilti "select count(distinct id, name) from $tableName" + order_qt_count_distinct_muilti_group_by "select count(distinct id, name) from $tableName group by name" + order_qt_count_distinct_muilti_group_by_select_key "select name, count(distinct id, name) from $tableName group by name" + + order_qt_count_distinct_sum_distinct_same "select max(distinct id), sum(distinct id) from $tableName" + + // explain plan select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI')*/ max(distinct id), sum(distinct id) from test_bucket1_table; + + order_qt_count_distinct_sum_distinct_same "select max(distinct id), sum(distinct id) from $tableName" + order_qt_count_distinct_sum_distinct_difference "select count(distinct name), sum(distinct id) from $tableName" + + order_qt_count_distinct_sum_distinct_group_by """ + select count(distinct name), sum(distinct id) + from $tableName group by name""" + + order_qt_count_distinct_sum_distinct_group_by_select_key """ + select name, count(distinct name), sum(distinct id) + from $tableName group by name""" + + order_qt_group_by_all_group_by """ + select id + from ( + select id, name + from $tableName + group by id, name + )a + group by id""" + + order_qt_group_by_partial_group_by """ + select id + from ( + select id, name + from $tableName + group by name, id + )a + group by id""" + + order_qt_group_by_count_distinct_sum_distinct """ + select c, c from (select count(distinct id) as c, sum(distinct id) as s + from $tableName)a group by c, s""" + + order_qt_group_by_count_distinct """ + select c + from ( + select count(distinct id) as c + from $tableName + )a + group by c""" + + + test { + sql "select count(distinct id, name), count(distinct id) from $tableName" + exception "The query contains multi count distinct or sum distinct, each can't have multi columns" + } + } + + test_aggregate_strategies('test_bucket1_table', 1) + test_aggregate_strategies('test_bucket10_table', 10) + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct number) + from numbers('number' = '10000000', 'backend_num'='10')""" + result([[10000000L]]) + } + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct number) + from numbers('number' = '10000000', 'backend_num'='10')""" + result([[10000000L]]) + } + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct number) + from numbers('number' = '10000000', 'backend_num'='1')""" + result([[10000000L]]) + } + + test { + sql """select + /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ + count(distinct number) + from numbers('number' = '10000000', 'backend_num'='1')""" + result([[10000000L]]) + } +} diff --git a/regression-test/suites/nereids_syntax_p0/explain.groovy b/regression-test/suites/nereids_syntax_p0/explain.groovy index 473fdffcada63e..734c33fc691988 100644 --- a/regression-test/suites/nereids_syntax_p0/explain.groovy +++ b/regression-test/suites/nereids_syntax_p0/explain.groovy @@ -24,8 +24,8 @@ suite("nereids_explain") { explain { sql("select count(2) + 1, sum(2) + sum(lo_suppkey) from lineorder") - contains "(sum(2) + sum(lo_suppkey))[#24]" - contains "project output tuple id: 3" + contains "(sum(2) + sum(lo_suppkey))[#" + contains "project output tuple id: 1" }