Skip to content

Commit

Permalink
Make InternalCentroid leaner (#116302)
Browse files Browse the repository at this point in the history
We are currently holding to fields to extract values, this commit makes them abstract methods so 
we don't use any heap.
  • Loading branch information
iverase authored and jozala committed Nov 13, 2024
1 parent 42e10e9 commit da3bbdf
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,20 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

/**
* Serialization and merge logic for {@link GeoCentroidAggregator}.
*/
public abstract class InternalCentroid extends InternalAggregation implements CentroidAggregation {
protected final SpatialPoint centroid;
protected final long count;
private final FieldExtractor firstField;
private final FieldExtractor secondField;

public InternalCentroid(
String name,
SpatialPoint centroid,
long count,
Map<String, Object> metadata,
FieldExtractor firstField,
FieldExtractor secondField
) {

public InternalCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(name, metadata);
assert (centroid == null) == (count == 0);
this.centroid = centroid;
assert count >= 0;
this.count = count;
this.firstField = firstField;
this.secondField = secondField;
}

protected abstract SpatialPoint centroidFromStream(StreamInput in) throws IOException;
Expand All @@ -59,16 +47,14 @@ public InternalCentroid(
* Read from a stream.
*/
@SuppressWarnings("this-escape")
protected InternalCentroid(StreamInput in, FieldExtractor firstField, FieldExtractor secondField) throws IOException {
protected InternalCentroid(StreamInput in) throws IOException {
super(in);
count = in.readVLong();
if (in.readBoolean()) {
centroid = centroidFromStream(in);
} else {
centroid = null;
}
this.firstField = firstField;
this.secondField = secondField;
}

@Override
Expand Down Expand Up @@ -110,11 +96,11 @@ public void accept(InternalAggregation aggregation) {
if (centroidAgg.count > 0) {
totalCount += centroidAgg.count;
if (Double.isNaN(firstSum)) {
firstSum = centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid);
secondSum = centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid);
firstSum = centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum = centroidAgg.count * extractSecond(centroidAgg.centroid);
} else {
firstSum += centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid);
secondSum += centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid);
firstSum += centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum += centroidAgg.count * extractSecond(centroidAgg.centroid);
}
}
}
Expand All @@ -126,6 +112,14 @@ public InternalAggregation get() {
};
}

protected abstract String nameFirst();

protected abstract double extractFirst(SpatialPoint point);

protected abstract String nameSecond();

protected abstract double extractSecond(SpatialPoint point);

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return copyWith(centroid, samplingContext.scaleUp(count));
Expand All @@ -136,16 +130,6 @@ protected boolean mustReduceOnSingleInternalAgg() {
return false;
}

protected static class FieldExtractor {
private final String name;
private final Function<SpatialPoint, Double> extractor;

public FieldExtractor(String name, Function<SpatialPoint, Double> extractor) {
this.name = name;
this.extractor = extractor;
}
}

protected abstract double extractDouble(String name);

@Override
Expand Down Expand Up @@ -174,8 +158,8 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th
if (centroid != null) {
builder.startObject(Fields.CENTROID.getPreferredName());
{
builder.field(firstField.name, firstField.extractor.apply(centroid));
builder.field(secondField.name, secondField.extractor.apply(centroid));
builder.field(nameFirst(), extractFirst(centroid));
builder.field(nameSecond(), extractSecond(centroid));
}
builder.endObject();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.ParseField;

import java.io.IOException;
import java.util.Map;
Expand All @@ -26,21 +25,14 @@
public class InternalGeoCentroid extends InternalCentroid implements GeoCentroid {

public InternalGeoCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(
name,
centroid,
count,
metadata,
new FieldExtractor("lat", SpatialPoint::getY),
new FieldExtractor("lon", SpatialPoint::getX)
);
super(name, centroid, count, metadata);
}

/**
* Read from a stream.
*/
public InternalGeoCentroid(StreamInput in) throws IOException {
super(in, new FieldExtractor("lat", SpatialPoint::getY), new FieldExtractor("lon", SpatialPoint::getX));
super(in);
}

public static InternalGeoCentroid empty(String name, Map<String, Object> metadata) {
Expand Down Expand Up @@ -84,12 +76,27 @@ protected InternalGeoCentroid copyWith(double firstSum, double secondSum, long t
}

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
protected String nameFirst() {
return "lat";
}

@Override
protected double extractFirst(SpatialPoint point) {
return point.getY();
}

@Override
protected String nameSecond() {
return "lon";
}

@Override
protected double extractSecond(SpatialPoint point) {
return point.getX();
}

static class Fields {
static final ParseField CENTROID_LAT = new ParseField("lat");
static final ParseField CENTROID_LON = new ParseField("lon");
@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.metrics.InternalCentroid;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.spatial.common.CartesianPoint;

import java.io.IOException;
Expand All @@ -25,14 +24,14 @@
public class InternalCartesianCentroid extends InternalCentroid implements CartesianCentroid {

public InternalCartesianCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(name, centroid, count, metadata, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY));
super(name, centroid, count, metadata);
}

/**
* Read from a stream.
*/
public InternalCartesianCentroid(StreamInput in) throws IOException {
super(in, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY));
super(in);
}

@Override
Expand Down Expand Up @@ -80,8 +79,23 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalCartesianCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
}

static class Fields {
static final ParseField CENTROID_X = new ParseField("x");
static final ParseField CENTROID_Y = new ParseField("y");
@Override
protected String nameFirst() {
return "x";
}

@Override
protected double extractFirst(SpatialPoint point) {
return point.getX();
}

@Override
protected String nameSecond() {
return "y";
}

@Override
protected double extractSecond(SpatialPoint point) {
return point.getY();
}
}

0 comments on commit da3bbdf

Please sign in to comment.