Skip to content

Commit

Permalink
save prune partition
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 8, 2024
1 parent 37ef2e3 commit 08c4e0f
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ public class DateLiteral extends LiteralExpr {

private static final Pattern HAS_OFFSET_PART = Pattern.compile("[\\+\\-]\\d{2}:\\d{2}");

@Override
public boolean equals(Object o) {
if (o instanceof DateLiteral) {
DateLiteral that = (DateLiteral) o;
return year == that.year && month == that.month && day == that.day && hour == that.hour
&& minute == that.minute && second == that.second && microsecond == that.microsecond;
}
return super.equals(o);
}

// Date Literal persist type in meta
private enum DateLiteralType {
DATETIME(0),
Expand Down Expand Up @@ -626,6 +636,26 @@ public int compareLiteral(LiteralExpr expr) {
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
if (expr instanceof DateLiteral) {
DateLiteral other = (DateLiteral) expr;
long yearMonthDay = year * 10000 + month * 100 + day;
long otherYearMonthDay = other.year * 10000 + other.month * 100 + other.day;
long diffDay = yearMonthDay - otherYearMonthDay;
if (diffDay != 0) {
return diffDay < 0 ? -1 : 1;
}

long hourMinuteSecond = hour * 10000 + minute * 100 + second;
long otherHourMinuteSecond = other.hour * 10000 + other.minute * 100 + other.second;
long diffSecond = hourMinuteSecond - otherHourMinuteSecond;
if (diffSecond != 0) {
return diffSecond < 0 ? -1 : 1;
}
long msDiff = this.microsecond - other.microsecond;
return msDiff < 0
? -1
: msDiff == 0 ? 0 : 1;
}
// date time will not overflow when doing addition and subtraction
return getStringValue().compareTo(expr.getStringValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand All @@ -62,11 +63,11 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

/**
* OneRangePartitionEvaluator.
Expand Down Expand Up @@ -102,31 +103,33 @@ public OneRangePartitionEvaluator(long partitionId, List<Slot> partitionSlots,

PartitionRangeExpander expander = new PartitionRangeExpander();
this.partitionSlotTypes = expander.computePartitionSlotTypes(lowers, uppers);
this.slotToType = IntStream.range(0, partitionSlots.size())
.mapToObj(index -> Pair.of(partitionSlots.get(index), partitionSlotTypes.get(index)))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
this.slotToType = Maps.newHashMapWithExpectedSize(partitionSlots.size() * 2);
for (int i = 0; i < partitionSlots.size(); i++) {
slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i));
}

this.partitionSlotContainsNull = IntStream.range(0, partitionSlots.size())
.mapToObj(index -> {
Slot slot = partitionSlots.get(index);
if (!slot.nullable()) {
return Pair.of(slot, false);
}
PartitionSlotType partitionSlotType = partitionSlotTypes.get(index);
boolean maybeNull = false;
switch (partitionSlotType) {
case CONST:
case RANGE:
maybeNull = range.lowerEndpoint().getKeys().get(index).isMinValue();
break;
case OTHER:
maybeNull = true;
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
return Pair.of(slot, maybeNull);
}).collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
this.partitionSlotContainsNull = Maps.newHashMapWithExpectedSize(partitionSlots.size() * 2);
for (int i = 0; i < partitionSlots.size(); i++) {
Slot slot = partitionSlots.get(i);
if (!slot.nullable()) {
partitionSlotContainsNull.put(slot, false);
continue;
}
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
boolean maybeNull = false;
switch (partitionSlotType) {
case CONST:
case RANGE:
maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue();
break;
case OTHER:
maybeNull = true;
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
partitionSlotContainsNull.put(slot, maybeNull);
}

int expandThreshold = cascadesContext.getAndCacheSessionVariable(
"partitionPruningExpandThreshold",
Expand All @@ -147,62 +150,14 @@ public long getPartitionId() {

@Override
public List<Map<Slot, PartitionSlotInput>> getOnePartitionInputs() {
List<Map<Slot, PartitionSlotInput>> onePartitionInputs = Lists.newArrayList();
for (List<Expression> input : inputs) {
boolean previousIsLowerBoundLiteral = true;
boolean previousIsUpperBoundLiteral = true;
List<Pair<Slot, PartitionSlotInput>> slotToInputs = Lists.newArrayList();
for (int i = 0; i < partitionSlots.size(); ++i) {
Slot partitionSlot = partitionSlots.get(i);
// partitionSlot will be replaced to this expression
Expression expression = input.get(i);
ColumnRange slotRange = null;
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
if (expression instanceof Literal) {
// const or expanded range
slotRange = ColumnRange.singleton((Literal) expression);
if (!expression.equals(lowers.get(i))) {
previousIsLowerBoundLiteral = false;
}
if (!expression.equals(uppers.get(i))) {
previousIsUpperBoundLiteral = false;
}
} else {
// un expanded range
switch (partitionSlotType) {
case RANGE:
boolean isLastPartitionColumn = i + 1 == partitionSlots.size();
BoundType rightBoundType = isLastPartitionColumn
? BoundType.OPEN : BoundType.CLOSED;
slotRange = ColumnRange.range(
lowers.get(i), BoundType.CLOSED, uppers.get(i), rightBoundType);
break;
case OTHER:
if (previousIsLowerBoundLiteral) {
slotRange = ColumnRange.atLeast(lowers.get(i));
} else if (previousIsUpperBoundLiteral) {
slotRange = ColumnRange.lessThen(uppers.get(i));
} else {
// unknown range
slotRange = ColumnRange.all();
}
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
previousIsLowerBoundLiteral = false;
previousIsUpperBoundLiteral = false;
}
ImmutableMap<Slot, ColumnRange> slotToRange = ImmutableMap.of(partitionSlot, slotRange);
slotToInputs.add(Pair.of(partitionSlot, new PartitionSlotInput(expression, slotToRange)));
}

Map<Slot, PartitionSlotInput> slotPartitionSlotInputMap = fillSlotRangesToInputs(
slotToInputs.stream()
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value)));
onePartitionInputs.add(slotPartitionSlotInputMap);
if (partitionSlots.size() == 1 && inputs.size() == 1 && inputs.get(0).size() == 1
&& inputs.get(0).get(0) instanceof Literal) {
// fast path
return computeSinglePartitionValueInputs();
} else {
// slow path
return commonComputeOnePartitionInputs();
}
return onePartitionInputs;
}

@Override
Expand Down Expand Up @@ -597,13 +552,14 @@ private EvaluateRangeResult mergeRanges(
}

private List<Literal> toNereidsLiterals(PartitionKey partitionKey) {
return IntStream.range(0, partitionKey.getKeys().size())
.mapToObj(index -> {
LiteralExpr literalExpr = partitionKey.getKeys().get(index);
PrimitiveType primitiveType = partitionKey.getTypes().get(index);
Type type = Type.fromPrimitiveType(primitiveType);
return Literal.fromLegacyLiteral(literalExpr, type);
}).collect(ImmutableList.toImmutableList());
List<Literal> literals = Lists.newArrayListWithCapacity(partitionKey.getKeys().size());
for (int i = 0; i < partitionKey.getKeys().size(); i++) {
LiteralExpr literalExpr = partitionKey.getKeys().get(i);
PrimitiveType primitiveType = partitionKey.getTypes().get(i);
Type type = Type.fromPrimitiveType(primitiveType);
literals.add(Literal.fromLegacyLiteral(literalExpr, type));
}
return literals;
}

@Override
Expand Down Expand Up @@ -655,15 +611,20 @@ private Optional<PartitionSlotType> getPartitionSlotType(Slot slot) {
private Map<Slot, PartitionSlotInput> fillSlotRangesToInputs(
Map<Slot, PartitionSlotInput> inputs) {

Map<Slot, ColumnRange> allColumnRanges = inputs.entrySet()
.stream()
.map(entry -> Pair.of(entry.getKey(), entry.getValue().columnRanges.get(entry.getKey())))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
Builder<Slot, ColumnRange> allColumnRangesBuilder =
ImmutableMap.builderWithExpectedSize(inputs.size() * 2);
for (Entry<Slot, PartitionSlotInput> entry : inputs.entrySet()) {
allColumnRangesBuilder.put(entry.getKey(), entry.getValue().columnRanges.get(entry.getKey()));
}

return inputs.keySet()
.stream()
.map(slot -> Pair.of(slot, new PartitionSlotInput(inputs.get(slot).result, allColumnRanges)))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
Map<Slot, ColumnRange> allColumnRanges = allColumnRangesBuilder.build();

Builder<Slot, PartitionSlotInput> partitionSlotInputs =
ImmutableMap.builderWithExpectedSize(inputs.size() * 2);
for (Slot slot : inputs.keySet()) {
partitionSlotInputs.put(slot, new PartitionSlotInput(inputs.get(slot).result, allColumnRanges));
}
return partitionSlotInputs.build();
}

/** EvaluateRangeInput */
Expand Down Expand Up @@ -728,4 +689,71 @@ public boolean isRejectNot() {
public boolean isDefaultPartition() {
return partitionItem.isDefaultPartition();
}

private List<Map<Slot, PartitionSlotInput>> computeSinglePartitionValueInputs() {
Slot partitionSlot = partitionSlots.get(0);
Literal literal = (Literal) inputs.get(0).get(0);
ColumnRange slotRange = ColumnRange.singleton(literal);
ImmutableMap<Slot, ColumnRange> slotToRange = ImmutableMap.of(partitionSlot, slotRange);
Map<Slot, PartitionSlotInput> slotToInputs =
ImmutableMap.of(partitionSlot, new PartitionSlotInput(literal, slotToRange));
return ImmutableList.of(slotToInputs);
}

private List<Map<Slot, PartitionSlotInput>> commonComputeOnePartitionInputs() {
List<Map<Slot, PartitionSlotInput>> onePartitionInputs = Lists.newArrayListWithCapacity(inputs.size());
for (List<Expression> input : inputs) {
boolean previousIsLowerBoundLiteral = true;
boolean previousIsUpperBoundLiteral = true;
Builder<Slot, PartitionSlotInput> slotToInputs = ImmutableMap.builderWithExpectedSize(16);
for (int i = 0; i < partitionSlots.size(); ++i) {
Slot partitionSlot = partitionSlots.get(i);
// partitionSlot will be replaced to this expression
Expression expression = input.get(i);
ColumnRange slotRange = null;
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
if (expression instanceof Literal) {
// const or expanded range
slotRange = ColumnRange.singleton((Literal) expression);
if (!expression.equals(lowers.get(i))) {
previousIsLowerBoundLiteral = false;
}
if (!expression.equals(uppers.get(i))) {
previousIsUpperBoundLiteral = false;
}
} else {
// un expanded range
switch (partitionSlotType) {
case RANGE:
boolean isLastPartitionColumn = i + 1 == partitionSlots.size();
BoundType rightBoundType = isLastPartitionColumn
? BoundType.OPEN : BoundType.CLOSED;
slotRange = ColumnRange.range(
lowers.get(i), BoundType.CLOSED, uppers.get(i), rightBoundType);
break;
case OTHER:
if (previousIsLowerBoundLiteral) {
slotRange = ColumnRange.atLeast(lowers.get(i));
} else if (previousIsUpperBoundLiteral) {
slotRange = ColumnRange.lessThen(uppers.get(i));
} else {
// unknown range
slotRange = ColumnRange.all();
}
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
previousIsLowerBoundLiteral = false;
previousIsUpperBoundLiteral = false;
}
ImmutableMap<Slot, ColumnRange> slotToRange = ImmutableMap.of(partitionSlot, slotRange);
slotToInputs.put(partitionSlot, new PartitionSlotInput(expression, slotToRange));
}

Map<Slot, PartitionSlotInput> slotPartitionSlotInputMap = fillSlotRangesToInputs(slotToInputs.build());
onePartitionInputs.add(slotPartitionSlotInputMap);
}
return onePartitionInputs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
import org.apache.doris.nereids.types.DateTimeType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;

/**
Expand Down Expand Up @@ -91,11 +94,15 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context)
}
}

/** prune */
public List<Long> prune() {
return partitions.stream()
.filter(partitionEvaluator -> !canPrune(partitionEvaluator))
.map(OnePartitionEvaluator::getPartitionId)
.collect(ImmutableList.toImmutableList());
Builder<Long> scanPartitionIds = ImmutableList.builder();
for (OnePartitionEvaluator partition : partitions) {
if (!canPrune(partition)) {
scanPartitionIds.add(partition.getPartitionId());
}
}
return scanPartitionIds.build();
}

/**
Expand All @@ -107,11 +114,12 @@ public static List<Long> prune(List<Slot> partitionSlots, Expression partitionPr
partitionPredicate = TryEliminateUninterestedPredicates.rewrite(
partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext);
partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext);
List<OnePartitionEvaluator> evaluators = idToPartitions.entrySet()
.stream()
.map(kv -> toPartitionEvaluator(kv.getKey(), kv.getValue(), partitionSlots, cascadesContext,
partitionTableType))
.collect(ImmutableList.toImmutableList());

List<OnePartitionEvaluator> evaluators = Lists.newArrayListWithCapacity(idToPartitions.size());
for (Entry<Long, PartitionItem> kv : idToPartitions.entrySet()) {
evaluators.add(toPartitionEvaluator(
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType));
}

partitionPredicate = OrToIn.INSTANCE.rewrite(partitionPredicate, null);
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
Expand Down
Loading

0 comments on commit 08c4e0f

Please sign in to comment.