diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 3c6eaa5999ff064..fecdbf650c3d731 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -64,6 +64,7 @@ import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.Statistics; +import org.apache.doris.statistics.StatisticsBuilder; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -619,7 +620,7 @@ public void updateConsumerStats(CTEId cteId, Statistics statistics) { List, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().get(cteId); for (Pair, Group> p : consumerGroups) { Map producerSlotToConsumerSlot = p.first; - Statistics updatedConsumerStats = new Statistics(statistics); + Statistics updatedConsumerStats = new StatisticsBuilder(statistics).build(); for (Entry entry : statistics.columnStatistics().entrySet()) { updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index efdd46f821f9b32..6cbaddfee6aaa9e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Not; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; @@ -30,8 +29,6 @@ import org.apache.doris.qe.ConnectContext; import org.apache.doris.thrift.TRuntimeFilterType; -import java.util.List; - /** * Join Commute */ @@ -114,9 +111,10 @@ public static boolean isZigZagJoin(LogicalJoin join) { } private static boolean containJoin(GroupPlan groupPlan) { - // TODO: tmp way to judge containJoin - List output = groupPlan.getOutput(); - return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals); + return groupPlan.getGroup().getStatistics().getWidthInJoinCluster() > 1; + // // TODO: tmp way to judge containJoin + // List output = groupPlan.getOutput(); + // return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index b716b350f24b840..9c9250dd26f74f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -331,7 +331,7 @@ A not in (1, 2, 3, 100): selectivity = 1.0; } } - Statistics estimated = new Statistics(context.statistics); + Statistics estimated = new StatisticsBuilder(context.statistics).build(); estimated = estimated.withSel(selectivity); estimated.addColumnStats(compareExpr, compareExprStatsBuilder.build()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 8495411745941bf..e30e27fbc7d9cbf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -350,8 +350,12 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN @Override public Statistics visitLogicalJoin(LogicalJoin join, Void context) { - return JoinEstimation.estimate(groupExpression.childStatistics(0), + Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0), groupExpression.childStatistics(1), join); + joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster( + groupExpression.childStatistics(0).getWidthInJoinCluster() + + groupExpression.childStatistics(1).getWidthInJoinCluster()).build(); + return joinStats; } @Override @@ -542,6 +546,7 @@ public Statistics visitPhysicalGenerate(PhysicalGenerate generat private Statistics computeAssertNumRows(long desiredNumOfRows) { Statistics statistics = groupExpression.childStatistics(0); statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount())); + statistics = new StatisticsBuilder(statistics).setWidthInJoinCluster(1).build(); return statistics; } @@ -751,7 +756,7 @@ private Statistics computeAggregate(Aggregate aggregate) { builder.setDataSize(rowCount * outputExpression.getDataType().width()); slotToColumnStats.put(outputExpression.toSlot(), columnStat); } - return new Statistics(rowCount, slotToColumnStats); + return new Statistics(rowCount, 1, slotToColumnStats); // TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats } @@ -770,7 +775,7 @@ private Statistics computeRepeat(Repeat repeat) { .setDataSize(stats.dataSize < 0 ? stats.dataSize : stats.dataSize * groupingSetNum); return Pair.of(kv.getKey(), columnStatisticBuilder.build()); }).collect(Collectors.toMap(Pair::key, Pair::value)); - return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, columnStatisticMap); + return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, 1, columnStatisticMap); } private Statistics computeProject(Project project) { @@ -780,7 +785,7 @@ private Statistics computeProject(Project project) { ColumnStatistic columnStatistic = ExpressionEstimation.estimate(projection, childStats); return new SimpleEntry<>(projection.toSlot(), columnStatistic); }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (item1, item2) -> item1)); - return new Statistics(childStats.getRowCount(), columnsStats); + return new Statistics(childStats.getRowCount(), childStats.getWidthInJoinCluster(), columnsStats); } private Statistics computeOneRowRelation(List projects) { @@ -792,7 +797,7 @@ private Statistics computeOneRowRelation(List projects) { }) .collect(Collectors.toMap(Pair::key, Pair::value)); int rowCount = 1; - return new Statistics(rowCount, columnStatsMap); + return new Statistics(rowCount, 1, columnStatsMap); } private Statistics computeEmptyRelation(EmptyRelation emptyRelation) { @@ -807,7 +812,7 @@ private Statistics computeEmptyRelation(EmptyRelation emptyRelation) { }) .collect(Collectors.toMap(Pair::key, Pair::value)); int rowCount = 0; - return new Statistics(rowCount, columnStatsMap); + return new Statistics(rowCount, 1, columnStatsMap); } private Statistics computeUnion(Union union) { @@ -850,7 +855,7 @@ private Statistics computeUnion(Union union) { statisticsBuilder.setRowCount(leftRowCount); statisticsBuilder.putColumnStatistics(unionOutput.get(i), headStats.findColumnStatistics(headSlot)); } - return statisticsBuilder.build(); + return statisticsBuilder.setWidthInJoinCluster(1).build(); } private Statistics computeExcept(SetOperation setOperation) { @@ -863,7 +868,7 @@ private Statistics computeExcept(SetOperation setOperation) { statisticsBuilder.putColumnStatistics(operatorOutput.get(i), columnStatistic); } statisticsBuilder.setRowCount(leftStats.getRowCount()); - return statisticsBuilder.build(); + return statisticsBuilder.setWidthInJoinCluster(1).build(); } private Statistics computeIntersect(SetOperation setOperation) { @@ -890,7 +895,8 @@ private Statistics computeIntersect(SetOperation setOperation) { leftChildStats.addColumnStats(outputs.get(i), leftChildStats.findColumnStatistics(leftChildOutputs.get(i))); } - return leftChildStats.withRowCountAndEnforceValid(rowCount); + return new StatisticsBuilder(leftChildStats.withRowCountAndEnforceValid(rowCount)) + .setWidthInJoinCluster(1).build(); } private Statistics computeGenerate(Generate generate) { @@ -912,7 +918,7 @@ private Statistics computeGenerate(Generate generate) { .build(); columnStatsMap.put(output, columnStatistic); } - return new Statistics(count, columnStatsMap); + return new Statistics(count, 1, columnStatsMap); } private Statistics computeWindow(Window windowOperator) { @@ -981,7 +987,7 @@ private Statistics computeWindow(Window windowOperator) { return Pair.of(expr.toSlot(), colStatsBuilder.build()); }).collect(Collectors.toMap(Pair::key, Pair::value)); columnStatisticMap.putAll(childColumnStats); - return new Statistics(childStats.getRowCount(), columnStatisticMap); + return new Statistics(childStats.getRowCount(), 1, columnStatisticMap); } private ColumnStatistic unionColumn(ColumnStatistic leftStats, double leftRowCount, ColumnStatistic rightStats, @@ -1020,7 +1026,8 @@ private Plan tryToFindChild(GroupExpression groupExpression) { @Override public Statistics visitLogicalCTEProducer(LogicalCTEProducer cteProducer, Void context) { - Statistics statistics = groupExpression.childStatistics(0); + StatisticsBuilder builder = new StatisticsBuilder(groupExpression.childStatistics(0)); + Statistics statistics = builder.setWidthInJoinCluster(1).build(); cteIdToStats.put(cteProducer.getCteId(), statistics); return statistics; } @@ -1032,7 +1039,7 @@ public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void c cteConsumer.getProducerToConsumerOutputMap()); Statistics prodStats = cteIdToStats.get(cteId); Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId)); - Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>()); + Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>()); for (Slot slot : cteConsumer.getOutput()) { Slot prodSlot = cteConsumer.getProducerSlot(slot); ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot); @@ -1052,7 +1059,8 @@ public Statistics visitLogicalCTEAnchor(LogicalCTEAnchor cteProducer, Void context) { - Statistics statistics = groupExpression.childStatistics(0); + Statistics statistics = new StatisticsBuilder(groupExpression.childStatistics(0)) + .setWidthInJoinCluster(1).build(); cteIdToStats.put(cteProducer.getCteId(), statistics); cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics); return statistics; @@ -1068,7 +1076,7 @@ public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void prodStats = groupExpression.getOwnerGroup().getStatistics(); } Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId)); - Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>()); + Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>()); for (Slot slot : cteConsumer.getOutput()) { Slot prodSlot = cteConsumer.getProducerSlot(slot); ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot); diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java index 5afcdae06658686..00a32f643569191 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java @@ -33,18 +33,19 @@ public class Statistics { private final double rowCount; private final Map expressionToColumnStats; + private final int widthInJoinCluster; // the byte size of one tuple private double tupleSize; - public Statistics(Statistics another) { - this.rowCount = another.rowCount; - this.expressionToColumnStats = new HashMap<>(another.expressionToColumnStats); - this.tupleSize = another.tupleSize; + public Statistics(double rowCount, Map expressionToColumnStats) { + this(rowCount, 1, expressionToColumnStats); } - public Statistics(double rowCount, Map expressionToColumnStats) { + public Statistics(double rowCount, int widthInJoinCluster, + Map expressionToColumnStats) { this.rowCount = rowCount; + this.widthInJoinCluster = widthInJoinCluster; this.expressionToColumnStats = expressionToColumnStats; } @@ -61,14 +62,14 @@ public double getRowCount() { } public Statistics withRowCount(double rowCount) { - return new Statistics(rowCount, new HashMap<>(expressionToColumnStats)); + return new Statistics(rowCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats)); } /** * Update by count. */ public Statistics withRowCountAndEnforceValid(double rowCount) { - Statistics statistics = new Statistics(rowCount, expressionToColumnStats); + Statistics statistics = new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats); statistics.enforceValid(); return statistics; } @@ -99,7 +100,7 @@ public Statistics withSel(double sel) { return this; } double newCount = rowCount * sel; - return new Statistics(newCount, new HashMap<>(expressionToColumnStats)); + return new Statistics(newCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats)); } public Statistics addColumnStats(Expression expression, ColumnStatistic columnStatistic) { @@ -146,7 +147,7 @@ public String toString() { return "-Infinite"; } DecimalFormat format = new DecimalFormat("#,###.##"); - return format.format(rowCount); + return format.format(rowCount) + " " + widthInJoinCluster; } public int getBENumber() { @@ -181,10 +182,14 @@ public String detail(String prefix) { StringBuilder builder = new StringBuilder(); builder.append(prefix).append("rows=").append(rowCount).append("\n"); builder.append(prefix).append("tupleSize=").append(computeTupleSize()).append("\n"); - + builder.append(prefix).append("width=").append(widthInJoinCluster).append("\n"); for (Entry entry : expressionToColumnStats.entrySet()) { builder.append(prefix).append(entry.getKey()).append(" -> ").append(entry.getValue()).append("\n"); } return builder.toString(); } + + public int getWidthInJoinCluster() { + return widthInJoinCluster; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticsBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticsBuilder.java index a0e75f7df380907..53d8f49cb14c0f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticsBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticsBuilder.java @@ -25,7 +25,7 @@ public class StatisticsBuilder { private double rowCount; - + private int widthInJoinCluster; private final Map expressionToColumnStats; public StatisticsBuilder() { @@ -34,6 +34,7 @@ public StatisticsBuilder() { public StatisticsBuilder(Statistics statistics) { this.rowCount = statistics.getRowCount(); + this.widthInJoinCluster = statistics.getWidthInJoinCluster(); expressionToColumnStats = new HashMap<>(); expressionToColumnStats.putAll(statistics.columnStatistics()); } @@ -43,6 +44,11 @@ public StatisticsBuilder setRowCount(double rowCount) { return this; } + public StatisticsBuilder setWidthInJoinCluster(int widthInJoinCluster) { + this.widthInJoinCluster = widthInJoinCluster; + return this; + } + public StatisticsBuilder putColumnStatistics( Map expressionToColumnStats) { this.expressionToColumnStats.putAll(expressionToColumnStats); @@ -55,6 +61,6 @@ public StatisticsBuilder putColumnStatistics(Expression expression, ColumnStatis } public Statistics build() { - return new Statistics(rowCount, expressionToColumnStats); + return new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats); } }