Skip to content

Commit

Permalink
width
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Jan 2, 2024
1 parent 4cbbd25 commit 07c573e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -619,7 +620,7 @@ public void updateConsumerStats(CTEId cteId, Statistics statistics) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().get(cteId);
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
Statistics updatedConsumerStats = new Statistics(statistics);
Statistics updatedConsumerStats = new StatisticsBuilder(statistics).build();
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterType;

import java.util.List;

/**
* Join Commute
*/
Expand Down Expand Up @@ -114,9 +111,10 @@ public static boolean isZigZagJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
}

private static boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<Slot> output = groupPlan.getOutput();
return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals);
return groupPlan.getGroup().getStatistics().getWidthInJoinCluster() > 1;
// // TODO: tmp way to judge containJoin
// List<Slot> output = groupPlan.getOutput();
// return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ A not in (1, 2, 3, 100):
selectivity = 1.0;
}
}
Statistics estimated = new Statistics(context.statistics);
Statistics estimated = new StatisticsBuilder(context.statistics).build();
estimated = estimated.withSel(selectivity);
estimated.addColumnStats(compareExpr,
compareExprStatsBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,12 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan>

@Override
public Statistics visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster(
groupExpression.childStatistics(0).getWidthInJoinCluster()
+ groupExpression.childStatistics(1).getWidthInJoinCluster()).build();
return joinStats;
}

@Override
Expand Down Expand Up @@ -542,6 +546,7 @@ public Statistics visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generat
private Statistics computeAssertNumRows(long desiredNumOfRows) {
Statistics statistics = groupExpression.childStatistics(0);
statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount()));
statistics = new StatisticsBuilder(statistics).setWidthInJoinCluster(1).build();
return statistics;
}

Expand Down Expand Up @@ -751,7 +756,7 @@ private Statistics computeAggregate(Aggregate<? extends Plan> aggregate) {
builder.setDataSize(rowCount * outputExpression.getDataType().width());
slotToColumnStats.put(outputExpression.toSlot(), columnStat);
}
return new Statistics(rowCount, slotToColumnStats);
return new Statistics(rowCount, 1, slotToColumnStats);
// TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats
}

Expand All @@ -770,7 +775,7 @@ private Statistics computeRepeat(Repeat<? extends Plan> repeat) {
.setDataSize(stats.dataSize < 0 ? stats.dataSize : stats.dataSize * groupingSetNum);
return Pair.of(kv.getKey(), columnStatisticBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, columnStatisticMap);
return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, 1, columnStatisticMap);
}

private Statistics computeProject(Project project) {
Expand All @@ -780,7 +785,7 @@ private Statistics computeProject(Project project) {
ColumnStatistic columnStatistic = ExpressionEstimation.estimate(projection, childStats);
return new SimpleEntry<>(projection.toSlot(), columnStatistic);
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (item1, item2) -> item1));
return new Statistics(childStats.getRowCount(), columnsStats);
return new Statistics(childStats.getRowCount(), childStats.getWidthInJoinCluster(), columnsStats);
}

private Statistics computeOneRowRelation(List<NamedExpression> projects) {
Expand All @@ -792,7 +797,7 @@ private Statistics computeOneRowRelation(List<NamedExpression> projects) {
})
.collect(Collectors.toMap(Pair::key, Pair::value));
int rowCount = 1;
return new Statistics(rowCount, columnStatsMap);
return new Statistics(rowCount, 1, columnStatsMap);
}

private Statistics computeEmptyRelation(EmptyRelation emptyRelation) {
Expand All @@ -807,7 +812,7 @@ private Statistics computeEmptyRelation(EmptyRelation emptyRelation) {
})
.collect(Collectors.toMap(Pair::key, Pair::value));
int rowCount = 0;
return new Statistics(rowCount, columnStatsMap);
return new Statistics(rowCount, 1, columnStatsMap);
}

private Statistics computeUnion(Union union) {
Expand Down Expand Up @@ -850,7 +855,7 @@ private Statistics computeUnion(Union union) {
statisticsBuilder.setRowCount(leftRowCount);
statisticsBuilder.putColumnStatistics(unionOutput.get(i), headStats.findColumnStatistics(headSlot));
}
return statisticsBuilder.build();
return statisticsBuilder.setWidthInJoinCluster(1).build();
}

private Statistics computeExcept(SetOperation setOperation) {
Expand All @@ -863,7 +868,7 @@ private Statistics computeExcept(SetOperation setOperation) {
statisticsBuilder.putColumnStatistics(operatorOutput.get(i), columnStatistic);
}
statisticsBuilder.setRowCount(leftStats.getRowCount());
return statisticsBuilder.build();
return statisticsBuilder.setWidthInJoinCluster(1).build();
}

private Statistics computeIntersect(SetOperation setOperation) {
Expand All @@ -890,7 +895,8 @@ private Statistics computeIntersect(SetOperation setOperation) {
leftChildStats.addColumnStats(outputs.get(i),
leftChildStats.findColumnStatistics(leftChildOutputs.get(i)));
}
return leftChildStats.withRowCountAndEnforceValid(rowCount);
return new StatisticsBuilder(leftChildStats.withRowCountAndEnforceValid(rowCount))
.setWidthInJoinCluster(1).build();
}

private Statistics computeGenerate(Generate generate) {
Expand All @@ -912,7 +918,7 @@ private Statistics computeGenerate(Generate generate) {
.build();
columnStatsMap.put(output, columnStatistic);
}
return new Statistics(count, columnStatsMap);
return new Statistics(count, 1, columnStatsMap);
}

private Statistics computeWindow(Window windowOperator) {
Expand Down Expand Up @@ -981,7 +987,7 @@ private Statistics computeWindow(Window windowOperator) {
return Pair.of(expr.toSlot(), colStatsBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
columnStatisticMap.putAll(childColumnStats);
return new Statistics(childStats.getRowCount(), columnStatisticMap);
return new Statistics(childStats.getRowCount(), 1, columnStatisticMap);
}

private ColumnStatistic unionColumn(ColumnStatistic leftStats, double leftRowCount, ColumnStatistic rightStats,
Expand Down Expand Up @@ -1020,7 +1026,8 @@ private Plan tryToFindChild(GroupExpression groupExpression) {

@Override
public Statistics visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, Void context) {
Statistics statistics = groupExpression.childStatistics(0);
StatisticsBuilder builder = new StatisticsBuilder(groupExpression.childStatistics(0));
Statistics statistics = builder.setWidthInJoinCluster(1).build();
cteIdToStats.put(cteProducer.getCteId(), statistics);
return statistics;
}
Expand All @@ -1032,7 +1039,7 @@ public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void c
cteConsumer.getProducerToConsumerOutputMap());
Statistics prodStats = cteIdToStats.get(cteId);
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>());
for (Slot slot : cteConsumer.getOutput()) {
Slot prodSlot = cteConsumer.getProducerSlot(slot);
ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot);
Expand All @@ -1052,7 +1059,8 @@ public Statistics visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? exten
@Override
public Statistics visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> cteProducer,
Void context) {
Statistics statistics = groupExpression.childStatistics(0);
Statistics statistics = new StatisticsBuilder(groupExpression.childStatistics(0))
.setWidthInJoinCluster(1).build();
cteIdToStats.put(cteProducer.getCteId(), statistics);
cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics);
return statistics;
Expand All @@ -1068,7 +1076,7 @@ public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void
prodStats = groupExpression.getOwnerGroup().getStatistics();
}
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>());
for (Slot slot : cteConsumer.getOutput()) {
Slot prodSlot = cteConsumer.getProducerSlot(slot);
ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ public class Statistics {
private final double rowCount;

private final Map<Expression, ColumnStatistic> expressionToColumnStats;
private final int widthInJoinCluster;

// the byte size of one tuple
private double tupleSize;

public Statistics(Statistics another) {
this.rowCount = another.rowCount;
this.expressionToColumnStats = new HashMap<>(another.expressionToColumnStats);
this.tupleSize = another.tupleSize;
public Statistics(double rowCount, Map<Expression, ColumnStatistic> expressionToColumnStats) {
this(rowCount, 1, expressionToColumnStats);
}

public Statistics(double rowCount, Map<Expression, ColumnStatistic> expressionToColumnStats) {
public Statistics(double rowCount, int widthInJoinCluster,
Map<Expression, ColumnStatistic> expressionToColumnStats) {
this.rowCount = rowCount;
this.widthInJoinCluster = widthInJoinCluster;
this.expressionToColumnStats = expressionToColumnStats;
}

Expand All @@ -61,14 +62,14 @@ public double getRowCount() {
}

public Statistics withRowCount(double rowCount) {
return new Statistics(rowCount, new HashMap<>(expressionToColumnStats));
return new Statistics(rowCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats));
}

/**
* Update by count.
*/
public Statistics withRowCountAndEnforceValid(double rowCount) {
Statistics statistics = new Statistics(rowCount, expressionToColumnStats);
Statistics statistics = new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats);
statistics.enforceValid();
return statistics;
}
Expand Down Expand Up @@ -99,7 +100,7 @@ public Statistics withSel(double sel) {
return this;
}
double newCount = rowCount * sel;
return new Statistics(newCount, new HashMap<>(expressionToColumnStats));
return new Statistics(newCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats));
}

public Statistics addColumnStats(Expression expression, ColumnStatistic columnStatistic) {
Expand Down Expand Up @@ -146,7 +147,7 @@ public String toString() {
return "-Infinite";
}
DecimalFormat format = new DecimalFormat("#,###.##");
return format.format(rowCount);
return format.format(rowCount) + " " + widthInJoinCluster;
}

public int getBENumber() {
Expand Down Expand Up @@ -181,10 +182,14 @@ public String detail(String prefix) {
StringBuilder builder = new StringBuilder();
builder.append(prefix).append("rows=").append(rowCount).append("\n");
builder.append(prefix).append("tupleSize=").append(computeTupleSize()).append("\n");

builder.append(prefix).append("width=").append(widthInJoinCluster).append("\n");
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
builder.append(prefix).append(entry.getKey()).append(" -> ").append(entry.getValue()).append("\n");
}
return builder.toString();
}

public int getWidthInJoinCluster() {
return widthInJoinCluster;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
public class StatisticsBuilder {

private double rowCount;

private int widthInJoinCluster;
private final Map<Expression, ColumnStatistic> expressionToColumnStats;

public StatisticsBuilder() {
Expand All @@ -34,6 +34,7 @@ public StatisticsBuilder() {

public StatisticsBuilder(Statistics statistics) {
this.rowCount = statistics.getRowCount();
this.widthInJoinCluster = statistics.getWidthInJoinCluster();
expressionToColumnStats = new HashMap<>();
expressionToColumnStats.putAll(statistics.columnStatistics());
}
Expand All @@ -43,6 +44,11 @@ public StatisticsBuilder setRowCount(double rowCount) {
return this;
}

public StatisticsBuilder setWidthInJoinCluster(int widthInJoinCluster) {
this.widthInJoinCluster = widthInJoinCluster;
return this;
}

public StatisticsBuilder putColumnStatistics(
Map<Expression, ColumnStatistic> expressionToColumnStats) {
this.expressionToColumnStats.putAll(expressionToColumnStats);
Expand All @@ -55,6 +61,6 @@ public StatisticsBuilder putColumnStatistics(Expression expression, ColumnStatis
}

public Statistics build() {
return new Statistics(rowCount, expressionToColumnStats);
return new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats);
}
}

0 comments on commit 07c573e

Please sign in to comment.