Skip to content

Commit

Permalink
[pick](Branch2.0) generate left deep tree when stats is unknown (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
keanji-x authored and kaka11chen committed Oct 24, 2023
1 parent 7453019 commit 4220806
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
Expand Down Expand Up @@ -311,17 +312,53 @@ public Cost visitPhysicalHashJoin(
}
// TODO: since the outputs rows may expand a lot, penalty on it will cause bc never be chosen.
// will refine this in next generation cost model.
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
// forbid broadcast join when stats is unknown
return CostV1.of(rightRowCount * buildSideFactor + 1 / leftRowCount,
rightRowCount,
0
);
}
return CostV1.of(leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor,
rightRowCount,
0
);
}
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
return CostV1.of(rightRowCount + 1 / leftRowCount,
rightRowCount,
0);
}
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0
);
}

private boolean isStatsUnknown(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}

private boolean isStatsUnknown(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}

@Override
public Cost visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
Expand All @@ -330,7 +367,11 @@ public Cost visitPhysicalNestedLoopJoin(
Preconditions.checkState(context.arity() == 2);
Statistics leftStatistics = context.getChildStatistics(0);
Statistics rightStatistics = context.getChildStatistics(1);

if (isStatsUnknown(nestedLoopJoin, leftStatistics, rightStatistics)) {
return CostV1.of(rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
}
return CostV1.of(
leftStatistics.getRowCount() * rightStatistics.getRowCount(),
rightStatistics.getRowCount(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ
}

private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = (leftStats.getRowCount() + rightStats.getRowCount());
// We do more like the nested loop join with one rows than inner join
if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) {
rowCount *= 0.99;
} else {
rowCount *= 1.01;
}
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
}
return new StatisticsBuilder()
.setRowCount(Math.max(1, leftStats.getRowCount() * rightStats.getRowCount()))
.putColumnStatistics(leftStats.columnStatistics())
Expand All @@ -156,7 +171,7 @@ private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics

private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
double rowCount = leftStats.getRowCount() + rightStats.getRowCount();
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -34,13 +35,15 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Physical hash join plan.
Expand Down Expand Up @@ -215,6 +218,11 @@ public int compare(Expression e1, Expression e2) {
}
}

public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}

@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -31,11 +32,13 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/**
* Use nested loop algorithm to do join.
Expand Down Expand Up @@ -169,6 +172,11 @@ public boolean isBitMapRuntimeFilterConditionsEmpty() {
return bitMapRuntimeFilterConditions.isEmpty();
}

public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}

@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder("NestedLoopJoin");
Expand Down

0 comments on commit 4220806

Please sign in to comment.