Skip to content

Commit

Permalink
Make InternalCentroid leaner
Browse files Browse the repository at this point in the history
  • Loading branch information
iverase committed Nov 6, 2024
1 parent 9eb4dd2 commit f09012e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,20 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

/**
* Serialization and merge logic for {@link GeoCentroidAggregator}.
*/
public abstract class InternalCentroid extends InternalAggregation implements CentroidAggregation {
protected final SpatialPoint centroid;
protected final long count;
private final FieldExtractor firstField;
private final FieldExtractor secondField;

public InternalCentroid(
String name,
SpatialPoint centroid,
long count,
Map<String, Object> metadata,
FieldExtractor firstField,
FieldExtractor secondField
) {

public InternalCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(name, metadata);
assert (centroid == null) == (count == 0);
this.centroid = centroid;
assert count >= 0;
this.count = count;
this.firstField = firstField;
this.secondField = secondField;
}

protected abstract SpatialPoint centroidFromStream(StreamInput in) throws IOException;
Expand All @@ -59,16 +47,14 @@ public InternalCentroid(
* Read from a stream.
*/
@SuppressWarnings("this-escape")
protected InternalCentroid(StreamInput in, FieldExtractor firstField, FieldExtractor secondField) throws IOException {
protected InternalCentroid(StreamInput in) throws IOException {
super(in);
count = in.readVLong();
if (in.readBoolean()) {
centroid = centroidFromStream(in);
} else {
centroid = null;
}
this.firstField = firstField;
this.secondField = secondField;
}

@Override
Expand Down Expand Up @@ -110,11 +96,11 @@ public void accept(InternalAggregation aggregation) {
if (centroidAgg.count > 0) {
totalCount += centroidAgg.count;
if (Double.isNaN(firstSum)) {
firstSum = centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid);
secondSum = centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid);
firstSum = centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum = centroidAgg.count * extractSecond(centroidAgg.centroid);
} else {
firstSum += centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid);
secondSum += centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid);
firstSum += centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum += centroidAgg.count * extractSecond(centroidAgg.centroid);
}
}
}
Expand All @@ -126,6 +112,14 @@ public InternalAggregation get() {
};
}

protected abstract String nameFirst();

protected abstract double extractFirst(SpatialPoint point);

protected abstract String nameSecond();

protected abstract double extractSecond(SpatialPoint point);

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return copyWith(centroid, samplingContext.scaleUp(count));
Expand All @@ -136,16 +130,6 @@ protected boolean mustReduceOnSingleInternalAgg() {
return false;
}

protected static class FieldExtractor {
private final String name;
private final Function<SpatialPoint, Double> extractor;

public FieldExtractor(String name, Function<SpatialPoint, Double> extractor) {
this.name = name;
this.extractor = extractor;
}
}

protected abstract double extractDouble(String name);

@Override
Expand Down Expand Up @@ -174,8 +158,8 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th
if (centroid != null) {
builder.startObject(Fields.CENTROID.getPreferredName());
{
builder.field(firstField.name, firstField.extractor.apply(centroid));
builder.field(secondField.name, secondField.extractor.apply(centroid));
builder.field(nameFirst(), extractFirst(centroid));
builder.field(nameSecond(), extractSecond(centroid));
}
builder.endObject();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.ParseField;

import java.io.IOException;
import java.util.Map;
Expand All @@ -26,21 +25,14 @@
public class InternalGeoCentroid extends InternalCentroid implements GeoCentroid {

public InternalGeoCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(
name,
centroid,
count,
metadata,
new FieldExtractor("lat", SpatialPoint::getY),
new FieldExtractor("lon", SpatialPoint::getX)
);
super(name, centroid, count, metadata);
}

/**
* Read from a stream.
*/
public InternalGeoCentroid(StreamInput in) throws IOException {
super(in, new FieldExtractor("lat", SpatialPoint::getY), new FieldExtractor("lon", SpatialPoint::getX));
super(in);
}

public static InternalGeoCentroid empty(String name, Map<String, Object> metadata) {
Expand Down Expand Up @@ -84,12 +76,27 @@ protected InternalGeoCentroid copyWith(double firstSum, double secondSum, long t
}

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
protected String nameFirst() {
return "lat";
}

@Override
protected double extractFirst(SpatialPoint point) {
return point.getY();
}

@Override
protected String nameSecond() {
return "lon";
}

@Override
protected double extractSecond(SpatialPoint point) {
return point.getX();
}

static class Fields {
static final ParseField CENTROID_LAT = new ParseField("lat");
static final ParseField CENTROID_LON = new ParseField("lon");
@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.metrics.InternalCentroid;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.spatial.common.CartesianPoint;

import java.io.IOException;
Expand All @@ -25,14 +24,14 @@
public class InternalCartesianCentroid extends InternalCentroid implements CartesianCentroid {

public InternalCartesianCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
super(name, centroid, count, metadata, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY));
super(name, centroid, count, metadata);
}

/**
* Read from a stream.
*/
public InternalCartesianCentroid(StreamInput in) throws IOException {
super(in, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY));
super(in);
}

@Override
Expand Down Expand Up @@ -80,8 +79,23 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalCartesianCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata());
}

static class Fields {
static final ParseField CENTROID_X = new ParseField("x");
static final ParseField CENTROID_Y = new ParseField("y");
@Override
protected String nameFirst() {
return "x";
}

@Override
protected double extractFirst(SpatialPoint point) {
return point.getX();
}

@Override
protected String nameSecond() {
return "y";
}

@Override
protected double extractSecond(SpatialPoint point) {
return point.getY();
}
}

0 comments on commit f09012e

Please sign in to comment.