Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sum aggregations for long values precision #50538

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,26 @@
import java.util.Objects;

public class InternalSum extends InternalNumericMetricsAggregation.SingleValue implements Sum {
private final double sum;
private final double doubleSum;
private final long longSum;
private final boolean isFloating;

InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.sum = sum;
this.doubleSum = sum;
this.longSum = (long) sum;
this.format = formatter;
this.isFloating = true;
}

InternalSum(String name, long sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.doubleSum = (double) sum;
this.longSum = sum;
this.format = formatter;
this.isFloating = false;
}

/**
Expand All @@ -46,13 +59,17 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
public InternalSum(StreamInput in) throws IOException {
super(in);
format = in.readNamedWriteable(DocValueFormat.class);
sum = in.readDouble();
doubleSum = in.readDouble();
longSum = in.readLong();
isFloating = in.readBoolean();
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(format);
out.writeDouble(sum);
out.writeDouble(doubleSum);
out.writeLong(longSum);
out.writeBoolean(isFloating);
}

@Override
Expand All @@ -62,38 +79,61 @@ public String getWriteableName() {

@Override
public double value() {
return sum;
return doubleSum;
}

// For testing
public long longValue() {
return longSum;
}

@Override
public double getValue() {
return sum;
return doubleSum;
}

@Override
public InternalSum reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
for (InternalAggregation aggregation : aggregations) {
double value = ((InternalSum) aggregation).sum;
kahanSummation.add(value);
if (isFloating) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
for (InternalAggregation aggregation : aggregations) {
double value = ((InternalSum) aggregation).doubleSum;
kahanSummation.add(value);
}
return new InternalSum(name, kahanSummation.value(), format, pipelineAggregators(), getMetaData());
} else {
// Compute the sum of long values with naive summation.
long sum = 0L;
for (InternalAggregation aggregation : aggregations) {
long value = ((InternalSum) aggregation).longSum;
sum += value;
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
}
return new InternalSum(name, kahanSummation.value(), format, pipelineAggregators(), getMetaData());
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(CommonFields.VALUE.getPreferredName(), sum);
if (format != DocValueFormat.RAW) {
builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(sum).toString());
if (isFloating) {
builder.field(CommonFields.VALUE.getPreferredName(), doubleSum);
if (format != DocValueFormat.RAW) {
builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(doubleSum).toString());
}
return builder;
} else {
builder.field(CommonFields.VALUE.getPreferredName(), longSum);
if (format != DocValueFormat.RAW) {
builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(longSum).toString());
}
return builder;
}
return builder;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), sum);
return Objects.hash(super.hashCode(), doubleSum, longSum);
}

@Override
Expand All @@ -103,6 +143,6 @@ public boolean equals(Object obj) {
if (super.equals(obj) == false) return false;

InternalSum that = (InternalSum) obj;
return Objects.equals(sum, that.sum);
return Objects.equals(doubleSum, that.doubleSum) && Objects.equals(longSum, that.longSum);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
Expand All @@ -42,7 +44,8 @@ class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final ValuesSource.Numeric valuesSource;
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray doubleSums;
private LongArray longSums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Expand All @@ -51,7 +54,8 @@ class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.valuesSource = valuesSource;
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
doubleSums = context.bigArrays().newDoubleArray(1, true);
longSums = context.bigArrays().newLongArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}
Expand All @@ -68,48 +72,79 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
kahanSummation.add(value);
if (valuesSource.isFloatingPoint()) {
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
doubleSums = bigArrays.grow(doubleSums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = doubleSums.get(bucket);
double compensation = compensations.get(bucket);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
kahanSummation.add(value);
}

compensations.set(bucket, kahanSummation.delta());
doubleSums.set(bucket, kahanSummation.value());
}
}
};
} else {
final SortedNumericDocValues values = valuesSource.longValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
longSums = bigArrays.grow(longSums, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
// Compute the sum of long values with naive summation.
long sum = longSums.get(bucket);

compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
for (int i = 0; i < valuesCount; i++) {
long value = values.nextValue();
sum += value;
}

longSums.set(bucket, sum);
}
}
}
};
};
}
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSource == null || owningBucketOrd >= sums.size()) {
if (valuesSource == null || owningBucketOrd >= doubleSums.size() || owningBucketOrd >= longSums.size()) {
return 0.0;
}
return sums.get(owningBucketOrd);
if (valuesSource.isFloatingPoint()) {
return doubleSums.get(owningBucketOrd);
} else {
return (double) longSums.get(owningBucketOrd);
}
}

@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSource == null || bucket >= sums.size()) {
if (valuesSource == null || bucket >= doubleSums.size() || bucket >= longSums.size()) {
return buildEmptyAggregation();
}
return new InternalSum(name, sums.get(bucket), format, pipelineAggregators(), metaData());
if (valuesSource.isFloatingPoint()) {
return new InternalSum(name, doubleSums.get(bucket), format, pipelineAggregators(), metaData());
} else {
return new InternalSum(name, longSums.get(bucket), format, pipelineAggregators(), metaData());
}
}

@Override
Expand All @@ -119,6 +154,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums, compensations);
Releasables.close(doubleSums, longSums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ public void testSummationAccuracy() {
verifySummationOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}

public void testSummationAccuracyLong() {
// Summing up a normal array of long values
long[] longValues = new long[]{1, 17458313843517748L};
verifySummationOfLongs(longValues, 17458313843517749L);

// Double values precision
double[] doubleValues = new double[]{1, 17458313843517748d};
verifySummationOfDoubles(doubleValues, 17458313843517748d, 0d);

// Summing up an array which contains NaN and infinities and expect a result same as naive summation
long[] values;
int n = randomIntBetween(5, 10);
values = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(0L, Long.MIN_VALUE, Long.MAX_VALUE)
: randomLongBetween(Long.MIN_VALUE, Long.MAX_VALUE);
sum += values[i];
}
verifySummationOfLongs(values, sum);
}

private void verifySummationOfDoubles(double[] values, double expected, double delta) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
for (double value : values) {
Expand All @@ -91,6 +114,16 @@ private void verifySummationOfDoubles(double[] values, double expected, double d
assertEquals(expected, reduced.value(), delta);
}

private void verifySummationOfLongs(long[] values, long expected) {
List<InternalAggregation> aggregations = new ArrayList<>(values.length);
for (long value : values) {
aggregations.add(new InternalSum("long1", value, null, null, null));
}
InternalSum internalSum = new InternalSum("long", 0, null, null, null);
InternalSum reduced = internalSum.reduce(aggregations, null);
assertEquals(expected, reduced.longValue());
}

@Override
protected void assertFromXContent(InternalSum sum, ParsedAggregation parsedAggregation) {
ParsedSum parsed = ((ParsedSum) parsedAggregation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ public void testSummationAccuracy() throws IOException {
verifySummationOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}

public void testSummationAccuracyLong() throws IOException {
// Summing up a normal array of long values
long[] longValues = new long[]{1, 17458313843517748L};
verifySummationOfLongs(longValues, 17458313843517749L);

// Double values precision
double[] doubleValues = new double[]{1, 17458313843517748d};
verifySummationOfDoubles(doubleValues, 17458313843517748d, 0d);

// Summing up an array which contains NaN and infinities and expect a result same as naive summation
long[] values;
int n = randomIntBetween(5, 10);
values = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(0L, Long.MIN_VALUE, Long.MAX_VALUE)
: randomLongBetween(Long.MIN_VALUE, Long.MAX_VALUE);
sum += values[i];
}
verifySummationOfLongs(values, sum);
}

private void verifySummationOfDoubles(double[] values, double expected, double delta) throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
Expand All @@ -176,6 +199,18 @@ private void verifySummationOfDoubles(double[] values, double expected, double d
);
}

private void verifySummationOfLongs(long[] values, long expected) throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
for (long value : values) {
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, value)));
}
},
result -> assertEquals(expected, result.longValue()),
NumberFieldMapper.NumberType.LONG
);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<InternalSum> verify) throws IOException {
Expand Down