Skip to content

Commit

Permalink
[refactor](Nerieds) Refactor aggregate function/plan/rules and suppor…
Browse files Browse the repository at this point in the history
…t related cbo rules (#14827)

# Proposed changes

## refactor
- add AggregateExpression to shield the difference of AggregateFunction before disassemble and after
- request `GATHER` physicalProperties for query, because query always collect result to the coordinator, use `GATHER` maybe select a better plan
- refactor `NormalizeAggregate`
- remove some physical fields for the `LogicalAggregate`, like `AggPhase`, `isDisassemble`
- remove `AggregateDisassemble` and `DistinctAggregateDisassemble`, and use `AggregateStrategies` to generate various of PhysicalHashAggregate, like `two phases aggregate`, `three phases aggregate`, and cascades can auto select the lowest cost alternative.
- move `PushAggregateToOlapScan` to `AggregateStrategies`
- separate the traverse and visit method in FoldConstantRuleOnFE
  - if some expression not implement the visit method, the traverse method can handle and rewrite the children by default
  - if some expression implement the visit, the user defined traverse(invoke accept/visit method) will quickly return because the default visit method will not forward to the children, and the pre-process in traverse method will not be skipped.

## new feature
- support `disable_nereids_rules` to skip some rules.

example:

1. create 1 bucket table `n`
```sql
CREATE TABLE `n` (
  `id` bigint(20) NOT NULL
) ENGINE=OLAP
DUPLICATE KEY(`id`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"in_memory" = "false",
"storage_format" = "V2",
"disable_auto_compaction" = "false"
);
```

2. insert some rows into `n`
```sql
insert into n select * from numbers('number'='20000000')
```

3. query table `n`
```sql
SET enable_nereids_planner=true;
SET enable_vectorized_engine=true;
SET enable_fallback_to_original_planner=false;
explain plan select id from n group by id;
```

the result show that we use the one stage aggregate
```
| PhysicalHashAggregate ( aggPhase=LOCAL, aggMode=INPUT_TO_RESULT, groupByExpr=[id#0], outputExpr=[id#0], partitionExpr=Optional.empty, requestProperties=[GATHER], stats=(rows=1, width=1, penalty=2.0E7) ) |
| +--PhysicalProject ( projects=[id#0], stats=(rows=20000000, width=1, penalty=0.0) )                                                                                                                                                                                                |
|    +--PhysicalOlapScan ( qualified=default_cluster:test.n, output=[id#0, name#1], stats=(rows=20000000, width=1, penalty=0.0) )                                                                                                                                                    |
```

4. disable one stage aggregate
```sql
explain plan select
  /*+SET_VAR(disable_nereids_rules=DISASSEMBLE_ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT)*/
  id
from n
group by id
```

the result is two stage aggregate
```
| PhysicalHashAggregate ( aggPhase=GLOBAL, aggMode=BUFFER_TO_RESULT, groupByExpr=[id#0], outputExpr=[id#0], partitionExpr=Optional[[id#0]], requestProperties=[GATHER], stats=(rows=1, width=1, penalty=2.0E7) ) |
| +--PhysicalHashAggregate ( aggPhase=LOCAL, aggMode=INPUT_TO_BUFFER, groupByExpr=[id#0], outputExpr=[id#0], partitionExpr=Optional[[id#0]], requestProperties=[ANY], stats=(rows=1, width=1, penalty=2.0E7) )     |
|    +--PhysicalProject ( projects=[id#0], stats=(rows=20000000, width=1, penalty=0.0) )                                                                                                                                                                                                   |
|       +--PhysicalOlapScan ( qualified=default_cluster:test.n, output=[id#0, name#1], stats=(rows=20000000, width=1, penalty=0.0) )                                                                                                                                                       |
```
  • Loading branch information
924060929 authored Dec 18, 2022
1 parent 13bc8c2 commit af4d9b6
Show file tree
Hide file tree
Showing 144 changed files with 5,170 additions and 2,580 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,17 @@ public static AggregateInfo create(
* Used by new optimizer.
*/
public static AggregateInfo create(
ArrayList<Expr> groupingExprs, ArrayList<FunctionCallExpr> aggExprs,
TupleDescriptor tupleDesc, TupleDescriptor intermediateTupleDesc, AggPhase phase) {
ArrayList<Expr> groupingExprs, ArrayList<FunctionCallExpr> aggExprs, List<Integer> aggExprIds,
boolean isPartialAgg, TupleDescriptor tupleDesc, TupleDescriptor intermediateTupleDesc, AggPhase phase) {
AggregateInfo result = new AggregateInfo(groupingExprs, aggExprs, phase);
result.outputTupleDesc = tupleDesc;
result.intermediateTupleDesc = intermediateTupleDesc;
int aggExprSize = result.getAggregateExprs().size();
for (int i = 0; i < aggExprSize; i++) {
result.materializedSlots.add(i);
String label = (isPartialAgg ? "partial_" : "")
+ aggExprs.get(i).toSql() + "[#" + aggExprIds.get(i) + "]";
result.materializedSlotLabels.add(label);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public abstract class AggregateInfoBase {
// exprs that need to be materialized.
// Populated in materializeRequiredSlots() which must be implemented by subclasses.
protected ArrayList<Integer> materializedSlots = Lists.newArrayList();
protected List<String> materializedSlotLabels = Lists.newArrayList();

protected AggregateInfoBase(ArrayList<Expr> groupingExprs,
ArrayList<FunctionCallExpr> aggExprs) {
Expand All @@ -94,6 +95,7 @@ protected AggregateInfoBase(AggregateInfoBase other) {
intermediateTupleDesc = other.intermediateTupleDesc;
outputTupleDesc = other.outputTupleDesc;
materializedSlots = Lists.newArrayList(other.materializedSlots);
materializedSlotLabels = Lists.newArrayList(other.materializedSlotLabels);
}

/**
Expand Down Expand Up @@ -234,6 +236,10 @@ public TupleId getOutputTupleId() {
return outputTupleDesc.getId();
}

public List<String> getMaterializedAggregateExprLabels() {
return Lists.newArrayList(materializedSlotLabels);
}

public boolean requiresIntermediateTuple() {
Preconditions.checkNotNull(intermediateTupleDesc);
Preconditions.checkNotNull(outputTupleDesc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public FunctionBuilder findFunctionBuilder(String name, List<? extends Object> a
.filter(functionBuilder -> functionBuilder.canApply(arguments))
.collect(Collectors.toList());
if (candidateBuilders.isEmpty()) {
String candidateHints = getCandidateHint(name, candidateBuilders);
String candidateHints = getCandidateHint(name, functionBuilders);
throw new AnalysisException("Can not found function '" + name
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
}
Expand All @@ -93,6 +93,6 @@ private void registerBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Bu
public String getCandidateHint(String name, List<FunctionBuilder> candidateBuilders) {
return candidateBuilders.stream()
.map(builder -> name + builder.toString())
.collect(Collectors.joining(", "));
.collect(Collectors.joining(", ", "[", "]"));
}
}
4 changes: 2 additions & 2 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ public BaseAnalysisTask createAnalysisTask(AnalysisTaskScheduler scheduler, Anal
* @return estimated row count
*/
public long estimatedRowCount() {
long cardinality = 1;
long cardinality = 0;
if (this instanceof OlapTable) {
OlapTable table = (OlapTable) this;
for (long selectedPartitionId : table.getPartitionIds()) {
Expand All @@ -527,6 +527,6 @@ public long estimatedRowCount() {
cardinality += baseIndex.getRowCount();
}
}
return cardinality;
return Math.max(cardinality, 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.common.util;

import com.google.common.collect.ImmutableMap;
import org.apache.logging.log4j.Logger;

import java.io.ByteArrayOutputStream;
Expand All @@ -26,10 +27,21 @@
import java.lang.management.ThreadMXBean;
import java.lang.reflect.Constructor;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

public class ReflectionUtils {
private static final Class[] emptyArray = new Class[]{};
private static final Map<Class, Class> boxToPrimitiveTypes = ImmutableMap.<Class, Class>builder()
.put(Boolean.class, boolean.class)
.put(Character.class, char.class)
.put(Byte.class, byte.class)
.put(Short.class, short.class)
.put(Integer.class, int.class)
.put(Long.class, long.class)
.put(Float.class, float.class)
.put(Double.class, double.class)
.build();

/**
* Cache of constructors for each class. Pins the classes so they
Expand Down Expand Up @@ -162,4 +174,11 @@ static void clearCache() {
static int getCacheSize() {
return CONSTRUCTOR_CACHE.size();
}

public static Optional<Class> getPrimitiveType(Class<?> targetClass) {
if (targetClass.isPrimitive()) {
return Optional.of(targetClass);
}
return Optional.ofNullable(boxToPrimitiveTypes.get(targetClass));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ public class CascadesContext {

private List<Table> tables = null;

public CascadesContext(Memo memo, StatementContext statementContext) {
this(memo, statementContext, new CTEContext());
public CascadesContext(Memo memo, StatementContext statementContext, PhysicalProperties requestProperties) {
this(memo, statementContext, new CTEContext(), requestProperties);
}

/**
Expand All @@ -86,20 +86,22 @@ public CascadesContext(Memo memo, StatementContext statementContext) {
* @param memo {@link Memo} reference
* @param statementContext {@link StatementContext} reference
*/
public CascadesContext(Memo memo, StatementContext statementContext, CTEContext cteContext) {
public CascadesContext(Memo memo, StatementContext statementContext,
CTEContext cteContext, PhysicalProperties requireProperties) {
this.memo = memo;
this.statementContext = statementContext;
this.ruleSet = new RuleSet();
this.jobPool = new JobStack();
this.jobScheduler = new SimpleJobScheduler();
this.currentJobContext = new JobContext(this, PhysicalProperties.ANY, Double.MAX_VALUE);
this.currentJobContext = new JobContext(this, requireProperties, Double.MAX_VALUE);
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
this.cteContext = cteContext;
}

public static CascadesContext newContext(StatementContext statementContext, Plan initPlan) {
return new CascadesContext(new Memo(initPlan), statementContext);
public static CascadesContext newContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return new CascadesContext(new Memo(initPlan), statementContext, requireProperties);
}

public NereidsAnalyzer newAnalyzer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
Expand Down Expand Up @@ -86,7 +88,10 @@ public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;

ExplainLevel explainLevel = getExplainLevel(queryStmt.getExplainOptions());
Plan resultPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, explainLevel);

LogicalPlan parsedPlan = logicalPlanAdapter.getLogicalPlan();
PhysicalProperties requireProperties = buildInitRequireProperties(parsedPlan);
Plan resultPlan = plan(parsedPlan, requireProperties, explainLevel);
if (explainLevel.isPlanLevel) {
return;
}
Expand Down Expand Up @@ -129,11 +134,11 @@ public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties)
* Do analyze and optimize for query plan.
*
* @param plan wait for plan
* @param outputProperties physical properties constraints
* @param requireProperties request physical properties constraints
* @return plan generated by this planner
* @throws AnalysisException throw exception if failed in ant stage
*/
public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainLevel explainLevel) {
public Plan plan(LogicalPlan plan, PhysicalProperties requireProperties, ExplainLevel explainLevel) {
if (explainLevel == ExplainLevel.PARSED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
parsedPlan = plan;
if (explainLevel == ExplainLevel.PARSED_PLAN) {
Expand All @@ -144,7 +149,8 @@ public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainL
// pre-process logical plan out of memo, e.g. process SET_VAR hint
plan = preprocess(plan);

initCascadesContext(plan);
initCascadesContext(plan, requireProperties);

try (Lock lock = new Lock(plan, cascadesContext)) {
// resolve column, table and function
analyze();
Expand Down Expand Up @@ -174,7 +180,7 @@ public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainL
// cost-based optimize and explore plan space
optimize();

PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), requireProperties);

// post-process physical plan out of memo, just for future use.
physicalPlan = postProcess(physicalPlan);
Expand All @@ -190,8 +196,8 @@ private LogicalPlan preprocess(LogicalPlan logicalPlan) {
return new PlanPreprocessors(statementContext).process(logicalPlan);
}

private void initCascadesContext(LogicalPlan plan) {
cascadesContext = CascadesContext.newContext(statementContext, plan);
private void initCascadesContext(LogicalPlan plan, PhysicalProperties requireProperties) {
cascadesContext = CascadesContext.newContext(statementContext, plan, requireProperties);
}

private void analyze() {
Expand Down Expand Up @@ -258,7 +264,8 @@ private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physical
throws AnalysisException {
try {
GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second;
() -> new AnalysisException("lowestCostPlans with physicalProperties("
+ physicalProperties + ") doesn't exist in root group")).second;
List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);

List<Plan> planChildren = Lists.newArrayList();
Expand Down Expand Up @@ -335,6 +342,11 @@ public CascadesContext getCascadesContext() {
return cascadesContext;
}

public static PhysicalProperties buildInitRequireProperties(Plan initPlan) {
boolean isQuery = !(initPlan instanceof Command) || (initPlan instanceof ExplainCommand);
return isQuery ? PhysicalProperties.GATHER : PhysicalProperties.ANY;
}

private ExplainLevel getExplainLevel(ExplainOptions explainOptions) {
if (explainOptions == null) {
return ExplainLevel.NONE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Maps;

import java.util.Map;
import javax.annotation.concurrent.GuardedBy;

/**
* Statement context for nereids
*/
Expand All @@ -37,6 +44,9 @@ public class StatementContext {

private final IdGenerator<RelationId> relationIdGenerator = RelationId.createGenerator();

@GuardedBy("this")
private final Map<String, Supplier<Object>> contextCacheMap = Maps.newLinkedHashMap();

private StatementBase parsedStatement;

public StatementContext() {
Expand Down Expand Up @@ -79,4 +89,14 @@ public RelationId getNextRelationId() {
public void setParsedStatement(StatementBase parsedStatement) {
this.parsedStatement = parsedStatement;
}

/** getOrRegisterCache */
public synchronized <T> T getOrRegisterCache(String key, Supplier<T> cacheSupplier) {
Supplier<T> supplier = (Supplier<T>) contextCacheMap.get(key);
if (supplier == null) {
contextCacheMap.put(key, (Supplier<Object>) Suppliers.memoize(cacheSupplier));
supplier = cacheSupplier;
}
return supplier.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ public class UnboundFunction extends Expression implements Unbound, PropagateNul

private final String name;
private final boolean isDistinct;
private final boolean isStar;

public UnboundFunction(String name, boolean isDistinct, boolean isStar, List<Expression> arguments) {
public UnboundFunction(String name, boolean isDistinct, List<Expression> arguments) {
super(arguments.toArray(new Expression[0]));
this.name = Objects.requireNonNull(name, "name can not be null");
this.name = Objects.requireNonNull(name, "name cannot be null");
this.isDistinct = isDistinct;
this.isStar = isStar;
}

public String getName() {
Expand All @@ -52,10 +50,6 @@ public boolean isDistinct() {
return isDistinct;
}

public boolean isStar() {
return isStar;
}

public List<Expression> getArguments() {
return children();
}
Expand All @@ -65,13 +59,13 @@ public String toSql() throws UnboundException {
String params = children.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
return name + "(" + (isDistinct ? "DISTINCT " : "") + params + ")";
return name + "(" + (isDistinct ? "distinct " : "") + params + ")";
}

@Override
public String toString() {
String params = Joiner.on(", ").join(children);
return "'" + name + "(" + (isDistinct ? "DISTINCT " : "") + params + ")";
return "'" + name + "(" + (isDistinct ? "distinct " : "") + params + ")";
}

@Override
Expand All @@ -81,7 +75,7 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public UnboundFunction withChildren(List<Expression> children) {
return new UnboundFunction(name, isDistinct, isStar, children);
return new UnboundFunction(name, isDistinct, children);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Indicating that the current rule depends on other rules.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface DependsRules {
/** depends rules */
Class[] value() default {};
}
Loading

0 comments on commit af4d9b6

Please sign in to comment.