From ea441a5e0c6851f99ee41aa449aeb9708dc6cc02 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Wed, 6 Nov 2024 16:22:38 +0100 Subject: [PATCH 1/2] Make InternalCentroid leaner (#116302) We are currently holding to fields to extract values, this commit makes them abstract methods so we don't use any heap. --- .../metrics/InternalCentroid.java | 50 +++++++------------ .../metrics/InternalGeoCentroid.java | 37 ++++++++------ .../metrics/InternalCartesianCentroid.java | 26 +++++++--- 3 files changed, 59 insertions(+), 54 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalCentroid.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalCentroid.java index 05dd82fd59c4f..eb789bcdd8a74 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalCentroid.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalCentroid.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.Function; /** * Serialization and merge logic for {@link GeoCentroidAggregator}. @@ -31,24 +30,13 @@ public abstract class InternalCentroid extends InternalAggregation implements CentroidAggregation { protected final SpatialPoint centroid; protected final long count; - private final FieldExtractor firstField; - private final FieldExtractor secondField; - - public InternalCentroid( - String name, - SpatialPoint centroid, - long count, - Map metadata, - FieldExtractor firstField, - FieldExtractor secondField - ) { + + public InternalCentroid(String name, SpatialPoint centroid, long count, Map metadata) { super(name, metadata); assert (centroid == null) == (count == 0); this.centroid = centroid; assert count >= 0; this.count = count; - this.firstField = firstField; - this.secondField = secondField; } protected abstract SpatialPoint centroidFromStream(StreamInput in) throws IOException; @@ -59,7 +47,7 @@ public InternalCentroid( * Read from a stream. */ @SuppressWarnings("this-escape") - protected InternalCentroid(StreamInput in, FieldExtractor firstField, FieldExtractor secondField) throws IOException { + protected InternalCentroid(StreamInput in) throws IOException { super(in); count = in.readVLong(); if (in.readBoolean()) { @@ -67,8 +55,6 @@ protected InternalCentroid(StreamInput in, FieldExtractor firstField, FieldExtra } else { centroid = null; } - this.firstField = firstField; - this.secondField = secondField; } @Override @@ -110,11 +96,11 @@ public void accept(InternalAggregation aggregation) { if (centroidAgg.count > 0) { totalCount += centroidAgg.count; if (Double.isNaN(firstSum)) { - firstSum = centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid); - secondSum = centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid); + firstSum = centroidAgg.count * extractFirst(centroidAgg.centroid); + secondSum = centroidAgg.count * extractSecond(centroidAgg.centroid); } else { - firstSum += centroidAgg.count * firstField.extractor.apply(centroidAgg.centroid); - secondSum += centroidAgg.count * secondField.extractor.apply(centroidAgg.centroid); + firstSum += centroidAgg.count * extractFirst(centroidAgg.centroid); + secondSum += centroidAgg.count * extractSecond(centroidAgg.centroid); } } } @@ -126,6 +112,14 @@ public InternalAggregation get() { }; } + protected abstract String nameFirst(); + + protected abstract double extractFirst(SpatialPoint point); + + protected abstract String nameSecond(); + + protected abstract double extractSecond(SpatialPoint point); + @Override public InternalAggregation finalizeSampling(SamplingContext samplingContext) { return copyWith(centroid, samplingContext.scaleUp(count)); @@ -136,16 +130,6 @@ protected boolean mustReduceOnSingleInternalAgg() { return false; } - protected static class FieldExtractor { - private final String name; - private final Function extractor; - - public FieldExtractor(String name, Function extractor) { - this.name = name; - this.extractor = extractor; - } - } - protected abstract double extractDouble(String name); @Override @@ -174,8 +158,8 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th if (centroid != null) { builder.startObject(Fields.CENTROID.getPreferredName()); { - builder.field(firstField.name, firstField.extractor.apply(centroid)); - builder.field(secondField.name, secondField.extractor.apply(centroid)); + builder.field(nameFirst(), extractFirst(centroid)); + builder.field(nameSecond(), extractSecond(centroid)); } builder.endObject(); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalGeoCentroid.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalGeoCentroid.java index 10e301608ec2f..1609046d59708 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalGeoCentroid.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalGeoCentroid.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.support.SamplingContext; -import org.elasticsearch.xcontent.ParseField; import java.io.IOException; import java.util.Map; @@ -26,21 +25,14 @@ public class InternalGeoCentroid extends InternalCentroid implements GeoCentroid { public InternalGeoCentroid(String name, SpatialPoint centroid, long count, Map metadata) { - super( - name, - centroid, - count, - metadata, - new FieldExtractor("lat", SpatialPoint::getY), - new FieldExtractor("lon", SpatialPoint::getX) - ); + super(name, centroid, count, metadata); } /** * Read from a stream. */ public InternalGeoCentroid(StreamInput in) throws IOException { - super(in, new FieldExtractor("lat", SpatialPoint::getY), new FieldExtractor("lon", SpatialPoint::getX)); + super(in); } public static InternalGeoCentroid empty(String name, Map metadata) { @@ -84,12 +76,27 @@ protected InternalGeoCentroid copyWith(double firstSum, double secondSum, long t } @Override - public InternalAggregation finalizeSampling(SamplingContext samplingContext) { - return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata()); + protected String nameFirst() { + return "lat"; + } + + @Override + protected double extractFirst(SpatialPoint point) { + return point.getY(); + } + + @Override + protected String nameSecond() { + return "lon"; + } + + @Override + protected double extractSecond(SpatialPoint point) { + return point.getX(); } - static class Fields { - static final ParseField CENTROID_LAT = new ParseField("lat"); - static final ParseField CENTROID_LON = new ParseField("lon"); + @Override + public InternalAggregation finalizeSampling(SamplingContext samplingContext) { + return new InternalGeoCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata()); } } diff --git a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/search/aggregations/metrics/InternalCartesianCentroid.java b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/search/aggregations/metrics/InternalCartesianCentroid.java index e009e07d35aa4..63f43458c79b5 100644 --- a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/search/aggregations/metrics/InternalCartesianCentroid.java +++ b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/search/aggregations/metrics/InternalCartesianCentroid.java @@ -13,7 +13,6 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.metrics.InternalCentroid; import org.elasticsearch.search.aggregations.support.SamplingContext; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.spatial.common.CartesianPoint; import java.io.IOException; @@ -25,14 +24,14 @@ public class InternalCartesianCentroid extends InternalCentroid implements CartesianCentroid { public InternalCartesianCentroid(String name, SpatialPoint centroid, long count, Map metadata) { - super(name, centroid, count, metadata, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY)); + super(name, centroid, count, metadata); } /** * Read from a stream. */ public InternalCartesianCentroid(StreamInput in) throws IOException { - super(in, new FieldExtractor("x", SpatialPoint::getX), new FieldExtractor("y", SpatialPoint::getY)); + super(in); } @Override @@ -80,8 +79,23 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { return new InternalCartesianCentroid(name, centroid, samplingContext.scaleUp(count), getMetadata()); } - static class Fields { - static final ParseField CENTROID_X = new ParseField("x"); - static final ParseField CENTROID_Y = new ParseField("y"); + @Override + protected String nameFirst() { + return "x"; + } + + @Override + protected double extractFirst(SpatialPoint point) { + return point.getX(); + } + + @Override + protected String nameSecond() { + return "y"; + } + + @Override + protected double extractSecond(SpatialPoint point) { + return point.getY(); } } From 543e9bab67c73e7dcc90af5ae3da163476151463 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 6 Nov 2024 15:28:47 +0000 Subject: [PATCH 2/2] [ML] Add debug logging for tests failing with empty model download (#116263) --- muted-tests.yml | 3 --- .../xpack/inference/DefaultEndPointsIT.java | 14 +++++++++++--- .../ml/packageloader/action/ModelLoaderUtils.java | 11 ++++++++++- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 86239120196a7..a404136c9f0ec 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -118,9 +118,6 @@ tests: - class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT method: testTracingCrossCluster issue: https://github.com/elastic/elasticsearch/issues/112731 -- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT - method: testInferDeploysDefaultElser - issue: https://github.com/elastic/elasticsearch/issues/114913 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=esql/60_usage/Basic ESQL usage output (telemetry)} issue: https://github.com/elastic/elasticsearch/issues/115231 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index 3a774a7a37d93..1fef26989d845 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.client.Request; import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; @@ -27,8 +28,15 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest { private TestThreadPool threadPool; @Before - public void createThreadPool() { + public void setupTest() throws IOException { threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName()); + + Request loggingSettings = new Request("PUT", "_cluster/settings"); + loggingSettings.setJsonEntity(""" + {"persistent" : { + "logger.org.elasticsearch.xpack.ml.packageloader" : "DEBUG" + }}"""); + client().performRequest(loggingSettings); } @After @@ -64,7 +72,7 @@ private static void assertDefaultElserConfig(Map modelConfig) { assertThat( modelConfig.toString(), adaptiveAllocations, - Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 8)) + Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32)) ); } @@ -99,7 +107,7 @@ private static void assertDefaultE5Config(Map modelConfig) { assertThat( modelConfig.toString(), adaptiveAllocations, - Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 8)) + Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32)) ); } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index e92aff74be463..2e08b845f6593 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.ml.packageloader.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; @@ -56,10 +58,12 @@ */ final class ModelLoaderUtils { + private static final Logger logger = LogManager.getLogger(ModelLoaderUtils.class); + public static String METADATA_FILE_EXTENSION = ".metadata.json"; public static String MODEL_FILE_EXTENSION = ".pt"; - private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(20, ByteSizeUnit.MB); + private static final ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(20, ByteSizeUnit.MB); private static final String VOCABULARY = "vocabulary"; private static final String MERGES = "merges"; private static final String SCORES = "scores"; @@ -83,6 +87,7 @@ record BytesAndPartIndex(BytesArray bytes, int partIndex) {} private final AtomicInteger currentPart; private final int lastPartNumber; private final byte[] buf; + private final RequestRange range; // TODO debug only HttpStreamChunker(URI uri, RequestRange range, int chunkSize) { var inputStream = getHttpOrHttpsInputStream(uri, range); @@ -91,6 +96,7 @@ record BytesAndPartIndex(BytesArray bytes, int partIndex) {} this.lastPartNumber = range.startPart() + range.numParts(); this.currentPart = new AtomicInteger(range.startPart()); this.buf = new byte[chunkSize]; + this.range = range; } // This ctor exists for testing purposes only. @@ -100,6 +106,7 @@ record BytesAndPartIndex(BytesArray bytes, int partIndex) {} this.lastPartNumber = range.startPart() + range.numParts(); this.currentPart = new AtomicInteger(range.startPart()); this.buf = new byte[chunkSize]; + this.range = range; } public boolean hasNext() { @@ -113,6 +120,7 @@ public BytesAndPartIndex next() throws IOException { int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); // EOF?? if (read == -1) { + logger.debug("end of stream, " + bytesRead + " bytes read"); break; } bytesRead += read; @@ -122,6 +130,7 @@ public BytesAndPartIndex next() throws IOException { totalBytesRead.addAndGet(bytesRead); return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement()); } else { + logger.warn("Empty part in range " + range + ", current part=" + currentPart.get() + ", last part=" + lastPartNumber); return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get()); } }