diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchBuildCompletePlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchBuildCompletePlugin.java index 14baa55794c95..b1207a2f5161d 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchBuildCompletePlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchBuildCompletePlugin.java @@ -29,6 +29,8 @@ import org.gradle.api.provider.Property; import org.gradle.api.tasks.Input; import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; @@ -47,6 +49,8 @@ public abstract class ElasticsearchBuildCompletePlugin implements Plugin { + private static final Logger log = LoggerFactory.getLogger(ElasticsearchBuildCompletePlugin.class); + @Inject protected abstract FlowScope getFlowScope(); @@ -241,8 +245,11 @@ private static void createBuildArchiveTar(List files, File projectDir, Fil tOut.setLongFileMode(TarArchiveOutputStream.LONGFILE_GNU); tOut.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_STAR); for (Path path : files.stream().map(File::toPath).toList()) { - if (!Files.isRegularFile(path)) { - throw new IOException("Support only file!"); + if (Files.exists(path) == false) { + log.warn("File disappeared before it could be added to CI archive: " + path); + continue; + } else if (!Files.isRegularFile(path)) { + throw new IOException("Support only file!: " + path); } long entrySize = Files.size(path); diff --git a/docs/changelog/117229.yaml b/docs/changelog/117229.yaml new file mode 100644 index 0000000000000..f1b859c03e4fa --- /dev/null +++ b/docs/changelog/117229.yaml @@ -0,0 +1,6 @@ +pr: 117229 +summary: "In this pr, a 400 error is returned when _source / _seq_no / _feature /\ + \ _nested_path / _field_names is requested, rather a 5xx" +area: Search +type: bug +issues: [] diff --git a/docs/changelog/117731.yaml b/docs/changelog/117731.yaml new file mode 100644 index 0000000000000..f69cd5bf31100 --- /dev/null +++ b/docs/changelog/117731.yaml @@ -0,0 +1,5 @@ +pr: 117731 +summary: Add cluster level reduction +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/117842.yaml b/docs/changelog/117842.yaml new file mode 100644 index 0000000000000..9b528a158288c --- /dev/null +++ b/docs/changelog/117842.yaml @@ -0,0 +1,5 @@ +pr: 117842 +summary: Limit size of `Literal#toString` +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/117865.yaml b/docs/changelog/117865.yaml new file mode 100644 index 0000000000000..33dc497725f92 --- /dev/null +++ b/docs/changelog/117865.yaml @@ -0,0 +1,5 @@ +pr: 117865 +summary: Fix BWC for ES|QL cluster request +area: ES|QL +type: bug +issues: [] diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/RankFeatureMetaFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/RankFeatureMetaFieldMapper.java index 15398b1f178ee..ed1cc57b84863 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/RankFeatureMetaFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/RankFeatureMetaFieldMapper.java @@ -48,7 +48,7 @@ public String typeName() { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - throw new UnsupportedOperationException("Cannot fetch values for internal field [" + typeName() + "]."); + throw new IllegalArgumentException("Cannot fetch values for internal field [" + typeName() + "]."); } @Override diff --git a/muted-tests.yml b/muted-tests.yml index 9d4cb1a89a963..e26f21eb14920 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -409,3 +409,6 @@ tests: - class: org.elasticsearch.search.ccs.CrossClusterIT method: testCancel issue: https://github.com/elastic/elasticsearch/issues/108061 +- class: org.elasticsearch.xpack.ml.integration.RegressionIT + method: testTwoJobsWithSameRandomizeSeedUseSameTrainingSet + issue: https://github.com/elastic/elasticsearch/issues/117805 diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index fd07ef098b334..e1b51a3e1a6ae 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -253,4 +253,5 @@ tasks.named("yamlRestTestV7CompatTransform").configure({ task -> task.skipTest("logsdb/20_source_mapping/stored _source mode is supported", "no longer serialize source_mode") task.skipTest("logsdb/20_source_mapping/include/exclude is supported with stored _source", "no longer serialize source_mode") task.skipTest("logsdb/20_source_mapping/synthetic _source is default", "no longer serialize source_mode") + task.skipTest("search/520_fetch_fields/fetch _seq_no via fields", "error code is changed from 5xx to 400 in 9.0") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/520_fetch_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/520_fetch_fields.yml index 2b309f502f0c2..9a43199755d75 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/520_fetch_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/520_fetch_fields.yml @@ -128,18 +128,88 @@ fetch _seq_no via stored_fields: --- fetch _seq_no via fields: + - requires: + cluster_features: ["meta_fetch_fields_error_code_changed"] + reason: The fields_api returns a 400 instead a 5xx when _seq_no is requested via fields - do: - catch: "request" + catch: bad_request search: index: test body: fields: [ _seq_no ] - # This should be `unauthorized` (401) or `forbidden` (403) or at least `bad request` (400) - # while instead it is mapped to an `internal_server_error (500)` - - match: { status: 500 } - - match: { error.root_cause.0.type: unsupported_operation_exception } + - match: { status: 400 } + - match: { error.root_cause.0.type: illegal_argument_exception } + - match: { error.root_cause.0.reason: "error fetching [_seq_no]: Cannot fetch values for internal field [_seq_no]." } + +--- +fetch _source via fields: + - requires: + cluster_features: ["meta_fetch_fields_error_code_changed"] + reason: The fields_api returns a 400 instead a 5xx when _seq_no is requested via fields + + - do: + catch: bad_request + search: + index: test + body: + fields: [ _source ] + + - match: { status: 400 } + - match: { error.root_cause.0.type: illegal_argument_exception } + - match: { error.root_cause.0.reason: "error fetching [_source]: Cannot fetch values for internal field [_source]." } + +--- +fetch _feature via fields: + - requires: + cluster_features: ["meta_fetch_fields_error_code_changed"] + reason: The fields_api returns a 400 instead a 5xx when _seq_no is requested via fields + + - do: + catch: bad_request + search: + index: test + body: + fields: [ _feature ] + + - match: { status: 400 } + - match: { error.root_cause.0.type: illegal_argument_exception } + - match: { error.root_cause.0.reason: "error fetching [_feature]: Cannot fetch values for internal field [_feature]." } + +--- +fetch _nested_path via fields: + - requires: + cluster_features: ["meta_fetch_fields_error_code_changed"] + reason: The fields_api returns a 400 instead a 5xx when _seq_no is requested via fields + + - do: + catch: bad_request + search: + index: test + body: + fields: [ _nested_path ] + + - match: { status: 400 } + - match: { error.root_cause.0.type: illegal_argument_exception } + - match: { error.root_cause.0.reason: "error fetching [_nested_path]: Cannot fetch values for internal field [_nested_path]." } + +--- +fetch _field_names via fields: + - requires: + cluster_features: ["meta_fetch_fields_error_code_changed"] + reason: The fields_api returns a 400 instead a 5xx when _seq_no is requested via fields + + - do: + catch: bad_request + search: + index: test + body: + fields: [ _field_names ] + + - match: { status: 400 } + - match: { error.root_cause.0.type: illegal_argument_exception } + - match: { error.root_cause.0.reason: "error fetching [_field_names]: Cannot fetch values for internal field [_field_names]." } --- fetch fields with none stored_fields: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java index bf11c1d69bcc6..671f60e2b9d5e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java @@ -495,7 +495,7 @@ public void testScriptScore() throws ExecutionException, InterruptedException, I for (SignificantTerms.Bucket bucket : sigTerms.getBuckets()) { assertThat( bucket.getSignificanceScore(), - is((double) bucket.getSubsetDf() + bucket.getSubsetSize() + bucket.getSupersetDf() + bucket.getSupersetSize()) + is((double) bucket.getSubsetDf() + sigTerms.getSubsetSize() + bucket.getSupersetDf() + sigTerms.getSupersetSize()) ); } } diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 5990e544cb8f1..841233c116c7f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -454,8 +454,8 @@ org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat, - org.elasticsearch.index.codec.vectors.ES816BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.ES816HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinarizedByteVectorValues.java similarity index 97% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinarizedByteVectorValues.java index 73dd4273a794e..76269628e8442 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinarizedByteVectorValues.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinaryQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizer.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/BinaryQuantizer.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizer.java index 192fb9092ac3a..5bdee45d2bddf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinaryQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizer.java @@ -17,11 +17,13 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java similarity index 95% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java index f4d22edc6dfdb..656d872798060 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.VectorSimilarityFunction; @@ -26,6 +26,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; @@ -35,10 +37,10 @@ import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; /** Vector scorer over binarized vector values */ -public class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer { +class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer { private final FlatVectorsScorer nonQuantizedDelegate; - public ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { this.nonQuantizedDelegate = nonQuantizedDelegate; } @@ -144,10 +146,10 @@ public RandomVectorScorerSupplier copy() throws IOException { } /** A binarized query representing its quantized form along with factors */ - public record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} + record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} /** Vector scorer over binarized vector values */ - public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { private final BinaryQueryVector queryVector; private final RandomAccessBinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; @@ -155,7 +157,7 @@ public static class BinarizedRandomVectorScorer extends RandomVectorScorer.Abstr private final float sqrtDimensions; private final float maxX1; - public BinarizedRandomVectorScorer( + BinarizedRandomVectorScorer( BinaryQueryVector queryVectors, RandomAccessBinarizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java index e32aea0fb04ae..d864ec5dee8c5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java index b0378fee6793d..55f9bec577d82 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -42,6 +42,7 @@ import org.apache.lucene.util.SuppressForbidden; import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.IOException; import java.util.HashMap; @@ -54,7 +55,7 @@ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ @SuppressForbidden(reason = "Lucene classes") -public class ES816BinaryQuantizedVectorsReader extends FlatVectorsReader { +class ES816BinaryQuantizedVectorsReader extends FlatVectorsReader { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES816BinaryQuantizedVectorsReader.class); @@ -63,7 +64,7 @@ public class ES816BinaryQuantizedVectorsReader extends FlatVectorsReader { private final FlatVectorsReader rawVectorsReader; private final ES816BinaryFlatVectorsScorer vectorScorer; - public ES816BinaryQuantizedVectorsReader( + ES816BinaryQuantizedVectorsReader( SegmentReadState state, FlatVectorsReader rawVectorsReader, ES816BinaryFlatVectorsScorer vectorsScorer diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java index 92837a8ffce45..11f8004b7b79a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; @@ -48,6 +48,8 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.Closeable; import java.io.IOException; @@ -61,14 +63,14 @@ import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; -import static org.elasticsearch.index.codec.vectors.ES816BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; -import static org.elasticsearch.index.codec.vectors.ES816BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ @SuppressForbidden(reason = "Lucene classes") -public class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { +class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES816BinaryQuantizedVectorsWriter.class); private final SegmentWriteState segmentWriteState; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java similarity index 99% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java index 097cdffff6ae4..52f9f14b7bf97 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/OffHeapBinarizedVectorValues.java similarity index 97% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/OffHeapBinarizedVectorValues.java index 628480e273b34..76ca98f99c5f5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/OffHeapBinarizedVectorValues.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; @@ -29,6 +29,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.IOException; import java.nio.ByteBuffer; @@ -37,7 +38,7 @@ import static org.elasticsearch.index.codec.vectors.BQVectorUtils.constSqrt; /** Binarized vector values loaded from off-heap */ -public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues implements RandomAccessBinarizedByteVectorValues { +abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues implements RandomAccessBinarizedByteVectorValues { protected final int dimension; protected final int size; @@ -251,10 +252,10 @@ public static OffHeapBinarizedVectorValues load( } /** Dense off-heap binarized vector values */ - public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { private int doc = -1; - public DenseOffHeapVectorValues( + DenseOffHeapVectorValues( int dimension, int size, float[] centroid, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/RandomAccessBinarizedByteVectorValues.java similarity index 96% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/es816/RandomAccessBinarizedByteVectorValues.java index 5163baf617c29..2ca58bd00904c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/RandomAccessBinarizedByteVectorValues.java @@ -17,10 +17,11 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldNamesFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldNamesFieldMapper.java index 565b1ff28a39f..425e3c664c262 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldNamesFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldNamesFieldMapper.java @@ -135,7 +135,7 @@ public boolean isEnabled() { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - throw new UnsupportedOperationException("Cannot fetch values for internal field [" + name() + "]."); + throw new IllegalArgumentException("Cannot fetch values for internal field [" + name() + "]."); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 365919d7852db..e56fc19b4e3a3 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -55,6 +55,8 @@ public Set getFeatures() { "mapper.constant_keyword.synthetic_source_write_fix" ); + public static final NodeFeature META_FETCH_FIELDS_ERROR_CODE_CHANGED = new NodeFeature("meta_fetch_fields_error_code_changed"); + @Override public Set getTestFeatures() { return Set.of( @@ -64,7 +66,8 @@ public Set getTestFeatures() { SourceFieldMapper.SOURCE_MODE_FROM_INDEX_SETTING, IgnoredSourceFieldMapper.ALWAYS_STORE_OBJECT_ARRAYS_IN_NESTED_OBJECTS, MapperService.LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT, - CONSTANT_KEYWORD_SYNTHETIC_SOURCE_WRITE_FIX + CONSTANT_KEYWORD_SYNTHETIC_SOURCE_WRITE_FIX, + META_FETCH_FIELDS_ERROR_CODE_CHANGED ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/NestedPathFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/NestedPathFieldMapper.java index b22c3a12fcda3..1cd752dc34403 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/NestedPathFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/NestedPathFieldMapper.java @@ -67,7 +67,7 @@ public Query existsQuery(SearchExecutionContext context) { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - throw new UnsupportedOperationException("Cannot fetch values for internal field [" + name() + "]."); + throw new IllegalArgumentException("Cannot fetch values for internal field [" + name() + "]."); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SeqNoFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SeqNoFieldMapper.java index e126102b0f3c2..66ee42dfc56f9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SeqNoFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SeqNoFieldMapper.java @@ -168,7 +168,7 @@ public boolean mayExistInIndex(SearchExecutionContext context) { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - throw new UnsupportedOperationException("Cannot fetch values for internal field [" + name() + "]."); + throw new IllegalArgumentException("Cannot fetch values for internal field [" + name() + "]."); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index 9f34b9d4afb9e..39dea35bfc2d2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -324,7 +324,7 @@ public String typeName() { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - throw new UnsupportedOperationException("Cannot fetch values for internal field [" + name() + "]."); + throw new IllegalArgumentException("Cannot fetch values for internal field [" + name() + "]."); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 582f5f1d3e881..a68ce88609bfc 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -45,8 +45,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; -import org.elasticsearch.index.codec.vectors.ES816BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.ES816HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 6e34ff401e9bb..84dd2e7b1e529 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -985,7 +985,7 @@ SignificantStringTerms.Bucket[] buildBuckets(int size) { @Override SignificantStringTerms.Bucket buildEmptyTemporaryBucket() { - return new SignificantStringTerms.Bucket(new BytesRef(), 0, 0, 0, 0, null, format, 0); + return new SignificantStringTerms.Bucket(new BytesRef(), 0, 0, null, format, 0); } private long subsetSize(long owningBucketOrd) { @@ -994,22 +994,19 @@ private long subsetSize(long owningBucketOrd) { } @Override - BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) - throws IOException { + BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) { long subsetSize = subsetSize(owningBucketOrd); return (spare, globalOrd, bucketOrd, docCount) -> { spare.bucketOrd = bucketOrd; oversizedCopy(lookupGlobalOrd.apply(globalOrd), spare.termBytes); spare.subsetDf = docCount; - spare.subsetSize = subsetSize; spare.supersetDf = backgroundFrequencies.freq(spare.termBytes); - spare.supersetSize = supersetSize; /* * During shard-local down-selection we use subset/superset stats * that are for this shard only. Back at the central reducer these * properties will be updated with global stats. */ - spare.updateScore(significanceHeuristic); + spare.updateScore(significanceHeuristic, subsetSize, supersetSize); }; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalMappedSignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalMappedSignificantTerms.java index 3f75a27306ab4..8c6d21cc74119 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalMappedSignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalMappedSignificantTerms.java @@ -59,7 +59,7 @@ protected InternalMappedSignificantTerms(StreamInput in, Bucket.Reader bucket subsetSize = in.readVLong(); supersetSize = in.readVLong(); significanceHeuristic = in.readNamedWriteable(SignificanceHeuristic.class); - buckets = in.readCollectionAsList(stream -> bucketReader.read(stream, subsetSize, supersetSize, format)); + buckets = in.readCollectionAsList(stream -> bucketReader.read(stream, format)); } @Override @@ -91,12 +91,12 @@ public B getBucketByKey(String term) { } @Override - protected long getSubsetSize() { + public long getSubsetSize() { return subsetSize; } @Override - protected long getSupersetSize() { + public long getSupersetSize() { return supersetSize; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 6c0eb465d1f80..78ae2481f5d99 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -53,13 +53,11 @@ public abstract static class Bucket> extends InternalMultiBu */ @FunctionalInterface public interface Reader> { - B read(StreamInput in, long subsetSize, long supersetSize, DocValueFormat format) throws IOException; + B read(StreamInput in, DocValueFormat format) throws IOException; } long subsetDf; - long subsetSize; long supersetDf; - long supersetSize; /** * Ordinal of the bucket while it is being built. Not used after it is * returned from {@link Aggregator#buildAggregations(org.elasticsearch.common.util.LongArray)} and not @@ -70,16 +68,7 @@ public interface Reader> { protected InternalAggregations aggregations; final transient DocValueFormat format; - protected Bucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - DocValueFormat format - ) { - this.subsetSize = subsetSize; - this.supersetSize = supersetSize; + protected Bucket(long subsetDf, long supersetDf, InternalAggregations aggregations, DocValueFormat format) { this.subsetDf = subsetDf; this.supersetDf = supersetDf; this.aggregations = aggregations; @@ -89,9 +78,7 @@ protected Bucket( /** * Read from a stream. */ - protected Bucket(StreamInput in, long subsetSize, long supersetSize, DocValueFormat format) { - this.subsetSize = subsetSize; - this.supersetSize = supersetSize; + protected Bucket(StreamInput in, DocValueFormat format) { this.format = format; } @@ -105,20 +92,10 @@ public long getSupersetDf() { return supersetDf; } - @Override - public long getSupersetSize() { - return supersetSize; - } - - @Override - public long getSubsetSize() { - return subsetSize; - } - // TODO we should refactor to remove this, since buckets should be immutable after they are generated. // This can lead to confusing bugs if the bucket is re-created (via createBucket() or similar) without // the score - void updateScore(SignificanceHeuristic significanceHeuristic) { + void updateScore(SignificanceHeuristic significanceHeuristic, long subsetSize, long supersetSize) { score = significanceHeuristic.getScore(subsetDf, subsetSize, supersetDf, supersetSize); } @@ -262,13 +239,11 @@ public InternalAggregation get() { buckets.forEach(entry -> { final B b = createBucket( entry.value.subsetDf[0], - globalSubsetSize, entry.value.supersetDf[0], - globalSupersetSize, entry.value.reducer.getAggregations(), entry.value.reducer.getProto() ); - b.updateScore(heuristic); + b.updateScore(heuristic, globalSubsetSize, globalSupersetSize); if (((b.score > 0) && (b.subsetDf >= minDocCount)) || reduceContext.isFinalReduce() == false) { final B removed = ordered.insertWithOverflow(b); if (removed == null) { @@ -317,9 +292,7 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { .map( b -> createBucket( samplingContext.scaleUp(b.subsetDf), - subsetSize, samplingContext.scaleUp(b.supersetDf), - supersetSize, InternalAggregations.finalizeSampling(b.aggregations, samplingContext), b ) @@ -328,14 +301,7 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { ); } - abstract B createBucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - B prototype - ); + abstract B createBucket(long subsetDf, long supersetDf, InternalAggregations aggregations, B prototype); protected abstract A create(long subsetSize, long supersetSize, List buckets); @@ -344,10 +310,6 @@ abstract B createBucket( */ protected abstract B[] createBucketsArray(int size); - protected abstract long getSubsetSize(); - - protected abstract long getSupersetSize(); - protected abstract SignificanceHeuristic getSignificanceHeuristic(); @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java index 6ae47d5975479..b96c495d37489 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java @@ -47,7 +47,6 @@ import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongConsumer; -import java.util.function.Supplier; import static org.elasticsearch.search.aggregations.InternalOrder.isKeyOrder; @@ -296,7 +295,7 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro try (ObjectArrayPriorityQueue ordered = buildPriorityQueue(size)) { B spare = null; BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningOrd); - Supplier emptyBucketBuilder = emptyBucketBuilder(owningOrd); + BucketUpdater bucketUpdater = bucketUpdater(owningOrd); while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts.increment(ordIdx, docCount); @@ -305,9 +304,9 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro } if (spare == null) { checkRealMemoryCBForInternalBucket(); - spare = emptyBucketBuilder.get(); + spare = buildEmptyBucket(); } - updateBucket(spare, ordsEnum, docCount); + bucketUpdater.updateBucket(spare, ordsEnum, docCount); spare = ordered.insertWithOverflow(spare); } @@ -348,9 +347,9 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro abstract void collectZeroDocEntriesIfNeeded(long owningBucketOrd, boolean excludeDeletedDocs) throws IOException; /** - * Build an empty temporary bucket. + * Build an empty bucket. */ - abstract Supplier emptyBucketBuilder(long owningBucketOrd); + abstract B buildEmptyBucket(); /** * Build a {@link PriorityQueue} to sort the buckets. After we've @@ -362,7 +361,7 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro * Update fields in {@code spare} to reflect information collected for * this bucket ordinal. */ - abstract void updateBucket(B spare, BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum, long docCount) throws IOException; + abstract BucketUpdater bucketUpdater(long owningBucketOrd); /** * Build an array to hold the "top" buckets for each ordinal. @@ -399,6 +398,10 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro abstract R buildEmptyResult(); } + interface BucketUpdater { + void updateBucket(B spare, BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum, long docCount) throws IOException; + } + /** * Builds results for the standard {@code terms} aggregation. */ @@ -490,8 +493,8 @@ private void collectZeroDocEntries(BinaryDocValues values, Bits liveDocs, int ma } @Override - Supplier emptyBucketBuilder(long owningBucketOrd) { - return () -> new StringTerms.Bucket(new BytesRef(), 0, null, showTermDocCountError, 0, format); + StringTerms.Bucket buildEmptyBucket() { + return new StringTerms.Bucket(new BytesRef(), 0, null, showTermDocCountError, 0, format); } @Override @@ -500,10 +503,12 @@ ObjectArrayPriorityQueue buildPriorityQueue(int size) { } @Override - void updateBucket(StringTerms.Bucket spare, BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum, long docCount) throws IOException { - ordsEnum.readValue(spare.termBytes); - spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); + BucketUpdater bucketUpdater(long owningBucketOrd) { + return (spare, ordsEnum, docCount) -> { + ordsEnum.readValue(spare.termBytes); + spare.docCount = docCount; + spare.bucketOrd = ordsEnum.ord(); + }; } @Override @@ -615,9 +620,8 @@ public void collect(int doc, long owningBucketOrd) throws IOException { void collectZeroDocEntriesIfNeeded(long owningBucketOrd, boolean excludeDeletedDocs) throws IOException {} @Override - Supplier emptyBucketBuilder(long owningBucketOrd) { - long subsetSize = subsetSizes.get(owningBucketOrd); - return () -> new SignificantStringTerms.Bucket(new BytesRef(), 0, subsetSize, 0, 0, null, format, 0); + SignificantStringTerms.Bucket buildEmptyBucket() { + return new SignificantStringTerms.Bucket(new BytesRef(), 0, 0, null, format, 0); } @Override @@ -626,20 +630,20 @@ ObjectArrayPriorityQueue buildPriorityQueue(int s } @Override - void updateBucket(SignificantStringTerms.Bucket spare, BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum, long docCount) - throws IOException { - - ordsEnum.readValue(spare.termBytes); - spare.bucketOrd = ordsEnum.ord(); - spare.subsetDf = docCount; - spare.supersetDf = backgroundFrequencies.freq(spare.termBytes); - spare.supersetSize = supersetSize; - /* - * During shard-local down-selection we use subset/superset stats - * that are for this shard only. Back at the central reducer these - * properties will be updated with global stats. - */ - spare.updateScore(significanceHeuristic); + BucketUpdater bucketUpdater(long owningBucketOrd) { + long subsetSize = subsetSizes.get(owningBucketOrd); + return (spare, ordsEnum, docCount) -> { + ordsEnum.readValue(spare.termBytes); + spare.bucketOrd = ordsEnum.ord(); + spare.subsetDf = docCount; + spare.supersetDf = backgroundFrequencies.freq(spare.termBytes); + /* + * During shard-local down-selection we use subset/superset stats + * that are for this shard only. Back at the central reducer these + * properties will be updated with global stats. + */ + spare.updateScore(significanceHeuristic, subsetSize, supersetSize); + }; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java index ce89b95b76a05..5d4c15d8a3b80 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java @@ -43,7 +43,6 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; -import java.util.function.Supplier; import static java.util.Collections.emptyList; import static org.elasticsearch.search.aggregations.InternalOrder.isKeyOrder; @@ -177,7 +176,7 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro try (ObjectArrayPriorityQueue ordered = buildPriorityQueue(size)) { B spare = null; BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); - Supplier emptyBucketBuilder = emptyBucketBuilder(owningBucketOrd); + BucketUpdater bucketUpdater = bucketUpdater(owningBucketOrd); while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts.increment(ordIdx, docCount); @@ -186,9 +185,9 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro } if (spare == null) { checkRealMemoryCBForInternalBucket(); - spare = emptyBucketBuilder.get(); + spare = buildEmptyBucket(); } - updateBucket(spare, ordsEnum, docCount); + bucketUpdater.updateBucket(spare, ordsEnum, docCount); spare = ordered.insertWithOverflow(spare); } @@ -240,17 +239,16 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro abstract B[] buildBuckets(int size); /** - * Build a {@linkplain Supplier} that can be used to build "empty" - * buckets. Those buckets will then be {@link #updateBucket updated} + * Build an empty bucket. Those buckets will then be {@link #bucketUpdater(long)} updated} * for each collected bucket. */ - abstract Supplier emptyBucketBuilder(long owningBucketOrd); + abstract B buildEmptyBucket(); /** * Update fields in {@code spare} to reflect information collected for * this bucket ordinal. */ - abstract void updateBucket(B spare, BucketOrdsEnum ordsEnum, long docCount) throws IOException; + abstract BucketUpdater bucketUpdater(long owningBucketOrd); /** * Build a {@link ObjectArrayPriorityQueue} to sort the buckets. After we've @@ -282,6 +280,10 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro abstract R buildEmptyResult(); } + interface BucketUpdater { + void updateBucket(B spare, BucketOrdsEnum ordsEnum, long docCount) throws IOException; + } + abstract class StandardTermsResultStrategy, B extends InternalTerms.Bucket> extends ResultStrategy { protected final boolean showTermDocCountError; @@ -305,13 +307,6 @@ final void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException { buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); } - @Override - Supplier emptyBucketBuilder(long owningBucketOrd) { - return this::buildEmptyBucket; - } - - abstract B buildEmptyBucket(); - @Override final void collectZeroDocEntriesIfNeeded(long owningBucketOrd, boolean excludeDeletedDocs) throws IOException { if (bucketCountThresholds.getMinDocCount() != 0) { @@ -375,10 +370,12 @@ LongTerms.Bucket buildEmptyBucket() { } @Override - void updateBucket(LongTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) { - spare.term = ordsEnum.value(); - spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); + BucketUpdater bucketUpdater(long owningBucketOrd) { + return (LongTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) -> { + spare.term = ordsEnum.value(); + spare.docCount = docCount; + spare.bucketOrd = ordsEnum.ord(); + }; } @Override @@ -457,10 +454,12 @@ DoubleTerms.Bucket buildEmptyBucket() { } @Override - void updateBucket(DoubleTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) { - spare.term = NumericUtils.sortableLongToDouble(ordsEnum.value()); - spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); + BucketUpdater bucketUpdater(long owningBucketOrd) { + return (DoubleTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) -> { + spare.term = NumericUtils.sortableLongToDouble(ordsEnum.value()); + spare.docCount = docCount; + spare.bucketOrd = ordsEnum.ord(); + }; } @Override @@ -565,20 +564,22 @@ SignificantLongTerms.Bucket[] buildBuckets(int size) { } @Override - Supplier emptyBucketBuilder(long owningBucketOrd) { - long subsetSize = subsetSizes.get(owningBucketOrd); - return () -> new SignificantLongTerms.Bucket(0, subsetSize, 0, supersetSize, 0, null, format, 0); + SignificantLongTerms.Bucket buildEmptyBucket() { + return new SignificantLongTerms.Bucket(0, 0, 0, null, format, 0); } @Override - void updateBucket(SignificantLongTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) throws IOException { - spare.term = ordsEnum.value(); - spare.subsetDf = docCount; - spare.supersetDf = backgroundFrequencies.freq(spare.term); - spare.bucketOrd = ordsEnum.ord(); - // During shard-local down-selection we use subset/superset stats that are for this shard only - // Back at the central reducer these properties will be updated with global stats - spare.updateScore(significanceHeuristic); + BucketUpdater bucketUpdater(long owningBucketOrd) { + long subsetSize = subsetSizes.get(owningBucketOrd); + return (spare, ordsEnum, docCount) -> { + spare.term = ordsEnum.value(); + spare.subsetDf = docCount; + spare.supersetDf = backgroundFrequencies.freq(spare.term); + spare.bucketOrd = ordsEnum.ord(); + // During shard-local down-selection we use subset/superset stats that are for this shard only + // Back at the central reducer these properties will be updated with global stats + spare.updateScore(significanceHeuristic, subsetSize, supersetSize); + }; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTerms.java index 2aace2a714a26..17ea290b7aaaf 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTerms.java @@ -30,23 +30,14 @@ public static class Bucket extends InternalSignificantTerms.Bucket { long term; - public Bucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - long term, - InternalAggregations aggregations, - DocValueFormat format, - double score - ) { - super(subsetDf, subsetSize, supersetDf, supersetSize, aggregations, format); + public Bucket(long subsetDf, long supersetDf, long term, InternalAggregations aggregations, DocValueFormat format, double score) { + super(subsetDf, supersetDf, aggregations, format); this.term = term; this.score = score; } - Bucket(StreamInput in, long subsetSize, long supersetSize, DocValueFormat format) throws IOException { - super(in, subsetSize, supersetSize, format); + Bucket(StreamInput in, DocValueFormat format) throws IOException { + super(in, format); subsetDf = in.readVLong(); supersetDf = in.readVLong(); term = in.readLong(); @@ -136,16 +127,7 @@ public SignificantLongTerms create(List buckets) { @Override public Bucket createBucket(InternalAggregations aggregations, SignificantLongTerms.Bucket prototype) { - return new Bucket( - prototype.subsetDf, - prototype.subsetSize, - prototype.supersetDf, - prototype.supersetSize, - prototype.term, - aggregations, - prototype.format, - prototype.score - ); + return new Bucket(prototype.subsetDf, prototype.supersetDf, prototype.term, aggregations, prototype.format, prototype.score); } @Override @@ -169,14 +151,7 @@ protected Bucket[] createBucketsArray(int size) { } @Override - Bucket createBucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - SignificantLongTerms.Bucket prototype - ) { - return new Bucket(subsetDf, subsetSize, supersetDf, supersetSize, prototype.term, aggregations, format, prototype.score); + Bucket createBucket(long subsetDf, long supersetDf, InternalAggregations aggregations, SignificantLongTerms.Bucket prototype) { + return new Bucket(subsetDf, supersetDf, prototype.term, aggregations, format, prototype.score); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTerms.java index 791c09d3cbd99..b255f17d2843b 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTerms.java @@ -34,14 +34,12 @@ public static class Bucket extends InternalSignificantTerms.Bucket { public Bucket( BytesRef term, long subsetDf, - long subsetSize, long supersetDf, - long supersetSize, InternalAggregations aggregations, DocValueFormat format, double score ) { - super(subsetDf, subsetSize, supersetDf, supersetSize, aggregations, format); + super(subsetDf, supersetDf, aggregations, format); this.termBytes = term; this.score = score; } @@ -49,8 +47,8 @@ public Bucket( /** * Read from a stream. */ - public Bucket(StreamInput in, long subsetSize, long supersetSize, DocValueFormat format) throws IOException { - super(in, subsetSize, supersetSize, format); + public Bucket(StreamInput in, DocValueFormat format) throws IOException { + super(in, format); termBytes = in.readBytesRef(); subsetDf = in.readVLong(); supersetDf = in.readVLong(); @@ -140,16 +138,7 @@ public SignificantStringTerms create(List buckets @Override public Bucket createBucket(InternalAggregations aggregations, SignificantStringTerms.Bucket prototype) { - return new Bucket( - prototype.termBytes, - prototype.subsetDf, - prototype.subsetSize, - prototype.supersetDf, - prototype.supersetSize, - aggregations, - prototype.format, - prototype.score - ); + return new Bucket(prototype.termBytes, prototype.subsetDf, prototype.supersetDf, aggregations, prototype.format, prototype.score); } @Override @@ -173,14 +162,7 @@ protected Bucket[] createBucketsArray(int size) { } @Override - Bucket createBucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - SignificantStringTerms.Bucket prototype - ) { - return new Bucket(prototype.termBytes, subsetDf, subsetSize, supersetDf, supersetSize, aggregations, format, prototype.score); + Bucket createBucket(long subsetDf, long supersetDf, InternalAggregations aggregations, SignificantStringTerms.Bucket prototype) { + return new Bucket(prototype.termBytes, subsetDf, supersetDf, aggregations, format, prototype.score); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTerms.java index f02b5338eea74..e8f160193bc71 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTerms.java @@ -17,6 +17,18 @@ */ public interface SignificantTerms extends MultiBucketsAggregation, Iterable { + /** + * @return The numbers of docs in the subset (also known as "foreground set"). + * This number is equal to the document count of the containing aggregation. + */ + long getSubsetSize(); + + /** + * @return The numbers of docs in the superset (ordinarily the background count + * of the containing aggregation). + */ + long getSupersetSize(); + interface Bucket extends MultiBucketsAggregation.Bucket { /** @@ -30,24 +42,12 @@ interface Bucket extends MultiBucketsAggregation.Bucket { */ long getSubsetDf(); - /** - * @return The numbers of docs in the subset (also known as "foreground set"). - * This number is equal to the document count of the containing aggregation. - */ - long getSubsetSize(); - /** * @return The number of docs in the superset containing a particular term (also * known as the "background count" of the bucket) */ long getSupersetDf(); - /** - * @return The numbers of docs in the superset (ordinarily the background count - * of the containing aggregation). - */ - long getSupersetSize(); - } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/UnmappedSignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/UnmappedSignificantTerms.java index 8bd14a46bff96..6d1370f147f36 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/UnmappedSignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/UnmappedSignificantTerms.java @@ -40,16 +40,8 @@ public class UnmappedSignificantTerms extends InternalSignificantTerms { - private Bucket( - BytesRef term, - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - DocValueFormat format - ) { - super(subsetDf, subsetSize, supersetDf, supersetSize, aggregations, format); + private Bucket(BytesRef term, long subsetDf, long supersetDf, InternalAggregations aggregations, DocValueFormat format) { + super(subsetDf, supersetDf, aggregations, format); } } @@ -95,14 +87,7 @@ protected UnmappedSignificantTerms create(long subsetSize, long supersetSize, Li } @Override - Bucket createBucket( - long subsetDf, - long subsetSize, - long supersetDf, - long supersetSize, - InternalAggregations aggregations, - Bucket prototype - ) { + Bucket createBucket(long subsetDf, long supersetDf, InternalAggregations aggregations, Bucket prototype) { throw new UnsupportedOperationException("not supported for UnmappedSignificantTerms"); } @@ -153,12 +138,12 @@ protected SignificanceHeuristic getSignificanceHeuristic() { } @Override - protected long getSubsetSize() { + public long getSubsetSize() { return 0; } @Override - protected long getSupersetSize() { + public long getSupersetSize() { return 0; } } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 546586a9ff3c3..2fbe3c1fc1532 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -195,13 +195,10 @@ protected SearchHit nextDoc(int doc) throws IOException { context.shardTarget(), context.searcher().getIndexReader(), docIdsToLoad, - context.request().allowPartialSearchResults() + context.request().allowPartialSearchResults(), + context.queryResult() ); - if (docsIterator.isTimedOut()) { - context.queryResult().searchTimedOut(true); - } - if (context.isCancelled()) { for (SearchHit hit : hits) { // release all hits that would otherwise become owned and eventually released by SearchHits below diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java index df4e7649ffd3b..4a242f70e8d02 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java @@ -16,6 +16,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.SearchTimeoutException; import java.io.IOException; @@ -30,12 +31,6 @@ */ abstract class FetchPhaseDocsIterator { - private boolean timedOut = false; - - public boolean isTimedOut() { - return timedOut; - } - /** * Called when a new leaf reader is reached * @param ctx the leaf reader for this set of doc ids @@ -53,7 +48,13 @@ public boolean isTimedOut() { /** * Iterate over a set of docsIds within a particular shard and index reader */ - public final SearchHit[] iterate(SearchShardTarget shardTarget, IndexReader indexReader, int[] docIds, boolean allowPartialResults) { + public final SearchHit[] iterate( + SearchShardTarget shardTarget, + IndexReader indexReader, + int[] docIds, + boolean allowPartialResults, + QuerySearchResult querySearchResult + ) { SearchHit[] searchHits = new SearchHit[docIds.length]; DocIdToIndex[] docs = new DocIdToIndex[docIds.length]; for (int index = 0; index < docIds.length; index++) { @@ -69,12 +70,10 @@ public final SearchHit[] iterate(SearchShardTarget shardTarget, IndexReader inde int[] docsInLeaf = docIdsInLeaf(0, endReaderIdx, docs, ctx.docBase); try { setNextReader(ctx, docsInLeaf); - } catch (ContextIndexSearcher.TimeExceededException timeExceededException) { - if (allowPartialResults) { - timedOut = true; - return SearchHits.EMPTY; - } - throw new SearchTimeoutException(shardTarget, "Time exceeded"); + } catch (ContextIndexSearcher.TimeExceededException e) { + SearchTimeoutException.handleTimeout(allowPartialResults, shardTarget, querySearchResult); + assert allowPartialResults; + return SearchHits.EMPTY; } for (int i = 0; i < docs.length; i++) { try { @@ -88,15 +87,15 @@ public final SearchHit[] iterate(SearchShardTarget shardTarget, IndexReader inde currentDoc = docs[i].docId; assert searchHits[docs[i].index] == null; searchHits[docs[i].index] = nextDoc(docs[i].docId); - } catch (ContextIndexSearcher.TimeExceededException timeExceededException) { - if (allowPartialResults) { - timedOut = true; - SearchHit[] partialSearchHits = new SearchHit[i]; - System.arraycopy(searchHits, 0, partialSearchHits, 0, i); - return partialSearchHits; + } catch (ContextIndexSearcher.TimeExceededException e) { + if (allowPartialResults == false) { + purgeSearchHits(searchHits); } - purgeSearchHits(searchHits); - throw new SearchTimeoutException(shardTarget, "Time exceeded"); + SearchTimeoutException.handleTimeout(allowPartialResults, shardTarget, querySearchResult); + assert allowPartialResults; + SearchHit[] partialSearchHits = new SearchHit[i]; + System.arraycopy(searchHits, 0, partialSearchHits, 0, i); + return partialSearchHits; } } } catch (SearchTimeoutException e) { diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java index da5d2d093fbd8..f1c24af580110 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java @@ -162,8 +162,8 @@ public void setProfiler(QueryProfiler profiler) { * Add a {@link Runnable} that will be run on a regular basis while accessing documents in the * DirectoryReader but also while collecting them and check for query cancellation or timeout. */ - public Runnable addQueryCancellation(Runnable action) { - return this.cancellable.add(action); + public void addQueryCancellation(Runnable action) { + this.cancellable.add(action); } /** @@ -407,8 +407,16 @@ public void throwTimeExceededException() { } } - public static class TimeExceededException extends RuntimeException { + /** + * Exception thrown whenever a search timeout occurs. May be thrown by {@link ContextIndexSearcher} or {@link ExitableDirectoryReader}. + */ + public static final class TimeExceededException extends RuntimeException { // This exception should never be re-thrown, but we fill in the stacktrace to be able to trace where it does not get properly caught + + /** + * Created via {@link #throwTimeExceededException()} + */ + private TimeExceededException() {} } @Override @@ -552,14 +560,12 @@ public DirectoryReader getDirectoryReader() { } private static class MutableQueryTimeout implements ExitableDirectoryReader.QueryCancellation { - private final List runnables = new ArrayList<>(); - private Runnable add(Runnable action) { + private void add(Runnable action) { Objects.requireNonNull(action, "cancellation runnable should not be null"); assert runnables.contains(action) == false : "Cancellation runnable already added"; runnables.add(action); - return action; } private void remove(Runnable action) { diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index d17cd4f69dec7..40da2e2a03a77 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -217,10 +217,11 @@ static void addCollectorsAndSearch(SearchContext searchContext) throws QueryPhas queryResult.topDocs(queryPhaseResult.topDocsAndMaxScore(), queryPhaseResult.sortValueFormats()); if (searcher.timeExceeded()) { assert timeoutRunnable != null : "TimeExceededException thrown even though timeout wasn't set"; - if (searchContext.request().allowPartialSearchResults() == false) { - throw new SearchTimeoutException(searchContext.shardTarget(), "Time exceeded"); - } - queryResult.searchTimedOut(true); + SearchTimeoutException.handleTimeout( + searchContext.request().allowPartialSearchResults(), + searchContext.shardTarget(), + searchContext.queryResult() + ); } if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER) { queryResult.terminatedEarly(queryPhaseResult.terminatedAfter()); diff --git a/server/src/main/java/org/elasticsearch/search/query/SearchTimeoutException.java b/server/src/main/java/org/elasticsearch/search/query/SearchTimeoutException.java index 0ed64811fee28..e006f176ff91a 100644 --- a/server/src/main/java/org/elasticsearch/search/query/SearchTimeoutException.java +++ b/server/src/main/java/org/elasticsearch/search/query/SearchTimeoutException.java @@ -33,4 +33,17 @@ public SearchTimeoutException(StreamInput in) throws IOException { public RestStatus status() { return RestStatus.GATEWAY_TIMEOUT; } + + /** + * Propagate a timeout according to whether partial search results are allowed or not. + * In case partial results are allowed, a flag will be set to the provided {@link QuerySearchResult} to indicate that there was a + * timeout, but the execution will continue and partial results will be returned to the user. + * When partial results are disallowed, a {@link SearchTimeoutException} will be thrown and returned to the user. + */ + public static void handleTimeout(boolean allowPartialSearchResults, SearchShardTarget target, QuerySearchResult querySearchResult) { + if (allowPartialSearchResults == false) { + throw new SearchTimeoutException(target, "Time exceeded"); + } + querySearchResult.searchTimedOut(true); + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index 1227db5d8e1db..7e3646e7689cc 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -73,10 +73,11 @@ public static void execute(SearchContext context) { } catch (IOException e) { throw new ElasticsearchException("Rescore Phase Failed", e); } catch (ContextIndexSearcher.TimeExceededException e) { - if (context.request().allowPartialSearchResults() == false) { - throw new SearchTimeoutException(context.shardTarget(), "Time exceeded"); - } - context.queryResult().searchTimedOut(true); + SearchTimeoutException.handleTimeout( + context.request().allowPartialSearchResults(), + context.shardTarget(), + context.queryResult() + ); } } diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index c2201f5b1c319..389555e60b43b 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -3,5 +3,5 @@ org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat -org.elasticsearch.index.codec.vectors.ES816BinaryQuantizedVectorsFormat -org.elasticsearch.index.codec.vectors.ES816HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BinaryQuantizationTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizationTests.java similarity index 99% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/BinaryQuantizationTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizationTests.java index 32d717bd76f91..205cbb4119dd6 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BinaryQuantizationTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizationTests.java @@ -17,11 +17,13 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.util.Random; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java similarity index 99% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java index 04d4ef2079b99..ce3aaacf96858 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java @@ -17,12 +17,14 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.IOException; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java similarity index 98% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java index 0892436891ff1..077285d067d3b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; @@ -40,6 +40,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import java.io.IOException; import java.util.Locale; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java similarity index 99% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java index f607de57e1fd5..4fbe1211d7a27 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -17,7 +17,7 @@ * * Modifications copyright (C) 2024 Elasticsearch B.V. */ -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.es816; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTermsTestCase.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTermsTestCase.java index 6d49d6855caca..7e5d19977fe9f 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTermsTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTermsTestCase.java @@ -59,8 +59,6 @@ protected void assertSampled( InternalSignificantTerms.Bucket sampledBucket = sampledIt.next(); assertEquals(sampledBucket.subsetDf, samplingContext.scaleUp(reducedBucket.subsetDf)); assertEquals(sampledBucket.supersetDf, samplingContext.scaleUp(reducedBucket.supersetDf)); - assertEquals(sampledBucket.subsetSize, samplingContext.scaleUp(reducedBucket.subsetSize)); - assertEquals(sampledBucket.supersetSize, samplingContext.scaleUp(reducedBucket.supersetSize)); assertThat(sampledBucket.score, closeTo(reducedBucket.score, 1e-14)); } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTermsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTermsTests.java index a303199338783..92bfa2f6f89f4 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTermsTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantLongTermsTests.java @@ -49,17 +49,8 @@ public void setUp() throws Exception { Set terms = new HashSet<>(); for (int i = 0; i < numBuckets; ++i) { long term = randomValueOtherThanMany(l -> terms.add(l) == false, random()::nextLong); - SignificantLongTerms.Bucket bucket = new SignificantLongTerms.Bucket( - subsetDfs[i], - subsetSize, - supersetDfs[i], - supersetSize, - term, - aggs, - format, - 0 - ); - bucket.updateScore(significanceHeuristic); + SignificantLongTerms.Bucket bucket = new SignificantLongTerms.Bucket(subsetDfs[i], supersetDfs[i], term, aggs, format, 0); + bucket.updateScore(significanceHeuristic, subsetSize, supersetSize); buckets.add(bucket); } return new SignificantLongTerms(name, requiredSize, 1L, metadata, format, subsetSize, supersetSize, significanceHeuristic, buckets); @@ -90,8 +81,6 @@ public void setUp() throws Exception { randomLong(), randomNonNegativeLong(), randomNonNegativeLong(), - randomNonNegativeLong(), - randomNonNegativeLong(), InternalAggregations.EMPTY, format, 0 diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTermsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTermsTests.java index a91566c615eaf..7499831f371aa 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTermsTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantStringTermsTests.java @@ -42,17 +42,8 @@ public class SignificantStringTermsTests extends InternalSignificantTermsTestCas Set terms = new HashSet<>(); for (int i = 0; i < numBuckets; ++i) { BytesRef term = randomValueOtherThanMany(b -> terms.add(b) == false, () -> new BytesRef(randomAlphaOfLength(10))); - SignificantStringTerms.Bucket bucket = new SignificantStringTerms.Bucket( - term, - subsetDfs[i], - subsetSize, - supersetDfs[i], - supersetSize, - aggs, - format, - 0 - ); - bucket.updateScore(significanceHeuristic); + SignificantStringTerms.Bucket bucket = new SignificantStringTerms.Bucket(term, subsetDfs[i], supersetDfs[i], aggs, format, 0); + bucket.updateScore(significanceHeuristic, subsetSize, supersetSize); buckets.add(bucket); } return new SignificantStringTerms( @@ -93,8 +84,6 @@ public class SignificantStringTermsTests extends InternalSignificantTermsTestCas new BytesRef(randomAlphaOfLengthBetween(1, 10)), randomNonNegativeLong(), randomNonNegativeLong(), - randomNonNegativeLong(), - randomNonNegativeLong(), InternalAggregations.EMPTY, format, 0 diff --git a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java index d5e930321db95..c8d1b6721c64b 100644 --- a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java +++ b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java @@ -17,6 +17,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -77,7 +78,7 @@ protected SearchHit nextDoc(int doc) { } }; - SearchHit[] hits = it.iterate(null, reader, docs, randomBoolean()); + SearchHit[] hits = it.iterate(null, reader, docs, randomBoolean(), new QuerySearchResult()); assertThat(hits.length, equalTo(docs.length)); for (int i = 0; i < hits.length; i++) { @@ -125,7 +126,10 @@ protected SearchHit nextDoc(int doc) { } }; - Exception e = expectThrows(FetchPhaseExecutionException.class, () -> it.iterate(null, reader, docs, randomBoolean())); + Exception e = expectThrows( + FetchPhaseExecutionException.class, + () -> it.iterate(null, reader, docs, randomBoolean(), new QuerySearchResult()) + ); assertThat(e.getMessage(), containsString("Error running fetch phase for doc [" + badDoc + "]")); assertThat(e.getCause(), instanceOf(IllegalArgumentException.class)); diff --git a/server/src/test/java/org/elasticsearch/search/fetch/subphase/FieldFetcherTests.java b/server/src/test/java/org/elasticsearch/search/fetch/subphase/FieldFetcherTests.java index f01f760ed71c3..c5f1efe561c22 100644 --- a/server/src/test/java/org/elasticsearch/search/fetch/subphase/FieldFetcherTests.java +++ b/server/src/test/java/org/elasticsearch/search/fetch/subphase/FieldFetcherTests.java @@ -271,7 +271,7 @@ public void testMetadataFields() throws IOException { FieldNamesFieldMapper.NAME, NestedPathFieldMapper.name(IndexVersion.current()) )) { - expectThrows(UnsupportedOperationException.class, () -> fetchFields(mapperService, source, fieldname)); + expectThrows(IllegalArgumentException.class, () -> fetchFields(mapperService, source, fieldname)); } } diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 008a056e87901..8b9176a346e30 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -295,15 +295,10 @@ private Response concat(int evals) throws IOException { * Returns many moderately long strings. */ public void testManyConcat() throws IOException { + int strings = 300; initManyLongs(); - Response resp = manyConcat(300); - Map map = responseAsMap(resp); - ListMatcher columns = matchesList(); - for (int s = 0; s < 300; s++) { - columns = columns.item(matchesMap().entry("name", "str" + s).entry("type", "keyword")); - } - MapMatcher mapMatcher = matchesMap(); - assertMap(map, mapMatcher.entry("columns", columns).entry("values", any(List.class)).entry("took", greaterThanOrEqualTo(0))); + Response resp = manyConcat("FROM manylongs", strings); + assertManyStrings(resp, strings); } /** @@ -311,15 +306,24 @@ public void testManyConcat() throws IOException { */ public void testHugeManyConcat() throws IOException { initManyLongs(); - assertCircuitBreaks(() -> manyConcat(2000)); + assertCircuitBreaks(() -> manyConcat("FROM manylongs", 2000)); + } + + /** + * Returns many moderately long strings. + */ + public void testManyConcatFromRow() throws IOException { + int strings = 2000; + Response resp = manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", strings); + assertManyStrings(resp, strings); } /** * Tests that generate many moderately long strings. */ - private Response manyConcat(int strings) throws IOException { + private Response manyConcat(String init, int strings) throws IOException { StringBuilder query = startQuery(); - query.append("FROM manylongs | EVAL str = CONCAT("); + query.append(init).append(" | EVAL str = CONCAT("); query.append( Arrays.stream(new String[] { "a", "b", "c", "d", "e" }) .map(f -> "TO_STRING(" + f + ")") @@ -344,7 +348,64 @@ private Response manyConcat(int strings) throws IOException { query.append("str").append(s); } query.append("\"}"); - return query(query.toString(), null); + return query(query.toString(), "columns"); + } + + /** + * Returns many moderately long strings. + */ + public void testManyRepeat() throws IOException { + int strings = 30; + initManyLongs(); + Response resp = manyRepeat("FROM manylongs", strings); + assertManyStrings(resp, 30); + } + + /** + * Hits a circuit breaker by building many moderately long strings. + */ + public void testHugeManyRepeat() throws IOException { + initManyLongs(); + assertCircuitBreaks(() -> manyRepeat("FROM manylongs", 75)); + } + + /** + * Returns many moderately long strings. + */ + public void testManyRepeatFromRow() throws IOException { + int strings = 10000; + Response resp = manyRepeat("ROW a = 99", strings); + assertManyStrings(resp, strings); + } + + /** + * Tests that generate many moderately long strings. + */ + private Response manyRepeat(String init, int strings) throws IOException { + StringBuilder query = startQuery(); + query.append(init).append(" | EVAL str = TO_STRING(a)"); + for (int s = 0; s < strings; s++) { + query.append(",\nstr").append(s).append("=REPEAT(str, 10000)"); + } + query.append("\n|KEEP "); + for (int s = 0; s < strings; s++) { + if (s != 0) { + query.append(", "); + } + query.append("str").append(s); + } + query.append("\"}"); + return query(query.toString(), "columns"); + } + + private void assertManyStrings(Response resp, int strings) throws IOException { + Map map = responseAsMap(resp); + ListMatcher columns = matchesList(); + for (int s = 0; s < strings; s++) { + columns = columns.item(matchesMap().entry("name", "str" + s).entry("type", "keyword")); + } + MapMatcher mapMatcher = matchesMap(); + assertMap(map, mapMatcher.entry("columns", columns)); } public void testManyEval() throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java index ae5083c245538..a3c03526c9b93 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java @@ -95,22 +95,20 @@ public void testStreamResponse() throws Exception { InternalMappedSignificantTerms read = (InternalMappedSignificantTerms) in.readNamedWriteable(InternalAggregation.class); assertEquals(sigTerms.getSignificanceHeuristic(), read.getSignificanceHeuristic()); + assertThat(read.getSubsetSize(), equalTo(10L)); + assertThat(read.getSupersetSize(), equalTo(20L)); SignificantTerms.Bucket originalBucket = sigTerms.getBuckets().get(0); SignificantTerms.Bucket streamedBucket = read.getBuckets().get(0); assertThat(originalBucket.getKeyAsString(), equalTo(streamedBucket.getKeyAsString())); assertThat(originalBucket.getSupersetDf(), equalTo(streamedBucket.getSupersetDf())); assertThat(originalBucket.getSubsetDf(), equalTo(streamedBucket.getSubsetDf())); - assertThat(streamedBucket.getSubsetSize(), equalTo(10L)); - assertThat(streamedBucket.getSupersetSize(), equalTo(20L)); } InternalMappedSignificantTerms getRandomSignificantTerms(SignificanceHeuristic heuristic) { if (randomBoolean()) { SignificantLongTerms.Bucket bucket = new SignificantLongTerms.Bucket( 1, - 2, 3, - 4, 123, InternalAggregations.EMPTY, DocValueFormat.RAW, @@ -121,9 +119,7 @@ public void testStreamResponse() throws Exception { SignificantStringTerms.Bucket bucket = new SignificantStringTerms.Bucket( new BytesRef("someterm"), 1, - 2, 3, - 4, InternalAggregations.EMPTY, DocValueFormat.RAW, randomDoubleBetween(0, 100, true) @@ -136,15 +132,13 @@ public void testReduce() { List aggs = createInternalAggregations(); AggregationReduceContext context = InternalAggregationTestCase.emptyReduceContextBuilder().forFinalReduction(); SignificantTerms reducedAgg = (SignificantTerms) InternalAggregationTestCase.reduce(aggs, context); + assertThat(reducedAgg.getSubsetSize(), equalTo(16L)); + assertThat(reducedAgg.getSupersetSize(), equalTo(30L)); assertThat(reducedAgg.getBuckets().size(), equalTo(2)); assertThat(reducedAgg.getBuckets().get(0).getSubsetDf(), equalTo(8L)); - assertThat(reducedAgg.getBuckets().get(0).getSubsetSize(), equalTo(16L)); assertThat(reducedAgg.getBuckets().get(0).getSupersetDf(), equalTo(10L)); - assertThat(reducedAgg.getBuckets().get(0).getSupersetSize(), equalTo(30L)); assertThat(reducedAgg.getBuckets().get(1).getSubsetDf(), equalTo(8L)); - assertThat(reducedAgg.getBuckets().get(1).getSubsetSize(), equalTo(16L)); assertThat(reducedAgg.getBuckets().get(1).getSupersetDf(), equalTo(10L)); - assertThat(reducedAgg.getBuckets().get(1).getSupersetSize(), equalTo(30L)); } public void testBasicScoreProperties() { @@ -234,9 +228,9 @@ private List createInternalAggregations() { : new AbstractSignificanceHeuristicTestCase.LongTestAggFactory(); List aggs = new ArrayList<>(); - aggs.add(factory.createAggregation(significanceHeuristic, 4, 10, 1, (f, i) -> f.createBucket(4, 4, 5, 10, 0))); - aggs.add(factory.createAggregation(significanceHeuristic, 4, 10, 1, (f, i) -> f.createBucket(4, 4, 5, 10, 1))); - aggs.add(factory.createAggregation(significanceHeuristic, 8, 10, 2, (f, i) -> f.createBucket(4, 4, 5, 10, i))); + aggs.add(factory.createAggregation(significanceHeuristic, 4, 10, 1, (f, i) -> f.createBucket(4, 5, 0))); + aggs.add(factory.createAggregation(significanceHeuristic, 4, 10, 1, (f, i) -> f.createBucket(4, 5, 1))); + aggs.add(factory.createAggregation(significanceHeuristic, 8, 10, 2, (f, i) -> f.createBucket(4, 5, i))); return aggs; } @@ -254,7 +248,7 @@ final A createAggregation( abstract A createAggregation(SignificanceHeuristic significanceHeuristic, long subsetSize, long supersetSize, List buckets); - abstract B createBucket(long subsetDF, long subsetSize, long supersetDF, long supersetSize, long label); + abstract B createBucket(long subsetDF, long supersetDF, long label); } private class StringTestAggFactory extends TestAggFactory { @@ -279,13 +273,11 @@ SignificantStringTerms createAggregation( } @Override - SignificantStringTerms.Bucket createBucket(long subsetDF, long subsetSize, long supersetDF, long supersetSize, long label) { + SignificantStringTerms.Bucket createBucket(long subsetDF, long supersetDF, long label) { return new SignificantStringTerms.Bucket( new BytesRef(Long.toString(label).getBytes(StandardCharsets.UTF_8)), subsetDF, - subsetSize, supersetDF, - supersetSize, InternalAggregations.EMPTY, DocValueFormat.RAW, 0 @@ -315,17 +307,8 @@ SignificantLongTerms createAggregation( } @Override - SignificantLongTerms.Bucket createBucket(long subsetDF, long subsetSize, long supersetDF, long supersetSize, long label) { - return new SignificantLongTerms.Bucket( - subsetDF, - subsetSize, - supersetDF, - supersetSize, - label, - InternalAggregations.EMPTY, - DocValueFormat.RAW, - 0 - ); + SignificantLongTerms.Bucket createBucket(long subsetDF, long supersetDF, long label) { + return new SignificantLongTerms.Bucket(subsetDF, supersetDF, label, InternalAggregations.EMPTY, DocValueFormat.RAW, 0); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index e2435c3396fa8..f5923a4942634 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -71,6 +71,8 @@ import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams; import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.TextExpansionQueryBuilder; import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage; import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage; @@ -398,6 +400,14 @@ public List getNamedXContent() { @Override public List> getQueries() { return List.of( + new QuerySpec<>(SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new, SparseVectorQueryBuilder::fromXContent), + new QuerySpec( + TextExpansionQueryBuilder.NAME, + TextExpansionQueryBuilder::new, + TextExpansionQueryBuilder::fromXContent + ), + // TODO: The WeightedTokensBuilder is slated for removal after the SparseVectorQueryBuilder is available. + // The logic to create a Boolean query based on weighted tokens will remain and/or be moved to server. new SearchPlugin.QuerySpec( WeightedTokensQueryBuilder.NAME, WeightedTokensQueryBuilder::new, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeAction.java new file mode 100644 index 0000000000000..bbd90448cf855 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeAction.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ElasticsearchTimeoutException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateTaskListener; +import org.elasticsearch.cluster.SimpleBatchedExecutor; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; + +import java.util.concurrent.atomic.AtomicBoolean; + +public abstract class AbstractTransportSetUpgradeModeAction extends AcknowledgedTransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(AbstractTransportSetUpgradeModeAction.class); + private final AtomicBoolean isRunning = new AtomicBoolean(false); + private final MasterServiceTaskQueue taskQueue; + + public AbstractTransportSetUpgradeModeAction( + String actionName, + String taskQueuePrefix, + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + actionName, + transportService, + clusterService, + threadPool, + actionFilters, + SetUpgradeModeActionRequest::new, + indexNameExpressionResolver, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + this.taskQueue = clusterService.createTaskQueue(taskQueuePrefix + " upgrade mode", Priority.NORMAL, new UpdateModeExecutor()); + } + + @Override + protected void masterOperation( + Task task, + SetUpgradeModeActionRequest request, + ClusterState state, + ActionListener listener + ) throws Exception { + // Don't want folks spamming this endpoint while it is in progress, only allow one request to be handled at a time + if (isRunning.compareAndSet(false, true) == false) { + String msg = Strings.format( + "Attempted to set [upgrade_mode] for feature name [%s] to [%s] from [%s] while previous request was processing.", + featureName(), + request.enabled(), + upgradeMode(state) + ); + logger.info(msg); + Exception detail = new IllegalStateException(msg); + listener.onFailure( + new ElasticsearchStatusException( + "Cannot change [upgrade_mode] for feature name [{}]. Previous request is still being processed.", + RestStatus.TOO_MANY_REQUESTS, + detail, + featureName() + ) + ); + return; + } + + // Noop, nothing for us to do, simply return fast to the caller + var upgradeMode = upgradeMode(state); + if (request.enabled() == upgradeMode) { + logger.info("Upgrade mode noop"); + isRunning.set(false); + listener.onResponse(AcknowledgedResponse.TRUE); + return; + } + + logger.info( + "Starting to set [upgrade_mode] for feature name [{}] to [{}] from [{}]", + featureName(), + request.enabled(), + upgradeMode + ); + + ActionListener wrappedListener = ActionListener.wrap(r -> { + logger.info("Finished setting [upgrade_mode] for feature name [{}]", featureName()); + isRunning.set(false); + listener.onResponse(r); + }, e -> { + logger.info("Failed to set [upgrade_mode] for feature name [{}]", featureName()); + isRunning.set(false); + listener.onFailure(e); + }); + + ActionListener setUpgradeModeListener = wrappedListener.delegateFailure((delegate, ack) -> { + if (ack.isAcknowledged()) { + upgradeModeSuccessfullyChanged(task, request, state, delegate); + } else { + logger.info("Cluster state update is NOT acknowledged"); + wrappedListener.onFailure(new ElasticsearchTimeoutException("Unknown error occurred while updating cluster state")); + } + }); + + taskQueue.submitTask(featureName(), new UpdateModeStateListener(request, setUpgradeModeListener), request.ackTimeout()); + } + + /** + * Define the feature name, used in log messages and naming the task on the task queue. + */ + protected abstract String featureName(); + + /** + * Parse the ClusterState for the implementation's {@link org.elasticsearch.cluster.metadata.Metadata.Custom} and find the upgradeMode + * boolean stored there. We will compare this boolean with the request's desired state to determine if we should change the metadata. + */ + protected abstract boolean upgradeMode(ClusterState state); + + /** + * This is called from the ClusterState updater and is expected to return quickly. + */ + protected abstract ClusterState createUpdatedState(SetUpgradeModeActionRequest request, ClusterState state); + + /** + * This method is only called when the cluster state was successfully changed. + * If we failed to update for any reason, this will not be called. + * The ClusterState param is the previous ClusterState before we called update. + */ + protected abstract void upgradeModeSuccessfullyChanged( + Task task, + SetUpgradeModeActionRequest request, + ClusterState state, + ActionListener listener + ); + + @Override + protected ClusterBlockException checkBlock(SetUpgradeModeActionRequest request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + private record UpdateModeStateListener(SetUpgradeModeActionRequest request, ActionListener listener) + implements + ClusterStateTaskListener { + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + + private class UpdateModeExecutor extends SimpleBatchedExecutor { + @Override + public Tuple executeTask(UpdateModeStateListener clusterStateListener, ClusterState clusterState) { + return Tuple.tuple(createUpdatedState(clusterStateListener.request(), clusterState), null); + } + + @Override + public void taskSucceeded(UpdateModeStateListener clusterStateListener, Void unused) { + clusterStateListener.listener().onResponse(AcknowledgedResponse.TRUE); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/SetUpgradeModeActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/SetUpgradeModeActionRequest.java new file mode 100644 index 0000000000000..98e30b284c21a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/SetUpgradeModeActionRequest.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.action; + +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SetUpgradeModeActionRequest extends AcknowledgedRequest implements ToXContentObject { + + private final boolean enabled; + + private static final ParseField ENABLED = new ParseField("enabled"); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "set_upgrade_mode_action_request", + a -> new SetUpgradeModeActionRequest((Boolean) a[0]) + ); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED); + } + + public SetUpgradeModeActionRequest(boolean enabled) { + super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + this.enabled = enabled; + } + + public SetUpgradeModeActionRequest(StreamInput in) throws IOException { + super(in); + this.enabled = in.readBoolean(); + } + + public boolean enabled() { + return enabled; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(enabled); + } + + @Override + public int hashCode() { + return Objects.hash(enabled); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + SetUpgradeModeActionRequest other = (SetUpgradeModeActionRequest) obj; + return enabled == other.enabled(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ENABLED.getPreferredName(), enabled); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetUpgradeModeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetUpgradeModeAction.java index 821caf001f3e0..a67ae33e85801 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetUpgradeModeAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetUpgradeModeAction.java @@ -7,17 +7,13 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.action.ActionType; -import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.action.SetUpgradeModeActionRequest; import java.io.IOException; -import java.util.Objects; public class SetUpgradeModeAction extends ActionType { @@ -28,9 +24,7 @@ private SetUpgradeModeAction() { super(NAME); } - public static class Request extends AcknowledgedRequest implements ToXContentObject { - - private final boolean enabled; + public static class Request extends SetUpgradeModeActionRequest { private static final ParseField ENABLED = new ParseField("enabled"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -43,48 +37,11 @@ public static class Request extends AcknowledgedRequest implements ToXC } public Request(boolean enabled) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); - this.enabled = enabled; + super(enabled); } public Request(StreamInput in) throws IOException { super(in); - this.enabled = in.readBoolean(); - } - - public boolean isEnabled() { - return enabled; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeBoolean(enabled); - } - - @Override - public int hashCode() { - return Objects.hash(enabled); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || obj.getClass() != getClass()) { - return false; - } - Request other = (Request) obj; - return Objects.equals(enabled, other.enabled); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ENABLED.getPreferredName(), enabled); - builder.endObject(); - return builder; } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index 5a63ad8e85e9b..e9e4e90421adc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -33,9 +33,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; -import org.elasticsearch.xpack.core.ml.search.WeightedTokensUtils; import java.io.IOException; import java.util.ArrayList; @@ -210,7 +207,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { return (shouldPruneTokens) ? WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, queryVectors, ft, context) - : WeightedTokensUtils.queryBuilderWithAllTokens(queryVectors, ft, context); + : WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, queryVectors, ft, context); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryWrapper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryWrapper.java new file mode 100644 index 0000000000000..234560f620d95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryWrapper.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.elasticsearch.index.query.SearchExecutionContext; + +import java.io.IOException; +import java.util.Objects; + +/** + * A wrapper class for the Lucene query generated by {@link SparseVectorQueryBuilder#toQuery(SearchExecutionContext)}. + * This wrapper facilitates the extraction of the complete sparse vector query using a {@link QueryVisitor}. + */ +public class SparseVectorQueryWrapper extends Query { + private final String fieldName; + private final Query termsQuery; + + public SparseVectorQueryWrapper(String fieldName, Query termsQuery) { + this.fieldName = fieldName; + this.termsQuery = termsQuery; + } + + public Query getTermsQuery() { + return termsQuery; + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + var rewrite = termsQuery.rewrite(indexSearcher); + if (rewrite != termsQuery) { + return new SparseVectorQueryWrapper(fieldName, rewrite); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return termsQuery.createWeight(searcher, scoreMode, boost); + } + + @Override + public String toString(String field) { + return termsQuery.toString(field); + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(fieldName)) { + termsQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); + } + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + SparseVectorQueryWrapper that = (SparseVectorQueryWrapper) obj; + return fieldName.equals(that.fieldName) && termsQuery.equals(that.termsQuery); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), fieldName, termsQuery); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java index 6d972bcf5863a..81758ec5f9342 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.Query; import org.apache.lucene.util.SetOnce; @@ -32,8 +32,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; -import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java index 256c90c3eaa62..f41fcd77ce627 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java @@ -125,7 +125,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } return (this.tokenPruningConfig == null) - ? WeightedTokensUtils.queryBuilderWithAllTokens(tokens, ft, context) + ? WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, tokens, ft, context) : WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, tokens, ft, context); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java index 133920416d227..1c2ac23151e6e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java @@ -24,13 +24,18 @@ public final class WeightedTokensUtils { private WeightedTokensUtils() {} - public static Query queryBuilderWithAllTokens(List tokens, MappedFieldType ft, SearchExecutionContext context) { + public static Query queryBuilderWithAllTokens( + String fieldName, + List tokens, + MappedFieldType ft, + SearchExecutionContext context + ) { var qb = new BooleanQuery.Builder(); for (var token : tokens) { qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); } - return qb.setMinimumNumberShouldMatch(1).build(); + return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build()); } public static Query queryBuilderWithPrunedTokens( @@ -64,7 +69,7 @@ public static Query queryBuilderWithPrunedTokens( } } - return qb.setMinimumNumberShouldMatch(1).build(); + return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build()); } /** diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeActionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeActionTests.java new file mode 100644 index 0000000000000..d780b7fbc32f4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/action/AbstractTransportSetUpgradeModeActionTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateTaskListener; +import org.elasticsearch.cluster.SimpleBatchedExecutor; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AbstractTransportSetUpgradeModeActionTests extends ESTestCase { + /** + * Creates a TaskQueue that invokes the SimpleBatchedExecutor. + */ + public static ClusterService clusterService() { + AtomicReference> executor = new AtomicReference<>(); + MasterServiceTaskQueue taskQueue = mock(); + ClusterService clusterService = mock(); + doAnswer(ans -> { + executor.set(ans.getArgument(2)); + return taskQueue; + }).when(clusterService).createTaskQueue(any(), any(), any()); + doAnswer(ans -> { + if (executor.get() == null) { + fail("We should create the task queue before we submit tasks to it"); + } else { + executor.get().executeTask(ans.getArgument(1), ClusterState.EMPTY_STATE); + executor.get().taskSucceeded(ans.getArgument(1), null); + } + return null; + }).when(taskQueue).submitTask(any(), any(), any()); + return clusterService; + } + + /** + * Creates a TaskQueue that calls the listener with an error. + */ + public static ClusterService clusterServiceWithError(Exception e) { + MasterServiceTaskQueue taskQueue = mock(); + ClusterService clusterService = mock(); + when(clusterService.createTaskQueue(any(), any(), any())).thenReturn(taskQueue); + doAnswer(ans -> { + ClusterStateTaskListener listener = ans.getArgument(1); + listener.onFailure(e); + return null; + }).when(taskQueue).submitTask(any(), any(), any()); + return clusterService; + } + + /** + * TaskQueue that does nothing. + */ + public static ClusterService clusterServiceThatDoesNothing() { + ClusterService clusterService = mock(); + when(clusterService.createTaskQueue(any(), any(), any())).thenReturn(mock()); + return clusterService; + } + + public void testIdempotent() throws Exception { + // create with update mode set to false + var action = new TestTransportSetUpgradeModeAction(clusterServiceThatDoesNothing(), false); + + // flip to true but do nothing (cluster service mock won't invoke the listener) + action.runWithoutWaiting(true); + // call again + var response = action.run(true); + + assertThat(response.v1(), nullValue()); + assertThat(response.v2(), notNullValue()); + assertThat(response.v2(), instanceOf(ElasticsearchStatusException.class)); + assertThat( + response.v2().getMessage(), + is("Cannot change [upgrade_mode] for feature name [" + action.featureName() + "]. Previous request is still being processed.") + ); + } + + public void testUpdateDoesNotRun() throws Exception { + var shouldNotChange = new AtomicBoolean(true); + var action = new TestTransportSetUpgradeModeAction(true, l -> shouldNotChange.set(false)); + + var response = action.run(true); + + assertThat(response.v1(), is(AcknowledgedResponse.TRUE)); + assertThat(response.v2(), nullValue()); + assertThat(shouldNotChange.get(), is(true)); + } + + public void testErrorReleasesLock() throws Exception { + var action = new TestTransportSetUpgradeModeAction(false, l -> l.onFailure(new IllegalStateException("hello there"))); + + action.run(true); + var response = action.run(true); + assertThat( + "Previous request should have finished processing.", + response.v2().getMessage(), + not(containsString("Previous request is still being processed")) + ); + } + + public void testErrorFromAction() throws Exception { + var expectedException = new IllegalStateException("hello there"); + var action = new TestTransportSetUpgradeModeAction(false, l -> l.onFailure(expectedException)); + + var response = action.run(true); + + assertThat(response.v1(), nullValue()); + assertThat(response.v2(), is(expectedException)); + } + + public void testErrorFromTaskQueue() throws Exception { + var expectedException = new IllegalStateException("hello there"); + var action = new TestTransportSetUpgradeModeAction(clusterServiceWithError(expectedException), false); + + var response = action.run(true); + + assertThat(response.v1(), nullValue()); + assertThat(response.v2(), is(expectedException)); + } + + public void testSuccess() throws Exception { + var action = new TestTransportSetUpgradeModeAction(false, l -> l.onResponse(AcknowledgedResponse.TRUE)); + + var response = action.run(true); + + assertThat(response.v1(), is(AcknowledgedResponse.TRUE)); + assertThat(response.v2(), nullValue()); + } + + private static class TestTransportSetUpgradeModeAction extends AbstractTransportSetUpgradeModeAction { + private final boolean upgradeMode; + private final ClusterState updatedClusterState; + private final Consumer> successFunc; + + TestTransportSetUpgradeModeAction(boolean upgradeMode, Consumer> successFunc) { + super("actionName", "taskQueuePrefix", mock(), clusterService(), mock(), mock(), mock()); + this.upgradeMode = upgradeMode; + this.updatedClusterState = ClusterState.EMPTY_STATE; + this.successFunc = successFunc; + } + + TestTransportSetUpgradeModeAction(ClusterService clusterService, boolean upgradeMode) { + super("actionName", "taskQueuePrefix", mock(), clusterService, mock(), mock(), mock()); + this.upgradeMode = upgradeMode; + this.updatedClusterState = ClusterState.EMPTY_STATE; + this.successFunc = listener -> {}; + } + + public void runWithoutWaiting(boolean upgrade) throws Exception { + masterOperation(mock(), new SetUpgradeModeActionRequest(upgrade), ClusterState.EMPTY_STATE, ActionListener.noop()); + } + + public Tuple run(boolean upgrade) throws Exception { + AtomicReference> response = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + masterOperation(mock(), new SetUpgradeModeActionRequest(upgrade), ClusterState.EMPTY_STATE, ActionListener.wrap(r -> { + response.set(Tuple.tuple(r, null)); + latch.countDown(); + }, e -> { + response.set(Tuple.tuple(null, e)); + latch.countDown(); + })); + assertTrue("Failed to run TestTransportSetUpgradeModeAction in 10s", latch.await(10, TimeUnit.SECONDS)); + return response.get(); + } + + @Override + protected String featureName() { + return "test-feature-name"; + } + + @Override + protected boolean upgradeMode(ClusterState state) { + return upgradeMode; + } + + @Override + protected ClusterState createUpdatedState(SetUpgradeModeActionRequest request, ClusterState state) { + return updatedClusterState; + } + + @Override + protected void upgradeModeSuccessfullyChanged( + Task task, + SetUpgradeModeActionRequest request, + ClusterState state, + ActionListener listener + ) { + successFunc.accept(listener); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java similarity index 94% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java index 3d17d8dd23ff6..b5296bef05b77 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.document.Document; import org.apache.lucene.document.FeatureField; @@ -40,9 +40,6 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; -import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; import java.lang.reflect.Method; @@ -50,7 +47,7 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD; +import static org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.hasSize; @@ -102,7 +99,7 @@ private SparseVectorQueryBuilder createTestQueryBuilder(TokenPruningConfig token @Override protected Collection> getPlugins() { - return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class); + return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class); } @Override @@ -156,8 +153,10 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws @Override protected void doAssertLuceneQuery(SparseVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery booleanQuery = (BooleanQuery) query; + assertThat(query, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) query; + assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); @@ -233,11 +232,13 @@ public void testToQuery() throws IOException { private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutionContext context) throws IOException { Query query = queryBuilder.doToQuery(context); + assertTrue(query instanceof SparseVectorQueryWrapper); + var sparseQuery = (SparseVectorQueryWrapper) query; if (queryBuilder.shouldPruneTokens()) { // It's possible that all documents were pruned for aggressive pruning configurations - assertTrue(query instanceof BooleanQuery || query instanceof MatchNoDocsQuery); + assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery || sparseQuery.getTermsQuery() instanceof MatchNoDocsQuery); } else { - assertTrue(query instanceof BooleanQuery); + assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java index 8da6fc843614e..9d8a286df1e66 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.document.Document; import org.apache.lucene.document.FeatureField; @@ -35,10 +35,6 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; -import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; -import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; import java.lang.reflect.Method; @@ -77,7 +73,7 @@ protected TextExpansionQueryBuilder doCreateTestQueryBuilder() { @Override protected Collection> getPlugins() { - return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class); + return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class); } @Override @@ -129,8 +125,10 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws @Override protected void doAssertLuceneQuery(TextExpansionQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery booleanQuery = (BooleanQuery) query; + assertThat(query, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) query; + assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java index bb727204e2651..cf63bfc269899 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java @@ -271,8 +271,11 @@ public void testPruningIsAppliedCorrectly() throws IOException { } private void assertCorrectLuceneQuery(String name, Query query, List expectedFeatureFields) { - assertTrue(query instanceof BooleanQuery); - List booleanClauses = ((BooleanQuery) query).clauses(); + assertThat(query, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) query; + assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); + List booleanClauses = booleanQuery.clauses(); assertEquals( name + " had " + booleanClauses.size() + " clauses, expected " + expectedFeatureFields.size(), expectedFeatureFields.size(), @@ -343,8 +346,10 @@ public void testMustRewrite() throws IOException { @Override protected void doAssertLuceneQuery(WeightedTokensQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery booleanQuery = (BooleanQuery) query; + assertThat(query, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) query; + assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java index 20cdbaf6acdbf..53f559c5c82fe 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java @@ -122,7 +122,11 @@ public boolean equals(Object obj) { @Override public String toString() { - return String.valueOf(value); + String str = String.valueOf(value); + if (str.length() > 500) { + return str.substring(0, 500) + "..."; + } + return str; } @Override diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/LiteralTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/LiteralTests.java index a4c67a8076479..a628916e67746 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/LiteralTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/LiteralTests.java @@ -6,9 +6,12 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import joptsimple.internal.Strings; + import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.tree.AbstractNodeTestCase; +import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.tree.SourceTests; import org.elasticsearch.xpack.esql.core.type.Converter; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -17,6 +20,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.function.Function; import java.util.function.Supplier; @@ -29,9 +33,12 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.SHORT; +import static org.hamcrest.Matchers.equalTo; public class LiteralTests extends AbstractNodeTestCase { + static class ValueAndCompatibleTypes { + final Supplier valueSupplier; final List validDataTypes; @@ -120,6 +127,19 @@ public void testReplaceChildren() { assertEquals("this type of node doesn't have any children to replace", e.getMessage()); } + public void testToString() { + assertThat(new Literal(Source.EMPTY, 1, LONG).toString(), equalTo("1")); + assertThat(new Literal(Source.EMPTY, "short", KEYWORD).toString(), equalTo("short")); + // toString should limit it's length + String tooLong = Strings.repeat('a', 510); + assertThat(new Literal(Source.EMPTY, tooLong, KEYWORD).toString(), equalTo(Strings.repeat('a', 500) + "...")); + + for (ValueAndCompatibleTypes g : GENERATORS) { + Literal lit = new Literal(Source.EMPTY, g.valueSupplier.get(), randomFrom(g.validDataTypes)); + assertThat(lit.toString(), equalTo(Objects.toString(lit.value()))); + } + } + private static Object randomValueOfTypeOtherThan(Object original, DataType type) { for (ValueAndCompatibleTypes gen : GENERATORS) { if (gen.validDataTypes.get(0) == type) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeRequest.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeRequest.java index 6ed2cc7e587be..1e8700bcd4030 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeRequest.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeRequest.java @@ -40,6 +40,17 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(sourcesFinished); } + @Override + public TaskId getParentTask() { + // Exchange requests with `sourcesFinished=true` complete the remote sink and return without blocking. + // Masking the parent task allows these requests to bypass task cancellation, ensuring cleanup of the remote sink. + // TODO: Maybe add a separate action/request for closing exchange sinks? + if (sourcesFinished) { + return TaskId.EMPTY_TASK_ID; + } + return super.getParentTask(); + } + /** * True if the {@link ExchangeSourceHandler} has enough input. * The corresponding {@link ExchangeSinkHandler} can drain pages and finish itself. @@ -70,9 +81,9 @@ public int hashCode() { @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - if (parentTaskId.isSet() == false) { - assert false : "ExchangeRequest must have a parent task"; - throw new IllegalStateException("ExchangeRequest must have a parent task"); + if (sourcesFinished == false && parentTaskId.isSet() == false) { + assert false : "ExchangeRequest with sourcesFinished=false must have a parent task"; + throw new IllegalStateException("ExchangeRequest with sourcesFinished=false must have a parent task"); } return new CancellableTask(id, type, action, "", parentTaskId, headers) { @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index a943a90d02e87..00c68c4f48e86 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -314,28 +314,20 @@ static final class TransportRemoteSink implements RemoteSink { @Override public void fetchPageAsync(boolean allSourcesFinished, ActionListener listener) { if (allSourcesFinished) { - if (finished.compareAndSet(false, true)) { - doFetchPageAsync(true, listener); - } else { - // already finished or promised - listener.onResponse(new ExchangeResponse(blockFactory, null, true)); - } - } else { - // already finished - if (finished.get()) { - listener.onResponse(new ExchangeResponse(blockFactory, null, true)); - return; - } - doFetchPageAsync(false, ActionListener.wrap(r -> { - if (r.finished()) { - finished.set(true); - } - listener.onResponse(r); - }, e -> { - finished.set(true); - listener.onFailure(e); - })); + close(listener.map(unused -> new ExchangeResponse(blockFactory, null, true))); + return; + } + // already finished + if (finished.get()) { + listener.onResponse(new ExchangeResponse(blockFactory, null, true)); + return; } + doFetchPageAsync(false, ActionListener.wrap(r -> { + if (r.finished()) { + finished.set(true); + } + listener.onResponse(r); + }, e -> close(ActionListener.running(() -> listener.onFailure(e))))); } private void doFetchPageAsync(boolean allSourcesFinished, ActionListener listener) { @@ -361,6 +353,15 @@ private void doFetchPageAsync(boolean allSourcesFinished, ActionListener listener) { + if (finished.compareAndSet(false, true)) { + doFetchPageAsync(true, listener.delegateFailure((l, unused) -> l.onResponse(null))); + } else { + listener.onResponse(null); + } + } } // For testing diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index 61b3386ce0274..375016a5d51d5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -224,8 +224,10 @@ void onSinkFailed(Exception e) { buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading if (finished == false) { finished = true; - outstandingSinks.finishInstance(); - completionListener.onFailure(e); + remoteSink.close(ActionListener.running(() -> { + outstandingSinks.finishInstance(); + completionListener.onFailure(e); + })); } } @@ -262,7 +264,7 @@ public void onFailure(Exception e) { failure.unwrapAndCollect(e); } buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading - sinkListener.onFailure(e); + remoteSink.close(ActionListener.running(() -> sinkListener.onFailure(e))); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java index 7d81cd3f66600..aaa937ef17c0e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java @@ -12,4 +12,14 @@ public interface RemoteSink { void fetchPageAsync(boolean allSourcesFinished, ActionListener listener); + + default void close(ActionListener listener) { + fetchPageAsync(true, listener.delegateFailure((l, r) -> { + try { + r.close(); + } finally { + l.onResponse(null); + } + })); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeRequestTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeRequestTests.java new file mode 100644 index 0000000000000..8a0891651a497 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeRequestTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.exchange; + +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class ExchangeRequestTests extends ESTestCase { + + public void testParentTask() { + ExchangeRequest r1 = new ExchangeRequest("1", true); + r1.setParentTask(new TaskId("node-1", 1)); + assertSame(TaskId.EMPTY_TASK_ID, r1.getParentTask()); + + ExchangeRequest r2 = new ExchangeRequest("1", false); + r2.setParentTask(new TaskId("node-2", 2)); + assertTrue(r2.getParentTask().isSet()); + assertThat(r2.getParentTask(), equalTo((new TaskId("node-2", 2)))); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index 4178f02898d79..fc6c850ba187b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -491,7 +491,7 @@ public void testConcurrentWithTransportActions() { } } - public void testFailToRespondPage() { + public void testFailToRespondPage() throws Exception { Settings settings = Settings.builder().build(); MockTransportService node0 = newTransportService(); ExchangeService exchange0 = new ExchangeService(settings, threadPool, ESQL_TEST_EXECUTOR, blockFactory()); @@ -558,7 +558,9 @@ public void sendResponse(TransportResponse transportResponse) { Throwable cause = ExceptionsHelper.unwrap(err, IOException.class); assertNotNull(cause); assertThat(cause.getMessage(), equalTo("page is too large")); - sinkHandler.onFailure(new RuntimeException(cause)); + PlainActionFuture sinkCompletionFuture = new PlainActionFuture<>(); + sinkHandler.addCompletionListener(sinkCompletionFuture); + assertBusy(() -> assertTrue(sinkCompletionFuture.isDone())); expectThrows(Exception.class, () -> sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS)); } } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/build.gradle b/x-pack/plugin/esql/qa/server/multi-clusters/build.gradle index 2c432eb94ebf1..c5a2636b07a59 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/build.gradle +++ b/x-pack/plugin/esql/qa/server/multi-clusters/build.gradle @@ -24,9 +24,22 @@ def supportedVersion = bwcVersion -> { } buildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName -> - tasks.register(bwcTaskName(bwcVersion), StandaloneRestIntegTestTask) { + tasks.register("${baseName}#newToOld", StandaloneRestIntegTestTask) { + usesBwcDistribution(bwcVersion) + systemProperty("tests.version.remote_cluster", bwcVersion) + maxParallelForks = 1 + } + + tasks.register("${baseName}#oldToNew", StandaloneRestIntegTestTask) { usesBwcDistribution(bwcVersion) - systemProperty("tests.old_cluster_version", bwcVersion) + systemProperty("tests.version.local_cluster", bwcVersion) + maxParallelForks = 1 + } + + // TODO: avoid running tests twice with the current version + tasks.register(bwcTaskName(bwcVersion), StandaloneRestIntegTestTask) { + dependsOn tasks.named("${baseName}#oldToNew") + dependsOn tasks.named("${baseName}#newToOld") maxParallelForks = 1 } } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/Clusters.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/Clusters.java index fa8cb49c59aed..5f3f135810322 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/Clusters.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/Clusters.java @@ -20,7 +20,7 @@ public static ElasticsearchCluster remoteCluster() { return ElasticsearchCluster.local() .name(REMOTE_CLUSTER_NAME) .distribution(DistributionType.DEFAULT) - .version(Version.fromString(System.getProperty("tests.old_cluster_version"))) + .version(distributionVersion("tests.version.remote_cluster")) .nodes(2) .setting("node.roles", "[data,ingest,master]") .setting("xpack.security.enabled", "false") @@ -34,7 +34,7 @@ public static ElasticsearchCluster localCluster(ElasticsearchCluster remoteClust return ElasticsearchCluster.local() .name(LOCAL_CLUSTER_NAME) .distribution(DistributionType.DEFAULT) - .version(Version.CURRENT) + .version(distributionVersion("tests.version.local_cluster")) .nodes(2) .setting("xpack.security.enabled", "false") .setting("xpack.license.self_generated.type", "trial") @@ -46,7 +46,18 @@ public static ElasticsearchCluster localCluster(ElasticsearchCluster remoteClust .build(); } - public static org.elasticsearch.Version oldVersion() { - return org.elasticsearch.Version.fromString(System.getProperty("tests.old_cluster_version")); + public static org.elasticsearch.Version localClusterVersion() { + String prop = System.getProperty("tests.version.local_cluster"); + return prop != null ? org.elasticsearch.Version.fromString(prop) : org.elasticsearch.Version.CURRENT; + } + + public static org.elasticsearch.Version remoteClusterVersion() { + String prop = System.getProperty("tests.version.remote_cluster"); + return prop != null ? org.elasticsearch.Version.fromString(prop) : org.elasticsearch.Version.CURRENT; + } + + private static Version distributionVersion(String key) { + final String val = System.getProperty(key); + return val != null ? Version.fromString(val) : Version.CURRENT; } } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/EsqlRestValidationIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/EsqlRestValidationIT.java index 21307c5362417..55500aa1c9537 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/EsqlRestValidationIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/EsqlRestValidationIT.java @@ -10,12 +10,14 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; import org.apache.http.HttpHost; +import org.elasticsearch.Version; import org.elasticsearch.client.RestClient; import org.elasticsearch.core.IOUtils; import org.elasticsearch.test.TestClustersThreadFilter; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.xpack.esql.qa.rest.EsqlRestValidationTestCase; import org.junit.AfterClass; +import org.junit.Before; import org.junit.ClassRule; import org.junit.rules.RuleChain; import org.junit.rules.TestRule; @@ -78,4 +80,9 @@ private RestClient remoteClusterClient() throws IOException { } return remoteClient; } + + @Before + public void skipTestOnOldVersions() { + assumeTrue("skip on old versions", Clusters.localClusterVersion().equals(Version.V_8_16_0)); + } } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java index af5eadc7358a2..6c7b700af5b1a 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java @@ -12,6 +12,7 @@ import org.apache.http.HttpEntity; import org.apache.http.HttpHost; +import org.elasticsearch.Version; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.RestClient; @@ -118,10 +119,8 @@ protected void shouldSkipTest(String testName) throws IOException { // Do not run tests including "METADATA _index" unless marked with metadata_fields_remote_test, // because they may produce inconsistent results with multiple clusters. assumeFalse("can't test with _index metadata", (remoteMetadata == false) && hasIndexMetadata(testCase.query)); - assumeTrue( - "Test " + testName + " is skipped on " + Clusters.oldVersion(), - isEnabled(testName, instructions, Clusters.oldVersion()) - ); + Version oldVersion = Version.min(Clusters.localClusterVersion(), Clusters.remoteClusterVersion()); + assumeTrue("Test " + testName + " is skipped on " + oldVersion, isEnabled(testName, instructions, oldVersion)); assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS.capabilityName())); assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS_V2.capabilityName())); assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_PLANNING_V1.capabilityName())); diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java index dbeaed1596eff..452f40baa34a8 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; import org.apache.http.HttpHost; +import org.elasticsearch.Version; import org.elasticsearch.client.Request; import org.elasticsearch.client.RestClient; import org.elasticsearch.common.Strings; @@ -29,7 +30,6 @@ import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -127,10 +127,12 @@ void indexDocs(RestClient client, String index, List docs) throws IOExcepti } private Map run(String query, boolean includeCCSMetadata) throws IOException { - Map resp = runEsql( - new RestEsqlTestCase.RequestObjectBuilder().query(query).includeCCSMetadata(includeCCSMetadata).build() - ); - logger.info("--> query {} response {}", query, resp); + var queryBuilder = new RestEsqlTestCase.RequestObjectBuilder().query(query); + if (includeCCSMetadata) { + queryBuilder.includeCCSMetadata(true); + } + Map resp = runEsql(queryBuilder.build()); + logger.info("--> query {} response {}", queryBuilder, resp); return resp; } @@ -156,7 +158,7 @@ private Map runEsql(RestEsqlTestCase.RequestObjectBuilder reques public void testCount() throws Exception { { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run("FROM test-local-index,*:test-remote-index | STATS c = COUNT(*)", includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long")); var values = List.of(List.of(localDocs.size() + remoteDocs.size())); @@ -165,13 +167,16 @@ public void testCount() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, false); } } { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run("FROM *:test-remote-index | STATS c = COUNT(*)", includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long")); var values = List.of(List.of(remoteDocs.size())); @@ -180,7 +185,10 @@ public void testCount() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, true); } @@ -189,7 +197,7 @@ public void testCount() throws Exception { public void testUngroupedAggs() throws Exception { { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run("FROM test-local-index,*:test-remote-index | STATS total = SUM(data)", includeCCSMetadata); var columns = List.of(Map.of("name", "total", "type", "long")); long sum = Stream.concat(localDocs.stream(), remoteDocs.stream()).mapToLong(d -> d.data).sum(); @@ -200,13 +208,16 @@ public void testUngroupedAggs() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, false); } } { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run("FROM *:test-remote-index | STATS total = SUM(data)", includeCCSMetadata); var columns = List.of(Map.of("name", "total", "type", "long")); long sum = remoteDocs.stream().mapToLong(d -> d.data).sum(); @@ -216,12 +227,16 @@ public void testUngroupedAggs() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, true); } } { + assumeTrue("requires ccs metadata", ccsMetadataAvailable()); Map result = runWithColumnarAndIncludeCCSMetadata("FROM *:test-remote-index | STATS total = SUM(data)"); var columns = List.of(Map.of("name", "total", "type", "long")); long sum = remoteDocs.stream().mapToLong(d -> d.data).sum(); @@ -293,7 +308,7 @@ private void assertClusterDetailsMap(Map result, boolean remoteO public void testGroupedAggs() throws Exception { { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run( "FROM test-local-index,*:test-remote-index | STATS total = SUM(data) BY color | SORT color", includeCCSMetadata @@ -311,13 +326,16 @@ public void testGroupedAggs() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, false); } } { - boolean includeCCSMetadata = randomBoolean(); + boolean includeCCSMetadata = includeCCSMetadata(); Map result = run( "FROM *:test-remote-index | STATS total = SUM(data) by color | SORT color", includeCCSMetadata @@ -336,29 +354,57 @@ public void testGroupedAggs() throws Exception { if (includeCCSMetadata) { mapMatcher = mapMatcher.entry("_clusters", any(Map.class)); } - assertMap(result, mapMatcher.entry("columns", columns).entry("values", values).entry("took", greaterThanOrEqualTo(0))); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); if (includeCCSMetadata) { assertClusterDetailsMap(result, true); } } } + public void testIndexPattern() throws Exception { + { + String indexPattern = randomFrom( + "test-local-index,*:test-remote-index", + "test-local-index,*:test-remote-*", + "test-local-index,*:test-*", + "test-*,*:test-remote-index" + ); + Map result = run("FROM " + indexPattern + " | STATS c = COUNT(*)", false); + var columns = List.of(Map.of("name", "c", "type", "long")); + var values = List.of(List.of(localDocs.size() + remoteDocs.size())); + MapMatcher mapMatcher = matchesMap(); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); + } + { + String indexPattern = randomFrom("*:test-remote-index", "*:test-remote-*", "*:test-*"); + Map result = run("FROM " + indexPattern + " | STATS c = COUNT(*)", false); + var columns = List.of(Map.of("name", "c", "type", "long")); + var values = List.of(List.of(remoteDocs.size())); + + MapMatcher mapMatcher = matchesMap(); + if (ccsMetadataAvailable()) { + mapMatcher = mapMatcher.entry("took", greaterThanOrEqualTo(0)); + } + assertMap(result, mapMatcher.entry("columns", columns).entry("values", values)); + } + } + private RestClient remoteClusterClient() throws IOException { var clusterHosts = parseClusterHosts(remoteCluster.getHttpAddresses()); return buildClient(restClientSettings(), clusterHosts.toArray(new HttpHost[0])); } - private TestFeatureService remoteFeaturesService() throws IOException { - if (remoteFeaturesService == null) { - try (RestClient remoteClient = remoteClusterClient()) { - var remoteNodeVersions = readVersionsFromNodesInfo(remoteClient); - var semanticNodeVersions = remoteNodeVersions.stream() - .map(ESRestTestCase::parseLegacyVersion) - .flatMap(Optional::stream) - .collect(Collectors.toSet()); - remoteFeaturesService = createTestFeatureService(getClusterStateFeatures(remoteClient), semanticNodeVersions); - } - } - return remoteFeaturesService; + private static boolean ccsMetadataAvailable() { + return Clusters.localClusterVersion().onOrAfter(Version.V_8_16_0); + } + + private static boolean includeCCSMetadata() { + return ccsMetadataAvailable() && randomBoolean(); } } diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java index 9a184b9a620fd..050259bbb5b5c 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java @@ -76,7 +76,6 @@ public void testBasicEsql() throws IOException { indexTimestampData(1); RequestObjectBuilder builder = requestObjectBuilder().query(fromIndex() + " | stats avg(value)"); - requestObjectBuilder().includeCCSMetadata(randomBoolean()); if (Build.current().isSnapshot()) { builder.pragmas(Settings.builder().put("data_partitioning", "shard").build()); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec index 7bbf011176693..b29c489910f65 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec @@ -503,6 +503,27 @@ FROM employees //end::reuseGroupingFunctionWithExpression-result[] ; +reuseGroupingFunctionImplicitAliasWithExpression#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| STATS s1 = `BUCKET(salary / 100 + 99, 50.)` + 1, s2 = BUCKET(salary / 1000 + 999, 50.) + 2 BY BUCKET(salary / 100 + 99, 50.), b2 = BUCKET(salary / 1000 + 999, 50.) +| SORT `BUCKET(salary / 100 + 99, 50.)`, b2 +| KEEP s1, `BUCKET(salary / 100 + 99, 50.)`, s2, b2 +; + + s1:double | BUCKET(salary / 100 + 99, 50.):double | s2:double | b2:double +351.0 |350.0 |1002.0 |1000.0 +401.0 |400.0 |1002.0 |1000.0 +451.0 |450.0 |1002.0 |1000.0 +501.0 |500.0 |1002.0 |1000.0 +551.0 |550.0 |1002.0 |1000.0 +601.0 |600.0 |1002.0 |1000.0 +601.0 |600.0 |1052.0 |1050.0 +651.0 |650.0 |1052.0 |1050.0 +701.0 |700.0 |1052.0 |1050.0 +751.0 |750.0 |1052.0 |1050.0 +801.0 |800.0 |1052.0 |1050.0 +; + reuseGroupingFunctionWithinAggs#[skip:-8.13.99, reason:BUCKET renamed in 8.14] FROM employees | STATS sum = 1 + MAX(1 + BUCKET(salary, 1000.)) BY BUCKET(salary, 1000.) + 1 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index e45b10d1aa122..804c1c56a1eb5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -1,5 +1,5 @@ standard aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS count=COUNT(), @@ -17,7 +17,7 @@ count:long | sum:long | avg:double | count_distinct:long | category:keyw ; values aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS values=MV_SORT(VALUES(message)), @@ -33,7 +33,7 @@ values:keyword | top ; mv -required_capability: categorize_v4 +required_capability: categorize_v5 FROM mv_sample_data | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message) @@ -48,7 +48,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; row mv -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"] | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message) @@ -61,7 +61,7 @@ COUNT():long | VALUES(str):keyword | category:keyword ; skips stopwords -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["Mon Tue connected to a", "Jul Aug connected to b September ", "UTC connected GMT to c UTC"] | STATS COUNT() BY category=CATEGORIZE(message) @@ -73,7 +73,7 @@ COUNT():long | category:keyword ; with multiple indices -required_capability: categorize_v4 +required_capability: categorize_v5 required_capability: union_types FROM sample_data* @@ -88,7 +88,7 @@ COUNT():long | category:keyword ; mv with many values -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS COUNT() BY category=CATEGORIZE(job_positions) @@ -105,7 +105,7 @@ COUNT():long | category:keyword ; mv with many values and SUM -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -120,7 +120,7 @@ SUM(languages):long | category:keyword ; mv with many values and nulls and SUM -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -134,7 +134,7 @@ SUM(languages):long | category:keyword ; mv via eval -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = MV_APPEND(message, "Banana") @@ -150,7 +150,7 @@ COUNT():long | category:keyword ; mv via eval const -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -164,7 +164,7 @@ COUNT():long | category:keyword ; mv via eval const without aliases -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -178,7 +178,7 @@ COUNT():long | CATEGORIZE(message):keyword ; mv const in parameter -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -191,7 +191,7 @@ COUNT():long | c:keyword ; agg alias shadowing -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -206,7 +206,7 @@ c:keyword ; chained aggregations using categorize -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -221,7 +221,7 @@ COUNT():long | category:keyword ; stats without aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS BY category=CATEGORIZE(message) @@ -235,7 +235,7 @@ category:keyword ; text field -required_capability: categorize_v4 +required_capability: categorize_v5 FROM hosts | STATS COUNT() BY category=CATEGORIZE(host_group) @@ -253,7 +253,7 @@ COUNT():long | category:keyword ; on TO_UPPER -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message)) @@ -267,7 +267,7 @@ COUNT():long | category:keyword ; on CONCAT -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana")) @@ -281,7 +281,7 @@ COUNT():long | category:keyword ; on CONCAT with unicode -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊")) @@ -295,7 +295,7 @@ COUNT():long | category:keyword ; on REVERSE(CONCAT()) -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊"))) @@ -309,7 +309,7 @@ COUNT():long | category:keyword ; and then TO_LOWER -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -324,7 +324,7 @@ COUNT():long | category:keyword ; on const empty string -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE("") @@ -336,7 +336,7 @@ COUNT():long | category:keyword ; on const empty string from eval -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = "" @@ -349,7 +349,7 @@ COUNT():long | category:keyword ; on null -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = null @@ -362,7 +362,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; on null string -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = null::string @@ -375,7 +375,7 @@ COUNT():long | category:keyword ; filtering out all data -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | WHERE @timestamp < "2023-10-23T00:00:00Z" @@ -387,7 +387,7 @@ COUNT():long | category:keyword ; filtering out all data with constant -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -398,7 +398,7 @@ COUNT():long | category:keyword ; drop output columns -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS count=COUNT() BY category=CATEGORIZE(message) @@ -413,7 +413,7 @@ x:integer ; category value processing -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["connected to a", "connected to b", "disconnected"] | STATS COUNT() BY category=CATEGORIZE(message) @@ -427,7 +427,7 @@ COUNT():long | category:keyword ; row aliases -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | EVAL x = message @@ -441,7 +441,7 @@ COUNT():long | category:keyword | y:keyword ; from aliases -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = message @@ -457,7 +457,7 @@ COUNT():long | category:keyword | y:keyword ; row aliases with keep -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | EVAL x = message @@ -473,7 +473,7 @@ COUNT():long | y:keyword ; from aliases with keep -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = message @@ -491,7 +491,7 @@ COUNT():long | y:keyword ; row rename -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | RENAME message as x @@ -505,7 +505,7 @@ COUNT():long | y:keyword ; from rename -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | RENAME message as x @@ -521,7 +521,7 @@ COUNT():long | y:keyword ; row drop -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to a" | STATS c = COUNT() BY category=CATEGORIZE(message) @@ -534,7 +534,7 @@ c:long ; from drop -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS c = COUNT() BY category=CATEGORIZE(message) @@ -547,3 +547,48 @@ c:long 3 3 ; + +categorize in aggs inside function +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(category, category) BY category=CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; + +categorize in aggs same as grouping inside function +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), `CATEGORIZE(message)`) BY CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; + +categorize in aggs same as grouping inside function with explicit alias +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), category) BY category=CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec index 24baf1263d06a..aa89c775da4cf 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec @@ -678,7 +678,7 @@ Bangalore | 9 | 72 ; docsCategorize -required_capability: categorize_v4 +required_capability: categorize_v5 // tag::docsCategorize[] FROM sample_data | STATS count=COUNT() BY category=CATEGORIZE(message) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index f61452f13fb53..6e0a55655ee1c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -522,7 +522,7 @@ h:d | languages:i 1.41 | null ; -groupByAlias#[skip:-8.13.99,reason:muted, see https://github.com/elastic/elasticsearch/issues/117770] +groupByAlias from employees | rename languages as l | keep l, height | stats m = min(height) by l | sort l; m:d | l:i @@ -951,7 +951,7 @@ c:l 49 ; -countFieldWithGrouping#[skip:-8.13.99,reason:muted, see https://github.com/elastic/elasticsearch/issues/117784] +countFieldWithGrouping from employees | rename languages as l | where emp_no < 10050 | stats c = count(emp_no) by l | sort l; c:l | l:i @@ -963,7 +963,7 @@ c:l | l:i 10 | null ; -countFieldWithAliasWithGrouping#[skip:-8.13.99,reason:muted, see https://github.com/elastic/elasticsearch/issues/117784] +countFieldWithAliasWithGrouping from employees | rename languages as l | eval e = emp_no | where emp_no < 10050 | stats c = count(e) by l | sort l; c:l | l:i @@ -982,7 +982,7 @@ c:l 49 ; -countEvalExpWithGrouping#[skip:-8.13.99,reason:muted, see https://github.com/elastic/elasticsearch/issues/117784] +countEvalExpWithGrouping from employees | rename languages as l | eval e = case(emp_no < 10050, emp_no, null) | stats c = count(e) by l | sort l; c:l | l:i diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/version.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/version.csv-spec index a4f6bd554881c..eb0d6d75a7d07 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/version.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/version.csv-spec @@ -159,7 +159,7 @@ id:i |name:s |version:v |o:v 13 |lllll |null |null ; -countVersion#[skip:-8.13.99,reason:muted, see https://github.com/elastic/elasticsearch/issues/117784] +countVersion FROM apps | RENAME name AS k | STATS v = COUNT(version) BY k | SORT k; v:l | k:s diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java index 5ffc92636b272..0910e820c118a 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java @@ -238,4 +238,41 @@ public void testSameRemoteClusters() throws Exception { } } } + + public void testTasks() throws Exception { + createRemoteIndex(between(10, 100)); + EsqlQueryRequest request = EsqlQueryRequest.syncEsqlQueryRequest(); + request.query("FROM *:test | STATS total=sum(const) | LIMIT 1"); + request.pragmas(randomPragmas()); + ActionFuture requestFuture = client().execute(EsqlQueryAction.INSTANCE, request); + assertTrue(PauseFieldPlugin.startEmitting.await(30, TimeUnit.SECONDS)); + try { + assertBusy(() -> { + List clusterTasks = client(REMOTE_CLUSTER).admin() + .cluster() + .prepareListTasks() + .setActions(ComputeService.CLUSTER_ACTION_NAME) + .get() + .getTasks(); + assertThat(clusterTasks.size(), equalTo(1)); + List drivers = client(REMOTE_CLUSTER).admin() + .cluster() + .prepareListTasks() + .setTargetParentTaskId(clusterTasks.get(0).taskId()) + .setActions(DriverTaskRunner.ACTION_NAME) + .setDetailed(true) + .get() + .getTasks(); + assertThat(drivers.size(), equalTo(1)); + TaskInfo driver = drivers.get(0); + assertThat(driver.description(), equalTo(""" + \\_ExchangeSourceOperator[] + \\_AggregationOperator[mode = INTERMEDIATE, aggs = sum of longs] + \\_ExchangeSinkOperator""")); + }); + } finally { + PauseFieldPlugin.allowEmitting.countDown(); + } + requestFuture.actionGet(30, TimeUnit.SECONDS).close(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 4422c8280dd0f..9fad9123944ff 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -400,7 +400,7 @@ public enum Cap { /** * Supported the text categorization function "CATEGORIZE". */ - CATEGORIZE_V4(Build.current().isSnapshot()), + CATEGORIZE_V5, /** * QSTR function diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 5f8c011cff53a..49d8a5ee8caad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; -import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; @@ -63,12 +62,10 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -364,35 +361,35 @@ private static void checkCategorizeGrouping(Aggregate agg, Set failures ); }); - // Forbid CATEGORIZE being used in the aggregations - agg.aggregates().forEach(a -> { - a.forEachDown( - Categorize.class, - categorize -> failures.add( - fail(categorize, "cannot use CATEGORIZE grouping function [{}] within the aggregations", categorize.sourceText()) + // Forbid CATEGORIZE being used in the aggregations, unless it appears as a grouping + agg.aggregates() + .forEach( + a -> a.forEachDown( + AggregateFunction.class, + aggregateFunction -> aggregateFunction.forEachDown( + Categorize.class, + categorize -> failures.add( + fail(categorize, "cannot use CATEGORIZE grouping function [{}] within an aggregation", categorize.sourceText()) + ) + ) ) ); - }); - // Forbid CATEGORIZE being referenced in the aggregation functions - Map categorizeByAliasId = new HashMap<>(); + // Forbid CATEGORIZE being referenced as a child of an aggregation function + AttributeMap categorizeByAttribute = new AttributeMap<>(); agg.groupings().forEach(g -> { g.forEachDown(Alias.class, alias -> { if (alias.child() instanceof Categorize categorize) { - categorizeByAliasId.put(alias.id(), categorize); + categorizeByAttribute.put(alias.toAttribute(), categorize); } }); }); agg.aggregates() .forEach(a -> a.forEachDown(AggregateFunction.class, aggregate -> aggregate.forEachDown(Attribute.class, attribute -> { - var categorize = categorizeByAliasId.get(attribute.id()); + var categorize = categorizeByAttribute.get(attribute); if (categorize != null) { failures.add( - fail( - attribute, - "cannot reference CATEGORIZE grouping function [{}] within the aggregations", - attribute.sourceText() - ) + fail(attribute, "cannot reference CATEGORIZE grouping function [{}] within an aggregation", attribute.sourceText()) ); } }))); @@ -449,7 +446,7 @@ private static void checkInvalidNamedExpressionUsage( // check the bucketing function against the group else if (c instanceof GroupingFunction gf) { if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { - failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText())); + failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText())); } } }); @@ -466,7 +463,7 @@ else if (c instanceof GroupingFunction gf) { // optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer // be verified (by check above in checkAggregate()), so do it explicitly here if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { - failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText())); + failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText())); } else if (level == 0) { addFailureOnGroupingUsedNakedInAggs(failures, gf, "function"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java index be7096538fb9a..957db4a7273e5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; public final class CombineProjections extends OptimizerRules.OptimizerRule { @@ -144,30 +145,31 @@ private static List combineUpperGroupingsAndLowerProjections( List upperGroupings, List lowerProjections ) { + assert upperGroupings.size() <= 1 + || upperGroupings.stream().anyMatch(group -> group.anyMatch(expr -> expr instanceof Categorize)) == false + : "CombineProjections only tested with a single CATEGORIZE with no additional groups"; // Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..) - AttributeMap aliases = new AttributeMap<>(); + AttributeMap aliases = new AttributeMap<>(); for (NamedExpression ne : lowerProjections) { - // record the alias - aliases.put(ne.toAttribute(), Alias.unwrap(ne)); + // Record the aliases. + // Projections are just aliases for attributes, so casting is safe. + aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne)); } - // Replace any matching attribute directly with the aliased attribute from the projection. - AttributeSet seen = new AttributeSet(); - List replaced = new ArrayList<>(); + + // Propagate any renames from the lower projection into the upper groupings. + // This can lead to duplicates: e.g. + // | EVAL x = y | STATS ... BY x, y + // All substitutions happen before; groupings must be attributes at this point except for CATEGORIZE which will be an alias like + // `c = CATEGORIZE(attribute)`. + // Therefore, it is correct to deduplicate based on simple equality (based on names) instead of name ids (Set vs. AttributeSet). + // TODO: The deduplication based on simple equality will be insufficient in case of multiple CATEGORIZEs, e.g. for + // `| EVAL x = y | STATS ... BY CATEGORIZE(x), CATEGORIZE(y)`. That will require semantic equality instead. + LinkedHashSet resolvedGroupings = new LinkedHashSet<>(); for (NamedExpression ne : upperGroupings) { - // Duplicated attributes are ignored. - if (ne instanceof Attribute attribute) { - var newExpression = aliases.resolve(attribute, attribute); - if (newExpression instanceof Attribute newAttribute && seen.add(newAttribute) == false) { - // Already seen, skip - continue; - } - replaced.add(newExpression); - } else { - // For grouping functions, this will replace nested properties too - replaced.add(ne.transformUp(Attribute.class, a -> aliases.resolve(a, a))); - } + NamedExpression transformed = (NamedExpression) ne.transformUp(Attribute.class, a -> aliases.resolve(a, a)); + resolvedGroupings.add(transformed); } - return replaced; + return new ArrayList<>(resolvedGroupings); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java index 2361b46b2be6f..c36d4caf7f599 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java @@ -9,18 +9,21 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,6 +54,16 @@ protected LogicalPlan rule(Aggregate aggregate) { AttributeMap aliases = new AttributeMap<>(); aggregate.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); + // Build Categorize grouping functions map. + // Functions like BUCKET() shouldn't reach this point, + // as they are moved to an early EVAL by ReplaceAggregateNestedExpressionWithEval + Map groupingAttributes = new HashMap<>(); + aggregate.forEachExpressionUp(Alias.class, a -> { + if (a.child() instanceof Categorize groupingFunction) { + groupingAttributes.put(groupingFunction, a.toAttribute()); + } + }); + // break down each aggregate into AggregateFunction and/or grouping key // preserve the projection at the end List aggs = aggregate.aggregates(); @@ -109,6 +122,9 @@ protected LogicalPlan rule(Aggregate aggregate) { return alias.toAttribute(); }); + // replace grouping functions with their references + aggExpression = aggExpression.transformUp(Categorize.class, groupingAttributes::get); + Alias alias = as.replaceChild(aggExpression); newEvals.add(alias); newProjections.add(alias.toAttribute()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java index 985e68252a1f9..4dbc43454a023 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java @@ -51,6 +51,7 @@ protected LogicalPlan rule(Aggregate aggregate) { // Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval. if (g instanceof Alias as) { if (as.child() instanceof Categorize cat) { + // For Categorize grouping function, we only move the child expression into an eval if (cat.field() instanceof Attribute == false) { groupingChanged = true; var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true); @@ -59,7 +60,6 @@ protected LogicalPlan rule(Aggregate aggregate) { evalNames.put(fieldAs.name(), fieldAttr); Categorize replacement = cat.replaceChildren(List.of(fieldAttr)); newGroupings.set(i, as.replaceChild(replacement)); - groupingAttributes.put(cat, fieldAttr); } } else { groupingChanged = true; @@ -135,6 +135,10 @@ protected LogicalPlan rule(Aggregate aggregate) { }); // replace any grouping functions with their references pointing to the added synthetic eval replaced = replaced.transformDown(GroupingFunction.class, gf -> { + // Categorize in aggs depends on the grouping result, not on an early eval + if (gf instanceof Categorize) { + return gf; + } aggsChanged.set(true); // should never return null, as it's verified. // but even if broken, the transform will fail safely; otoh, returning `gf` will fail later due to incorrect plan. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 69e2d1c45aa3c..35aba7665ec87 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -120,10 +120,14 @@ public final PhysicalOperation groupingPhysicalOperation( * - before stats (keep x = a | stats by x) which requires the partial input to use a's channel * - after stats (stats by a | keep x = a) which causes the output layout to refer to the follow-up alias */ + // TODO: This is likely required only for pre-8.14 node compatibility; confirm and remove if possible. + // Since https://github.com/elastic/elasticsearch/pull/104958, it shouldn't be possible to have aliases in the aggregates + // which the groupings refer to. Except for `BY CATEGORIZE(field)`, which remains as alias in the grouping, all aliases + // should've become EVALs before or after the STATS. for (NamedExpression agg : aggregates) { if (agg instanceof Alias a) { if (a.child() instanceof Attribute attr) { - if (groupAttribute.id().equals(attr.id())) { + if (sourceGroupAttribute.id().equals(attr.id())) { groupAttributeLayout.nameIds().add(a.id()); // TODO: investigate whether a break could be used since it shouldn't be possible to have multiple // attributes pointing to the same attribute @@ -133,8 +137,8 @@ public final PhysicalOperation groupingPhysicalOperation( // is in the output form // if the group points to an alias declared in the aggregate, use the alias child as source else if (aggregatorMode.isOutputPartial()) { - if (groupAttribute.semanticEquals(a.toAttribute())) { - groupAttribute = attr; + if (sourceGroupAttribute.semanticEquals(a.toAttribute())) { + sourceGroupAttribute = attr; break; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java index c998af2215169..5e13825d91bda 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java @@ -29,14 +29,8 @@ import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer; -import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Filter; -import org.elasticsearch.xpack.esql.plan.logical.Limit; -import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.plan.logical.TopN; -import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec; import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; @@ -44,10 +38,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; -import org.elasticsearch.xpack.esql.plan.physical.LimitExec; -import org.elasticsearch.xpack.esql.plan.physical.OrderExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; -import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.planner.mapper.LocalMapper; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; import org.elasticsearch.xpack.esql.session.Configuration; @@ -83,29 +74,25 @@ public static Tuple breakPlanBetweenCoordinatorAndDa return new Tuple<>(coordinatorPlan, dataNodePlan.get()); } - public static PhysicalPlan dataNodeReductionPlan(LogicalPlan plan, PhysicalPlan unused) { - var pipelineBreakers = plan.collectFirstChildren(Mapper::isPipelineBreaker); + public static PhysicalPlan reductionPlan(PhysicalPlan plan) { + // find the logical fragment + var fragments = plan.collectFirstChildren(p -> p instanceof FragmentExec); + if (fragments.isEmpty()) { + return null; + } + final FragmentExec fragment = (FragmentExec) fragments.get(0); - if (pipelineBreakers.isEmpty() == false) { - UnaryPlan pipelineBreaker = (UnaryPlan) pipelineBreakers.get(0); - if (pipelineBreaker instanceof TopN) { - LocalMapper mapper = new LocalMapper(); - var physicalPlan = EstimatesRowSize.estimateRowSize(0, mapper.map(plan)); - return physicalPlan.collectFirstChildren(TopNExec.class::isInstance).get(0); - } else if (pipelineBreaker instanceof Limit limit) { - return new LimitExec(limit.source(), unused, limit.limit()); - } else if (pipelineBreaker instanceof OrderBy order) { - return new OrderExec(order.source(), unused, order.order()); - } else if (pipelineBreaker instanceof Aggregate) { - LocalMapper mapper = new LocalMapper(); - var physicalPlan = EstimatesRowSize.estimateRowSize(0, mapper.map(plan)); - var aggregate = (AggregateExec) physicalPlan.collectFirstChildren(AggregateExec.class::isInstance).get(0); - return aggregate.withMode(AggregatorMode.INITIAL); - } else { - throw new EsqlIllegalArgumentException("unsupported unary physical plan node [" + pipelineBreaker.nodeName() + "]"); - } + final var pipelineBreakers = fragment.fragment().collectFirstChildren(Mapper::isPipelineBreaker); + if (pipelineBreakers.isEmpty()) { + return null; + } + final var pipelineBreaker = pipelineBreakers.get(0); + final LocalMapper mapper = new LocalMapper(); + PhysicalPlan reducePlan = mapper.map(pipelineBreaker); + if (reducePlan instanceof AggregateExec agg) { + reducePlan = agg.withMode(AggregatorMode.INITIAL); // force to emit intermediate outputs } - return null; + return EstimatesRowSize.estimateRowSize(fragment.estimatedRowSize(), reducePlan); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index b06dd3cdb64d3..9aea1577a4137 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -60,12 +60,10 @@ import org.elasticsearch.xpack.esql.action.EsqlQueryAction; import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; import org.elasticsearch.xpack.esql.plan.physical.OutputExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; @@ -780,35 +778,24 @@ private void runComputeOnDataNode( } } + private static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) { + PhysicalPlan reducePlan = new ExchangeSourceExec(plan.source(), plan.output(), plan.isIntermediateAgg()); + if (enable) { + PhysicalPlan p = PlannerUtils.reductionPlan(plan); + if (p != null) { + reducePlan = p.replaceChildren(List.of(reducePlan)); + } + } + return new ExchangeSinkExec(plan.source(), plan.output(), plan.isIntermediateAgg(), reducePlan); + } + private class DataNodeRequestHandler implements TransportRequestHandler { @Override public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) { final ActionListener listener = new ChannelActionListener<>(channel); - final ExchangeSinkExec reducePlan; + final PhysicalPlan reductionPlan; if (request.plan() instanceof ExchangeSinkExec plan) { - var fragments = plan.collectFirstChildren(FragmentExec.class::isInstance); - if (fragments.isEmpty()) { - listener.onFailure(new IllegalStateException("expected a fragment plan for a remote compute; got " + request.plan())); - return; - } - var localExchangeSource = new ExchangeSourceExec(plan.source(), plan.output(), plan.isIntermediateAgg()); - Holder reducePlanHolder = new Holder<>(); - if (request.pragmas().nodeLevelReduction()) { - PhysicalPlan dataNodePlan = request.plan(); - request.plan() - .forEachUp( - FragmentExec.class, - f -> { reducePlanHolder.set(PlannerUtils.dataNodeReductionPlan(f.fragment(), dataNodePlan)); } - ); - } - reducePlan = new ExchangeSinkExec( - plan.source(), - plan.output(), - plan.isIntermediateAgg(), - reducePlanHolder.get() != null - ? reducePlanHolder.get().replaceChildren(List.of(localExchangeSource)) - : localExchangeSource - ); + reductionPlan = reductionPlan(plan, request.pragmas().nodeLevelReduction()); } else { listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + request.plan())); return; @@ -825,7 +812,7 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T request.indicesOptions() ); try (var computeListener = ComputeListener.create(transportService, (CancellableTask) task, listener)) { - runComputeOnDataNode((CancellableTask) task, sessionId, reducePlan, request, computeListener); + runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, computeListener); } } } @@ -871,10 +858,10 @@ public void messageReceived(ClusterComputeRequest request, TransportChannel chan * Performs a compute on a remote cluster. The output pages are placed in an exchange sink specified by * {@code globalSessionId}. The coordinator on the main cluster will poll pages from there. *

- * Currently, the coordinator on the remote cluster simply collects pages from data nodes in the remote cluster - * and places them in the exchange sink. We can achieve this by using a single exchange buffer to minimize overhead. - * However, here we use two exchange buffers so that we can run an actual plan on this coordinator to perform partial - * reduce operations, such as limit, topN, and partial-to-partial aggregation in the future. + * Currently, the coordinator on the remote cluster polls pages from data nodes within the remote cluster + * and performs cluster-level reduction before sending pages to the querying cluster. This reduction aims + * to minimize data transfers across clusters but may require additional CPU resources for operations like + * aggregations. */ void runComputeOnRemoteCluster( String clusterAlias, @@ -892,6 +879,7 @@ void runComputeOnRemoteCluster( () -> exchangeService.finishSinkHandler(globalSessionId, new TaskCancelledException(parentTask.getReasonCancelled())) ); final String localSessionId = clusterAlias + ":" + globalSessionId; + final PhysicalPlan coordinatorPlan = reductionPlan(plan, true); var exchangeSource = new ExchangeSourceHandler( configuration.pragmas().exchangeBufferSize(), transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), @@ -899,12 +887,6 @@ void runComputeOnRemoteCluster( ); try (Releasable ignored = exchangeSource.addEmptySink()) { exchangeSink.addCompletionListener(computeListener.acquireAvoid()); - PhysicalPlan coordinatorPlan = new ExchangeSinkExec( - plan.source(), - plan.output(), - plan.isIntermediateAgg(), - new ExchangeSourceExec(plan.source(), plan.output(), plan.isIntermediateAgg()) - ); runCompute( parentTask, new ComputeContext(localSessionId, clusterAlias, List.of(), configuration, exchangeSource, exchangeSink), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java index 8564e4b3afde1..031bfd7139a84 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java @@ -9,12 +9,14 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; record RemoteClusterPlan(PhysicalPlan plan, String[] targetIndices, OriginalIndices originalIndices) { static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { @@ -24,7 +26,8 @@ static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { if (planIn.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { originalIndices = OriginalIndices.readOriginalIndices(planIn); } else { - originalIndices = new OriginalIndices(planIn.readStringArray(), IndicesOptions.strictSingleIndexNoExpandForbidClosed()); + // fallback to the previous behavior + originalIndices = new OriginalIndices(planIn.readStringArray(), SearchRequest.DEFAULT_INDICES_OPTIONS); } return new RemoteClusterPlan(plan, targetIndices, originalIndices); } @@ -38,4 +41,18 @@ public void writeTo(PlanStreamOutput out) throws IOException { out.writeStringArray(originalIndices.indices()); } } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + RemoteClusterPlan that = (RemoteClusterPlan) o; + return Objects.equals(plan, that.plan) + && Objects.deepEquals(targetIndices, that.targetIndices) + && Objects.equals(originalIndices, that.originalIndices); + } + + @Override + public int hashCode() { + return Objects.hash(plan, Arrays.hashCode(targetIndices), originalIndices); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index d02e78202e0c2..74e2de1141728 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -407,12 +407,12 @@ public void testAggFilterOnBucketingOrAggFunctions() { // but fails if it's different assertEquals( - "1:32: can only use grouping function [bucket(a, 3)] part of the BY clause", + "1:32: can only use grouping function [bucket(a, 3)] as part of the BY clause", error("row a = 1 | stats sum(a) where bucket(a, 3) > -1 by bucket(a,2)") ); assertEquals( - "1:40: can only use grouping function [bucket(salary, 10)] part of the BY clause", + "1:40: can only use grouping function [bucket(salary, 10)] as part of the BY clause", error("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by emp_no") ); @@ -444,19 +444,19 @@ public void testAggWithNonBooleanFilter() { public void testGroupingInsideAggsAsAgg() { assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.) by emp_no") ); assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.)") ); assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.) by bucket(emp_no, 6.)") ); assertEquals( - "1:22: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:22: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats 3 + bucket(emp_no, 5.) by bucket(emp_no, 6.)") ); } @@ -1846,7 +1846,7 @@ public void testIntervalAsString() { } public void testCategorizeSingleGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)"); query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); @@ -1875,7 +1875,7 @@ public void testCategorizeSingleGrouping() { } public void testCategorizeNestedGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)"); @@ -1890,27 +1890,33 @@ public void testCategorizeNestedGrouping() { } public void testCategorizeWithinAggregations() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)"); + query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY cat = CATEGORIZE(first_name)"); + query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY CATEGORIZE(first_name)"); assertEquals( - "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within the aggregations", + "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within an aggregation", error("FROM test | STATS COUNT(CATEGORIZE(first_name)) BY CATEGORIZE(first_name)") ); - assertEquals( - "1:25: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + "1:25: cannot reference CATEGORIZE grouping function [cat] within an aggregation", error("FROM test | STATS COUNT(cat) BY cat = CATEGORIZE(first_name)") ); assertEquals( - "1:30: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + "1:30: cannot reference CATEGORIZE grouping function [cat] within an aggregation", error("FROM test | STATS SUM(LENGTH(cat::keyword) + LENGTH(last_name)) BY cat = CATEGORIZE(first_name)") ); assertEquals( - "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within the aggregations", + "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within an aggregation", error("FROM test | STATS COUNT(`CATEGORIZE(first_name)`) BY CATEGORIZE(first_name)") ); + + assertEquals( + "1:28: can only use grouping function [CATEGORIZE(last_name)] as part of the BY clause", + error("FROM test | STATS MV_COUNT(CATEGORIZE(last_name)) BY CATEGORIZE(first_name)") + ); } public void testSortByAggregate() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index ec02995978d97..4c2ad531f3f1c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -1212,12 +1212,12 @@ public void testCombineProjectionWithAggregationFirstAndAliasedGroupingUsedInAgg * \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..] */ public void testCombineProjectionWithCategorizeGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); var plan = plan(""" from test | eval k = first_name, k1 = k - | stats s = sum(salary) by cat = CATEGORIZE(k) + | stats s = sum(salary) by cat = CATEGORIZE(k1) | keep s, cat """); @@ -3949,7 +3949,7 @@ public void testNestedExpressionsInGroups() { * \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..] */ public void testNestedExpressionsInGroupsWithCategorize() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); var plan = optimizedPlan(""" from test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java new file mode 100644 index 0000000000000..07ca112e8c527 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java @@ -0,0 +1,206 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.xpack.esql.ConfigurationTestUtils; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.analysis.Analyzer; +import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.index.EsIndex; +import org.elasticsearch.xpack.esql.index.IndexResolution; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.parser.EsqlParser; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; +import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.hamcrest.Matchers.equalTo; + +public class ClusterRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return ClusterComputeRequest::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List writeables = new ArrayList<>(); + writeables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables()); + writeables.addAll(new EsqlPlugin().getNamedWriteables()); + return new NamedWriteableRegistry(writeables); + } + + @Override + protected ClusterComputeRequest createTestInstance() { + var sessionId = randomAlphaOfLength(10); + String query = randomQuery(); + PhysicalPlan physicalPlan = DataNodeRequestTests.mapAndMaybeOptimize(parse(query)); + OriginalIndices originalIndices = new OriginalIndices( + generateRandomStringArray(10, 10, false, false), + IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) + ); + String[] targetIndices = generateRandomStringArray(10, 10, false, false); + ClusterComputeRequest request = new ClusterComputeRequest( + randomAlphaOfLength(10), + sessionId, + randomConfiguration(query, randomTables()), + new RemoteClusterPlan(physicalPlan, targetIndices, originalIndices) + ); + request.setParentTask(randomAlphaOfLength(10), randomNonNegativeLong()); + return request; + } + + @Override + protected ClusterComputeRequest mutateInstance(ClusterComputeRequest in) throws IOException { + return switch (between(0, 4)) { + case 0 -> { + var request = new ClusterComputeRequest( + randomValueOtherThan(in.clusterAlias(), () -> randomAlphaOfLength(10)), + in.sessionId(), + in.configuration(), + in.remoteClusterPlan() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 1 -> { + var request = new ClusterComputeRequest( + in.clusterAlias(), + randomValueOtherThan(in.sessionId(), () -> randomAlphaOfLength(10)), + in.configuration(), + in.remoteClusterPlan() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 2 -> { + var request = new ClusterComputeRequest( + in.clusterAlias(), + in.sessionId(), + randomValueOtherThan(in.configuration(), ConfigurationTestUtils::randomConfiguration), + in.remoteClusterPlan() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 3 -> { + RemoteClusterPlan plan = in.remoteClusterPlan(); + var request = new ClusterComputeRequest( + in.clusterAlias(), + in.sessionId(), + in.configuration(), + new RemoteClusterPlan( + plan.plan(), + randomValueOtherThan(plan.targetIndices(), () -> generateRandomStringArray(10, 10, false, false)), + plan.originalIndices() + ) + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 4 -> { + RemoteClusterPlan plan = in.remoteClusterPlan(); + var request = new ClusterComputeRequest( + in.clusterAlias(), + in.sessionId(), + in.configuration(), + new RemoteClusterPlan( + plan.plan(), + plan.targetIndices(), + new OriginalIndices( + plan.originalIndices().indices(), + randomValueOtherThan( + plan.originalIndices().indicesOptions(), + () -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) + ) + ) + ) + ); + request.setParentTask(in.getParentTask()); + yield request; + } + default -> throw new AssertionError("invalid value"); + }; + } + + public void testFallbackIndicesOptions() throws Exception { + ClusterComputeRequest request = createTestInstance(); + var version = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_14_0, + TransportVersions.ESQL_ORIGINAL_INDICES + ); + ClusterComputeRequest cloned = copyInstance(request, version); + assertThat(cloned.clusterAlias(), equalTo(request.clusterAlias())); + assertThat(cloned.sessionId(), equalTo(request.sessionId())); + assertThat(cloned.configuration(), equalTo(request.configuration())); + RemoteClusterPlan plan = cloned.remoteClusterPlan(); + assertThat(plan.plan(), equalTo(request.remoteClusterPlan().plan())); + assertThat(plan.targetIndices(), equalTo(request.remoteClusterPlan().targetIndices())); + OriginalIndices originalIndices = plan.originalIndices(); + assertThat(originalIndices.indices(), equalTo(request.remoteClusterPlan().originalIndices().indices())); + assertThat(originalIndices.indicesOptions(), equalTo(SearchRequest.DEFAULT_INDICES_OPTIONS)); + } + + private static String randomQuery() { + return randomFrom(""" + from test + | where round(emp_no) > 10 + | limit 10 + """, """ + from test + | sort last_name + | limit 10 + | where round(emp_no) > 10 + | eval c = first_name + """); + } + + static LogicalPlan parse(String query) { + Map mapping = loadMapping("mapping-basic.json"); + EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); + IndexResolution getIndexResult = IndexResolution.valid(test); + var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var analyzer = new Analyzer( + new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), + TEST_VERIFIER + ); + return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query))); + } + + @Override + protected List filteredWarnings() { + return withDefaultLimitWarning(super.filteredWarnings()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 587da844fceda..e28f5ce9a76a6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -48,7 +48,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; @@ -376,8 +375,6 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeStorageProvider; -import org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder; -import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder; import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; import org.elasticsearch.xpack.ml.rest.RestMlInfoAction; import org.elasticsearch.xpack.ml.rest.RestMlMemoryAction; @@ -1764,22 +1761,6 @@ public List> getQueryVectorBuilders() { ); } - @Override - public List> getQueries() { - return List.of( - new QuerySpec( - TextExpansionQueryBuilder.NAME, - TextExpansionQueryBuilder::new, - TextExpansionQueryBuilder::fromXContent - ), - new QuerySpec( - SparseVectorQueryBuilder.NAME, - SparseVectorQueryBuilder::new, - SparseVectorQueryBuilder::fromXContent - ) - ); - } - private ContextParser checkAggLicense(ContextParser realParser, LicensedFeature.Momentary feature) { return (parser, name) -> { if (feature.check(getLicenseState()) == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetUpgradeModeAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetUpgradeModeAction.java index 744d5dbd6974f..5912619e892ed 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetUpgradeModeAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetUpgradeModeAction.java @@ -9,35 +9,27 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.ClusterStateUpdateTask; -import org.elasticsearch.cluster.block.ClusterBlockException; -import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Predicates; -import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.persistent.PersistentTasksClusterService; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask; import org.elasticsearch.persistent.PersistentTasksService; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.AbstractTransportSetUpgradeModeAction; +import org.elasticsearch.xpack.core.action.SetUpgradeModeActionRequest; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; @@ -48,7 +40,6 @@ import java.util.Comparator; import java.util.List; import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.elasticsearch.ExceptionsHelper.rethrowAndSuppress; @@ -58,12 +49,11 @@ import static org.elasticsearch.xpack.core.ml.MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; import static org.elasticsearch.xpack.core.ml.MlTasks.JOB_TASK_NAME; -public class TransportSetUpgradeModeAction extends AcknowledgedTransportMasterNodeAction { +public class TransportSetUpgradeModeAction extends AbstractTransportSetUpgradeModeAction { private static final Set ML_TASK_NAMES = Set.of(JOB_TASK_NAME, DATAFEED_TASK_NAME, DATA_FRAME_ANALYTICS_TASK_NAME); private static final Logger logger = LogManager.getLogger(TransportSetUpgradeModeAction.class); - private final AtomicBoolean isRunning = new AtomicBoolean(false); private final PersistentTasksClusterService persistentTasksClusterService; private final PersistentTasksService persistentTasksService; private final OriginSettingClient client; @@ -79,69 +69,38 @@ public TransportSetUpgradeModeAction( Client client, PersistentTasksService persistentTasksService ) { - super( - SetUpgradeModeAction.NAME, - transportService, - clusterService, - threadPool, - actionFilters, - SetUpgradeModeAction.Request::new, - indexNameExpressionResolver, - EsExecutors.DIRECT_EXECUTOR_SERVICE - ); + super(SetUpgradeModeAction.NAME, "ml", transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver); this.persistentTasksClusterService = persistentTasksClusterService; this.client = new OriginSettingClient(client, ML_ORIGIN); this.persistentTasksService = persistentTasksService; } @Override - protected void masterOperation( - Task task, - SetUpgradeModeAction.Request request, - ClusterState state, - ActionListener listener - ) throws Exception { - - // Don't want folks spamming this endpoint while it is in progress, only allow one request to be handled at a time - if (isRunning.compareAndSet(false, true) == false) { - String msg = "Attempted to set [upgrade_mode] to [" - + request.isEnabled() - + "] from [" - + MlMetadata.getMlMetadata(state).isUpgradeMode() - + "] while previous request was processing."; - logger.info(msg); - Exception detail = new IllegalStateException(msg); - listener.onFailure( - new ElasticsearchStatusException( - "Cannot change [upgrade_mode]. Previous request is still being processed.", - RestStatus.TOO_MANY_REQUESTS, - detail - ) - ); - return; - } + protected String featureName() { + return "ml-set-upgrade-mode"; + } - // Noop, nothing for us to do, simply return fast to the caller - if (request.isEnabled() == MlMetadata.getMlMetadata(state).isUpgradeMode()) { - logger.info("Upgrade mode noop"); - isRunning.set(false); - listener.onResponse(AcknowledgedResponse.TRUE); - return; - } + @Override + protected boolean upgradeMode(ClusterState state) { + return MlMetadata.getMlMetadata(state).isUpgradeMode(); + } - logger.info( - "Starting to set [upgrade_mode] to [" + request.isEnabled() + "] from [" + MlMetadata.getMlMetadata(state).isUpgradeMode() + "]" - ); + @Override + protected ClusterState createUpdatedState(SetUpgradeModeActionRequest request, ClusterState currentState) { + logger.trace("Executing cluster state update"); + MlMetadata.Builder builder = new MlMetadata.Builder(currentState.metadata().custom(MlMetadata.TYPE)); + builder.isUpgradeMode(request.enabled()); + ClusterState.Builder newState = ClusterState.builder(currentState); + newState.metadata(Metadata.builder(currentState.getMetadata()).putCustom(MlMetadata.TYPE, builder.build()).build()); + return newState.build(); + } - ActionListener wrappedListener = ActionListener.wrap(r -> { - logger.info("Completed upgrade mode request"); - isRunning.set(false); - listener.onResponse(r); - }, e -> { - logger.info("Completed upgrade mode request but with failure", e); - isRunning.set(false); - listener.onFailure(e); - }); + protected void upgradeModeSuccessfullyChanged( + Task task, + SetUpgradeModeActionRequest request, + ClusterState state, + ActionListener wrappedListener + ) { final PersistentTasksCustomMetadata tasksCustomMetadata = state.metadata().custom(PersistentTasksCustomMetadata.TYPE); // <4> We have unassigned the tasks, respond to the listener. @@ -201,71 +160,29 @@ protected void masterOperation( */ - ActionListener clusterStateUpdateListener = ActionListener.wrap(acknowledgedResponse -> { - // State change was not acknowledged, we either timed out or ran into some exception - // We should not continue and alert failure to the end user - if (acknowledgedResponse.isAcknowledged() == false) { - logger.info("Cluster state update is NOT acknowledged"); - wrappedListener.onFailure(new ElasticsearchTimeoutException("Unknown error occurred while updating cluster state")); - return; - } - - // There are no tasks to worry about starting/stopping - if (tasksCustomMetadata == null || tasksCustomMetadata.tasks().isEmpty()) { - logger.info("No tasks to worry about after state update"); - wrappedListener.onResponse(AcknowledgedResponse.TRUE); - return; - } - - // Did we change from disabled -> enabled? - if (request.isEnabled()) { - logger.info("Enabling upgrade mode, must isolate datafeeds"); - isolateDatafeeds(tasksCustomMetadata, isolateDatafeedListener); - } else { - logger.info("Disabling upgrade mode, must wait for tasks to not have AWAITING_UPGRADE assignment"); - persistentTasksService.waitForPersistentTasksCondition( - // Wait for jobs, datafeeds and analytics not to be "Awaiting upgrade" - persistentTasksCustomMetadata -> persistentTasksCustomMetadata.tasks() - .stream() - .noneMatch(t -> ML_TASK_NAMES.contains(t.getTaskName()) && t.getAssignment().equals(AWAITING_UPGRADE)), - request.ackTimeout(), - ActionListener.wrap(r -> { - logger.info("Done waiting for tasks to be out of AWAITING_UPGRADE"); - wrappedListener.onResponse(AcknowledgedResponse.TRUE); - }, wrappedListener::onFailure) - ); - } - }, wrappedListener::onFailure); - - // <1> Change MlMetadata to indicate that upgrade_mode is now enabled - submitUnbatchedTask("ml-set-upgrade-mode", new AckedClusterStateUpdateTask(request, clusterStateUpdateListener) { - - @Override - protected AcknowledgedResponse newResponse(boolean acknowledged) { - logger.trace("Cluster update response built: " + acknowledged); - return AcknowledgedResponse.of(acknowledged); - } - - @Override - public ClusterState execute(ClusterState currentState) throws Exception { - logger.trace("Executing cluster state update"); - MlMetadata.Builder builder = new MlMetadata.Builder(currentState.metadata().custom(MlMetadata.TYPE)); - builder.isUpgradeMode(request.isEnabled()); - ClusterState.Builder newState = ClusterState.builder(currentState); - newState.metadata(Metadata.builder(currentState.getMetadata()).putCustom(MlMetadata.TYPE, builder.build()).build()); - return newState.build(); - } - }); - } - - @SuppressForbidden(reason = "legacy usage of unbatched task") // TODO add support for batching here - private void submitUnbatchedTask(@SuppressWarnings("SameParameterValue") String source, ClusterStateUpdateTask task) { - clusterService.submitUnbatchedStateUpdateTask(source, task); - } + if (tasksCustomMetadata == null || tasksCustomMetadata.tasks().isEmpty()) { + logger.info("No tasks to worry about after state update"); + wrappedListener.onResponse(AcknowledgedResponse.TRUE); + return; + } - @Override - protected ClusterBlockException checkBlock(SetUpgradeModeAction.Request request, ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + if (request.enabled()) { + logger.info("Enabling upgrade mode, must isolate datafeeds"); + isolateDatafeeds(tasksCustomMetadata, isolateDatafeedListener); + } else { + logger.info("Disabling upgrade mode, must wait for tasks to not have AWAITING_UPGRADE assignment"); + persistentTasksService.waitForPersistentTasksCondition( + // Wait for jobs, datafeeds and analytics not to be "Awaiting upgrade" + persistentTasksCustomMetadata -> persistentTasksCustomMetadata.tasks() + .stream() + .noneMatch(t -> ML_TASK_NAMES.contains(t.getTaskName()) && t.getAssignment().equals(AWAITING_UPGRADE)), + request.ackTimeout(), + ActionListener.wrap(r -> { + logger.info("Done waiting for tasks to be out of AWAITING_UPGRADE"); + wrappedListener.onResponse(AcknowledgedResponse.TRUE); + }, wrappedListener::onFailure) + ); + } } /**