Skip to content

Commit

Permalink
Improve efficiency of T-Digest functions
Browse files Browse the repository at this point in the history
Use `TDigest.valuesAt()` method to compute values at multiple quantiles
in a single pass of the T-Digest structure instead of multiple
calls to `TDigest.valueAt()` method.
Applies to:
- approx_percentile() aggregation,
- values_at_quantiles() function.
NOTE: In function values_at_quantiles(tdigest, array<double>),
it is required that the input percentile array be sorted in ascending order.
The similar function values_at_quantiles(qdigest, array<double>) allows
arbitrary order of percentiles.
In approx_percentile() aggregation ordering of percentiles is not required.
  • Loading branch information
kasiafi authored and martint committed Oct 4, 2020
1 parent 7788685 commit 0d41a63
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.prestosql.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Doubles;
import io.airlift.stats.TDigest;
import io.prestosql.operator.aggregation.state.TDigestAndPercentileArrayState;
import io.prestosql.spi.block.Block;
Expand Down Expand Up @@ -96,14 +97,42 @@ public static void output(@AggregationState TDigestAndPercentileArrayState state

BlockBuilder blockBuilder = out.beginBlockEntry();

for (int i = 0; i < percentiles.size(); i++) {
Double percentile = percentiles.get(i);
DOUBLE.writeDouble(blockBuilder, digest.valueAt(percentile));
List<Double> valuesAtPercentiles = valuesAtPercentiles(digest, percentiles);
for (double value : valuesAtPercentiles) {
DOUBLE.writeDouble(blockBuilder, value);
}

out.closeEntry();
}

public static List<Double> valuesAtPercentiles(TDigest digest, List<Double> percentiles)
{
int[] indexes = new int[percentiles.size()];
double[] sortedPercentiles = new double[percentiles.size()];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
sortedPercentiles[i] = percentiles.get(i);
}

it.unimi.dsi.fastutil.Arrays.quickSort(0, percentiles.size(), (a, b) -> Doubles.compare(sortedPercentiles[a], sortedPercentiles[b]), (a, b) -> {
double tempPercentile = sortedPercentiles[a];
sortedPercentiles[a] = sortedPercentiles[b];
sortedPercentiles[b] = tempPercentile;

int tempIndex = indexes[a];
indexes[a] = indexes[b];
indexes[b] = tempIndex;
});

List<Double> valuesAtPercentiles = digest.valuesAt(Doubles.asList(sortedPercentiles));
double[] result = new double[valuesAtPercentiles.size()];
for (int i = 0; i < valuesAtPercentiles.size(); i++) {
result[indexes[i]] = valuesAtPercentiles.get(i);
}

return Doubles.asList(result);
}

private static void initializePercentilesArray(@AggregationState TDigestAndPercentileArrayState state, Block percentilesArrayBlock)
{
if (state.getPercentiles() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.util.List;

import static io.prestosql.operator.aggregation.ApproximateDoublePercentileArrayAggregations.valuesAtPercentiles;
import static io.prestosql.operator.aggregation.ApproximateLongPercentileAggregations.toDoubleExact;
import static io.prestosql.spi.type.BigintType.BIGINT;

Expand Down Expand Up @@ -66,9 +67,9 @@ public static void output(@AggregationState TDigestAndPercentileArrayState state

BlockBuilder blockBuilder = out.beginBlockEntry();

for (int i = 0; i < percentiles.size(); i++) {
Double percentile = percentiles.get(i);
BIGINT.writeLong(blockBuilder, Math.round(digest.valueAt(percentile)));
List<Double> valuesAtPercentiles = valuesAtPercentiles(digest, percentiles);
for (double value : valuesAtPercentiles) {
BIGINT.writeLong(blockBuilder, Math.round(value));
}

out.closeEntry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.util.List;

import static io.prestosql.operator.aggregation.ApproximateDoublePercentileArrayAggregations.valuesAtPercentiles;
import static io.prestosql.spi.type.RealType.REAL;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Float.intBitsToFloat;
Expand Down Expand Up @@ -67,9 +68,9 @@ public static void output(@AggregationState TDigestAndPercentileArrayState state

BlockBuilder blockBuilder = out.beginBlockEntry();

for (int i = 0; i < percentiles.size(); i++) {
Double percentile = percentiles.get(i);
REAL.writeLong(blockBuilder, floatToRawIntBits((float) digest.valueAt(percentile)));
List<Double> valuesAtPercentiles = valuesAtPercentiles(digest, percentiles);
for (double value : valuesAtPercentiles) {
REAL.writeLong(blockBuilder, floatToRawIntBits((float) value));
}

out.closeEntry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.prestosql.operator.scalar;

import com.google.common.collect.Ordering;
import io.airlift.stats.TDigest;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
Expand All @@ -21,6 +22,10 @@
import io.prestosql.spi.function.SqlType;
import io.prestosql.spi.type.StandardTypes;

import java.util.List;
import java.util.stream.IntStream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.util.Failures.checkCondition;
Expand All @@ -44,9 +49,15 @@ public static double valueAtQuantile(@SqlType(StandardTypes.TDIGEST) TDigest inp
@SqlType("array(double)")
public static Block valuesAtQuantiles(@SqlType(StandardTypes.TDIGEST) TDigest input, @SqlType("array(double)") Block percentilesArrayBlock)
{
List<Double> percentiles = IntStream.range(0, percentilesArrayBlock.getPositionCount())
.mapToDouble(i -> DOUBLE.getDouble(percentilesArrayBlock, i))
.boxed()
.collect(toImmutableList());
checkCondition(Ordering.natural().isOrdered(percentiles), INVALID_FUNCTION_ARGUMENT, "percentiles must be sorted in increasing order");
BlockBuilder output = DOUBLE.createBlockBuilder(null, percentilesArrayBlock.getPositionCount());
for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) {
DOUBLE.writeDouble(output, input.valueAt(DOUBLE.getDouble(percentilesArrayBlock, i)));
List<Double> valuesAtPercentiles = input.valuesAt(percentiles);
for (Double value : valuesAtPercentiles) {
DOUBLE.writeDouble(output, value);
}
return output.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ public void testLongPartialStep()
createLongsBlock(1L, null, 2L, 2L, null, 2L, 2L, null, 2L, 2L, null, 3L, 3L, null, 3L, null, 3L, 4L, 5L, 6L, 7L),
createRLEBlock(ImmutableList.of(0.01, 0.5), 21));

// unsorted percentiles
assertAggregation(
LONG_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION,
ImmutableList.of(3L, 1L, 2L),
createLongsBlock(null, 1L, 2L, 3L),
createRLEBlock(ImmutableList.of(0.8, 0.2, 0.5), 4));

// weighted approx_percentile
assertAggregation(
LONG_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION,
Expand Down Expand Up @@ -294,6 +301,13 @@ public void testFloatPartialStep()
createBlockOfReals(1.0f, null, 2.0f, 2.0f, null, 2.0f, 2.0f, null, 2.0f, 2.0f, null, 3.0f, 3.0f, null, 3.0f, null, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f),
createRLEBlock(ImmutableList.of(0.01, 0.5), 21));

// unsorted percentiles
assertAggregation(
FLOAT_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION,
ImmutableList.of(3.0f, 1.0f, 2.0f),
createBlockOfReals(null, 1.0f, 2.0f, 3.0f),
createRLEBlock(ImmutableList.of(0.8, 0.2, 0.5), 4));

// weighted approx_percentile
assertAggregation(
FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION,
Expand Down Expand Up @@ -425,6 +439,13 @@ public void testDoublePartialStep()
createDoublesBlock(1.0, null, 2.0, 2.0, null, 2.0, 2.0, null, 2.0, 2.0, null, 3.0, 3.0, null, 3.0, null, 3.0, 4.0, 5.0, 6.0, 7.0),
createRLEBlock(ImmutableList.of(0.01, 0.5), 21));

// unsorted percentiles
assertAggregation(
DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION,
ImmutableList.of(3.0, 1.0, 2.0),
createDoublesBlock(null, 1.0, 2.0, 3.0),
createRLEBlock(ImmutableList.of(0.8, 0.2, 0.5), 4));

// weighted approx_percentile
assertAggregation(
DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import static java.lang.Math.round;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestTDigestFunctions
{
Expand Down Expand Up @@ -78,6 +79,11 @@ public void testValuesAtQuantiles()
"SELECT values_at_quantiles(tdigest_agg(d), ARRAY[0.0001e0, 0.75e0, 0.85e0]) " +
"FROM (VALUES 0.1e0, 0.1e0, 0.1e0, 0.1e0, 10e0) T(d)"))
.matches("VALUES ARRAY[0.1e0, 0.1e0, 10.0e0]");

assertThatThrownBy(() -> assertions.query(
"SELECT values_at_quantiles(tdigest_agg(d), ARRAY[1e0, 0e0]) " +
"FROM (VALUES 0.1e0) T(d)"))
.hasMessage("percentiles must be sorted in increasing order");
}

@Test
Expand Down

0 comments on commit 0d41a63

Please sign in to comment.