diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java index bf539efaf3c30..54962ac241f75 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java @@ -433,7 +433,7 @@ private void commonNodeConfig() { if (node.getTestDistribution().equals(TestDistribution.INTEG_TEST)) { node.defaultConfig.put("xpack.security.enabled", "false"); } else { - if (node.getVersion().onOrAfter("7.16.0")) { + if (hasDeprecationIndexing(node)) { node.defaultConfig.put("cluster.deprecation_indexing.enabled", "false"); } } @@ -474,13 +474,17 @@ public void nextNodeToNextVersion() { commonNodeConfig(); nodeIndex += 1; if (node.getTestDistribution().equals(TestDistribution.DEFAULT)) { - if (node.getVersion().onOrAfter("7.16.0")) { + if (hasDeprecationIndexing(node)) { node.setting("cluster.deprecation_indexing.enabled", "false"); } } node.start(); } + private static boolean hasDeprecationIndexing(ElasticsearchNode node) { + return node.getVersion().onOrAfter("7.16.0") && node.getSettingKeys().contains("stateless.enabled") == false; + } + @Override public void extraConfigFile(String destination, File from) { nodes.all(node -> node.extraConfigFile(destination, from)); diff --git a/docs/reference/mapping/params/format.asciidoc b/docs/reference/mapping/params/format.asciidoc index dff7bb4a11ee4..5babb4def2320 100644 --- a/docs/reference/mapping/params/format.asciidoc +++ b/docs/reference/mapping/params/format.asciidoc @@ -70,6 +70,11 @@ The following tables lists all the defaults ISO formats supported: (separated by `T`), is optional. Examples: `yyyy-MM-dd'T'HH:mm:ss.SSSZ` or `yyyy-MM-dd`. + NOTE: When using `date_optional_time`, the parsing is lenient and will attempt to parse + numbers as a year (e.g. `292278994` will be parsed as a year). This can lead to unexpected results + when paired with a numeric focused format like `epoch_second` and `epoch_millis`. + It is recommended you use `strict_date_optional_time` when pairing with a numeric focused format. + [[strict-date-time-nanos]]`strict_date_optional_time_nanos`:: A generic ISO datetime parser, where the date must include the year at a minimum, and the time diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQuery.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQuery.java index dc51afe5d420d..3d0f26e8cc130 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQuery.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQuery.java @@ -9,9 +9,7 @@ package org.elasticsearch.index.mapper.extras; import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInvertState; -import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -300,19 +298,23 @@ public RuntimePhraseScorer scorer(LeafReaderContext context) throws IOException @Override public Matches matches(LeafReaderContext context, int doc) throws IOException { - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null) { + var terms = context.reader().terms(field); + if (terms == null) { return null; } - // Some highlighters will already have reindexed the source with positions and offsets, + // Some highlighters will already have re-indexed the source with positions and offsets, // so rather than doing it again we check to see if this data is available on the // current context and if so delegate directly to the inner query - if (fi.getIndexOptions().compareTo(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) > 0) { + if (terms.hasOffsets()) { Weight innerWeight = in.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1); return innerWeight.matches(context, doc); } RuntimePhraseScorer scorer = scorer(context); - if (scorer == null || scorer.iterator().advance(doc) != doc) { + if (scorer == null) { + return null; + } + final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator(); + if (twoPhase.approximation().advance(doc) != doc || scorer.twoPhaseIterator().matches() == false) { return null; } return scorer.matches(); @@ -321,13 +323,14 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { } private class RuntimePhraseScorer extends Scorer { - private final LeafSimScorer scorer; private final CheckedIntFunction, IOException> valueFetcher; private final String field; private final Query query; private final TwoPhaseIterator twoPhase; + private final MemoryIndexEntry cacheEntry = new MemoryIndexEntry(); + private int doc = -1; private float freq; @@ -357,7 +360,6 @@ public float matchCost() { // Defaults to a high-ish value so that it likely runs last. return 10_000f; } - }; } @@ -394,35 +396,35 @@ private float freq() throws IOException { return freq; } - private float computeFreq() throws IOException { - MemoryIndex index = new MemoryIndex(); - index.setSimilarity(FREQ_SIMILARITY); - List values = valueFetcher.apply(docID()); - float frequency = 0; - for (Object value : values) { - if (value == null) { - continue; + private MemoryIndex getOrCreateMemoryIndex() throws IOException { + if (cacheEntry.docID != docID()) { + cacheEntry.docID = docID(); + cacheEntry.memoryIndex = new MemoryIndex(true, false); + cacheEntry.memoryIndex.setSimilarity(FREQ_SIMILARITY); + List values = valueFetcher.apply(docID()); + for (Object value : values) { + if (value == null) { + continue; + } + cacheEntry.memoryIndex.addField(field, value.toString(), indexAnalyzer); } - index.addField(field, value.toString(), indexAnalyzer); - frequency += index.search(query); - index.reset(); } - return frequency; + return cacheEntry.memoryIndex; + } + + private float computeFreq() throws IOException { + return getOrCreateMemoryIndex().search(query); } private Matches matches() throws IOException { - MemoryIndex index = new MemoryIndex(true, false); - List values = valueFetcher.apply(docID()); - for (Object value : values) { - if (value == null) { - continue; - } - index.addField(field, value.toString(), indexAnalyzer); - } - IndexSearcher searcher = index.createSearcher(); + IndexSearcher searcher = getOrCreateMemoryIndex().createSearcher(); Weight w = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1); return w.matches(searcher.getLeafContexts().get(0), 0); } } + private static class MemoryIndexEntry { + private int docID = -1; + private MemoryIndex memoryIndex; + } } diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQueryTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQueryTests.java index 2b8d5870cb8aa..81e1dd7099860 100644 --- a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQueryTests.java +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/SourceConfirmedTextQueryTests.java @@ -49,13 +49,19 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; public class SourceConfirmedTextQueryTests extends ESTestCase { + private static final AtomicInteger sourceFetchCount = new AtomicInteger(); private static final IOFunction, IOException>> SOURCE_FETCHER_PROVIDER = - context -> docID -> Collections.singletonList(context.reader().document(docID).get("body")); + context -> docID -> { + sourceFetchCount.incrementAndGet(); + return Collections.singletonList(context.reader().document(docID).get("body")); + }; public void testTerm() throws Exception { try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig(Lucene.STANDARD_ANALYZER))) { @@ -440,11 +446,11 @@ public void testEmptyIndex() throws Exception { } public void testMatches() throws Exception { - checkMatches(new TermQuery(new Term("body", "d")), "a b c d e", new int[] { 3, 3 }); - checkMatches(new PhraseQuery("body", "b", "c"), "a b c d c b c a", new int[] { 1, 2, 5, 6 }); + checkMatches(new TermQuery(new Term("body", "d")), "a b c d e", new int[] { 3, 3 }, false); + checkMatches(new PhraseQuery("body", "b", "c"), "a b c d c b c a", new int[] { 1, 2, 5, 6 }, true); } - private static void checkMatches(Query query, String inputDoc, int[] expectedMatches) throws IOException { + private static void checkMatches(Query query, String inputDoc, int[] expectedMatches, boolean expectedFetch) throws IOException { try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig(Lucene.STANDARD_ANALYZER))) { Document doc = new Document(); doc.add(new TextField("body", "xxxxxnomatchxxxx", Store.YES)); @@ -464,30 +470,48 @@ private static void checkMatches(Query query, String inputDoc, int[] expectedMat Query sourceConfirmedQuery = new SourceConfirmedTextQuery(query, SOURCE_FETCHER_PROVIDER, Lucene.STANDARD_ANALYZER); try (IndexReader ir = DirectoryReader.open(w)) { - - IndexSearcher searcher = new IndexSearcher(ir); - TopDocs td = searcher.search( - sourceConfirmedQuery, - 3, - new Sort(KeywordField.newSortField("sort", false, SortedSetSelector.Type.MAX)) - ); - - Weight weight = searcher.createWeight(searcher.rewrite(sourceConfirmedQuery), ScoreMode.COMPLETE_NO_SCORES, 1); - - int firstDoc = td.scoreDocs[0].doc; - LeafReaderContext firstCtx = searcher.getLeafContexts().get(ReaderUtil.subIndex(firstDoc, searcher.getLeafContexts())); - checkMatches(weight, firstCtx, firstDoc - firstCtx.docBase, expectedMatches, 0); - - int secondDoc = td.scoreDocs[1].doc; - LeafReaderContext secondCtx = searcher.getLeafContexts().get(ReaderUtil.subIndex(secondDoc, searcher.getLeafContexts())); - checkMatches(weight, secondCtx, secondDoc - secondCtx.docBase, expectedMatches, 1); - + { + IndexSearcher searcher = new IndexSearcher(ir); + TopDocs td = searcher.search( + sourceConfirmedQuery, + 3, + new Sort(KeywordField.newSortField("sort", false, SortedSetSelector.Type.MAX)) + ); + + Weight weight = searcher.createWeight(searcher.rewrite(sourceConfirmedQuery), ScoreMode.COMPLETE_NO_SCORES, 1); + + int firstDoc = td.scoreDocs[0].doc; + LeafReaderContext firstCtx = searcher.getLeafContexts().get(ReaderUtil.subIndex(firstDoc, searcher.getLeafContexts())); + checkMatches(weight, firstCtx, firstDoc - firstCtx.docBase, expectedMatches, 0, expectedFetch); + + int secondDoc = td.scoreDocs[1].doc; + LeafReaderContext secondCtx = searcher.getLeafContexts() + .get(ReaderUtil.subIndex(secondDoc, searcher.getLeafContexts())); + checkMatches(weight, secondCtx, secondDoc - secondCtx.docBase, expectedMatches, 1, expectedFetch); + } + + { + IndexSearcher searcher = new IndexSearcher(ir); + TopDocs td = searcher.search(KeywordField.newExactQuery("sort", "0"), 1); + + Weight weight = searcher.createWeight(searcher.rewrite(sourceConfirmedQuery), ScoreMode.COMPLETE_NO_SCORES, 1); + int firstDoc = td.scoreDocs[0].doc; + LeafReaderContext firstCtx = searcher.getLeafContexts().get(ReaderUtil.subIndex(firstDoc, searcher.getLeafContexts())); + checkMatches(weight, firstCtx, firstDoc - firstCtx.docBase, new int[0], 0, false); + } } } } - private static void checkMatches(Weight w, LeafReaderContext ctx, int doc, int[] expectedMatches, int offset) throws IOException { + private static void checkMatches(Weight w, LeafReaderContext ctx, int doc, int[] expectedMatches, int offset, boolean expectedFetch) + throws IOException { + int count = sourceFetchCount.get(); Matches matches = w.matches(ctx, doc); + if (expectedMatches.length == 0) { + assertNull(matches); + assertThat(sourceFetchCount.get() - count, equalTo(expectedFetch ? 1 : 0)); + return; + } assertNotNull(matches); MatchesIterator mi = matches.getMatches("body"); int i = 0; @@ -498,6 +522,7 @@ private static void checkMatches(Weight w, LeafReaderContext ctx, int doc, int[] i += 2; } assertEquals(expectedMatches.length, i); + assertThat(sourceFetchCount.get() - count, equalTo(expectedFetch ? 1 : 0)); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 02f905f7cd87a..fdf3af80b8526 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -24,11 +24,13 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.ModelSettings; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -46,10 +48,10 @@ public class BulkShardRequestInferenceProvider { public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; // Contains the original text for the field - public static final String TEXT_SUBFIELD_NAME = "text"; - // Contains the inference result when it's a sparse vector - public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; + public static final String INFERENCE_RESULTS = "inference_results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; private final ClusterState clusterState; private final Map inferenceProvidersMap; @@ -90,7 +92,13 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { InferenceProvider inferenceProvider = new InferenceProvider( - service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ), service.get() ); inferenceProviderMap.put(inferenceId, inferenceProvider); @@ -105,7 +113,7 @@ public void onFailure(Exception e) { } }; - modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); + modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); } } } @@ -259,35 +267,22 @@ public void onResponse(InferenceServiceResults results) { } int i = 0; - for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { - String fieldName = inferenceFieldNames.get(i++); - List> inferenceFieldResultList; - try { - inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new ArrayList<>() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object" + for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { + String inferenceFieldName = inferenceFieldNames.get(i++); + Map inferenceFieldResult = new LinkedHashMap<>(); + inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap()); + inferenceFieldResult.put( + INFERENCE_RESULTS, + List.of( + Map.of( + INFERENCE_CHUNKS_RESULTS, + inferenceResults.asMap("output").get("output"), + INFERENCE_CHUNKS_TEXT, + docMap.get(inferenceFieldName) ) - ); - return; - } - // Remove previous inference results if any - inferenceFieldResultList.clear(); - - // TODO Check inference result type to change subfield name - var inferenceFieldMap = Map.of( - SPARSE_VECTOR_SUBFIELD_NAME, - inferenceResults.asMap("output").get("output"), - TEXT_SUBFIELD_NAME, - docMap.get(fieldName) + ) ); - inferenceFieldResultList.add(inferenceFieldMap); + rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 47efa0ca49771..c6e4d4af926a2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -210,6 +210,16 @@ protected Parameter[] getParameters() { return new Parameter[] { elementType, dims, indexed, similarity, indexOptions, meta }; } + public Builder similarity(VectorSimilarity vectorSimilarity) { + similarity.setValue(vectorSimilarity); + return this; + } + + public Builder dimensions(int dimensions) { + this.dims.setValue(dimensions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { return new DenseVectorFieldMapper( @@ -708,7 +718,7 @@ static Function errorByteElementsAppender(byte[] v ElementType.FLOAT ); - enum VectorSimilarity { + public enum VectorSimilarity { L2_NORM { @Override float score(float similarity, ElementType elementType, int dim) { diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 62166115820f5..14cfeacf76139 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -35,6 +35,8 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragm /** * Convert the result to a map to aid with test assertions + * + * @return a map */ Map asMap(); } diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java similarity index 61% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to server/src/main/java/org/elasticsearch/inference/ModelSettings.java index 78773bfb72a95..957e2f44d5813 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -8,7 +8,6 @@ package org.elasticsearch.inference; -import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,28 +18,22 @@ import java.util.Objects; /** - * Model settings that are interesting for semantic_text inference fields. This class is used to serialize common - * ServiceSettings methods when building inference for semantic_text fields. - * - * @param taskType task type - * @param inferenceId inference id - * @param dimensions number of dimensions. May be null if not applicable - * @param similarity similarity used by the service. May be null if not applicable + * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. + * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ -public record SemanticTextModelSettings( - TaskType taskType, - String inferenceId, - @Nullable Integer dimensions, - @Nullable SimilarityMeasure similarity -) { +public class ModelSettings { public static final String NAME = "model_settings"; - private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + private final TaskType taskType; + private final String inferenceId; + private final Integer dimensions; + private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; @@ -49,7 +42,7 @@ public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer this.similarity = similarity; } - public SemanticTextModelSettings(Model model) { + public ModelSettings(Model model) { this( model.getTaskType(), model.getInferenceEntityId(), @@ -58,16 +51,16 @@ public SemanticTextModelSettings(Model model) { ); } - public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { + public static ModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { TaskType taskType = TaskType.fromString((String) args[0]); String inferenceId = (String) args[1]; Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[2]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); + SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); + return new ModelSettings(taskType, inferenceId, dimensions, similarity); }); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); @@ -88,4 +81,20 @@ public Map asMap() { } return Map.of(NAME, attrsMap); } + + public TaskType taskType() { + return taskType; + } + + public String inferenceId() { + return inferenceId; + } + + public Integer dimensions() { + return dimensions; + } + + public SimilarityMeasure similarity() { + return similarity; + } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index f8ed331d358b2..4b81e089ed2b2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -33,6 +33,9 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -56,6 +59,9 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -91,10 +97,10 @@ public void testNoInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -130,6 +136,26 @@ public void testNoInference() { verifyNoMoreInteractions(inferenceServiceRegistry); } + private static Model mockModel(String inferenceServiceId) { + Model model = mock(Model.class); + + when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); + TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; + when(model.getTaskType()).thenReturn(taskType); + + ServiceSettings serviceSettings = mock(ServiceSettings.class); + when(model.getServiceSettings()).thenReturn(serviceSettings); + SimilarityMeasure similarity = switch (randomInt(2)) { + case 0 -> SimilarityMeasure.COSINE; + case 1 -> SimilarityMeasure.DOT_PRODUCT; + default -> null; + }; + when(serviceSettings.similarity()).thenReturn(similarity); + when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); + + return model; + } + public void testFailedBulkShardRequest() { Map> fieldsForModels = Map.of(); @@ -191,10 +217,10 @@ public void testInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -257,8 +283,8 @@ public void testFailedInference() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); @@ -291,8 +317,8 @@ public void testInferenceFailsForIncorrectRootObject() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -315,39 +341,6 @@ public void testInferenceFailsForIncorrectRootObject() { assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); } - public void testInferenceFailsForIncorrectInferenceFieldObject() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value") - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object") - ); - } - public void testInferenceIdNotFound() { Map> fieldsForModels = Map.of( @@ -359,8 +352,8 @@ public void testInferenceIdNotFound() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceService(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -400,17 +393,20 @@ private static void checkInferenceResults( ); for (String inferenceFieldName : inferenceFieldNames) { - List> inferenceService1FieldResults = (List>) inferenceRootResultField.get( - inferenceFieldName - ); + Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(1)); - Map inferenceResultElement = inferenceService1FieldResults.get(0); - assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); - assertThat( - inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), - equalTo(docSource.get(inferenceFieldName)) + assertThat(inferenceService1FieldResults.size(), equalTo(2)); + Map modelSettings = (Map) inferenceService1FieldResults.get(ModelSettings.NAME); + assertNotNull(modelSettings); + assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName())); + assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName())); + + List> inferenceResultElement = (List>) inferenceService1FieldResults.get( + INFERENCE_RESULTS ); + assertFalse(inferenceResultElement.isEmpty()); + assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); + assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); } } @@ -421,8 +417,13 @@ private static void verifyInferenceServiceInvoked( Model model, Collection inferenceTexts ) { - verify(modelRegistry).getModel(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); + verify(inferenceService).parsePersistedConfigWithSecrets( + eq(inferenceService1Id), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ); verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); verifyNoMoreInteractions(inferenceService); } @@ -537,9 +538,16 @@ private static BulkShardRequest runBulkOperation( ); }; - private static InferenceService createInferenceService(Model model, String inferenceServiceId) { + private static InferenceService createInferenceService(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); @@ -548,7 +556,7 @@ private static InferenceService createInferenceService(Model model, String infer for (int i = 0; i < texts.size(); i++) { inferenceResults.add(createInferenceResults()); } - doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); + doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); listener.onResponse(inferenceServiceResults); return null; @@ -556,9 +564,16 @@ private static InferenceService createInferenceService(Model model, String infer return inferenceService; } - private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) { + private static InferenceService createInferenceServiceThatFails(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); @@ -591,7 +606,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("Model not found")); return null; - }).when(modelRegistry).getModel(any(), any()); + }).when(modelRegistry).getModelWithSecrets(any(), any()); inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( inferenceId, @@ -604,7 +619,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onResponse(unparsedModel); return null; - }).when(modelRegistry).getModel(eq(inferenceId), any()); + }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); }); return modelRegistry; diff --git a/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuite.java b/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuite.java index 65a23bd376212..e5f46ff135171 100644 --- a/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuite.java +++ b/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuite.java @@ -177,7 +177,7 @@ private static Stream validateExecutableSections( .filter(section -> false == section.getExpectedWarningHeaders().isEmpty()) .filter(section -> false == hasYamlRunnerFeature("warnings", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ - attempted to add a [do] with a [warnings] section without a corresponding ["skip": "features": "warnings"] \ + attempted to add a [do] with a [warnings] section without a corresponding ["requires": "test_runner_features": "warnings"] \ so runners that do not support the [warnings] section can skip the test at line [%d]\ """, section.getLocation().lineNumber())); @@ -190,7 +190,7 @@ private static Stream validateExecutableSections( .filter(section -> false == hasYamlRunnerFeature("warnings_regex", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ attempted to add a [do] with a [warnings_regex] section without a corresponding \ - ["skip": "features": "warnings_regex"] so runners that do not support the [warnings_regex] \ + ["requires": "test_runner_features": "warnings_regex"] so runners that do not support the [warnings_regex] \ section can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -204,7 +204,7 @@ private static Stream validateExecutableSections( .filter(section -> false == hasYamlRunnerFeature("allowed_warnings", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ attempted to add a [do] with a [allowed_warnings] section without a corresponding \ - ["skip": "features": "allowed_warnings"] so runners that do not support the [allowed_warnings] \ + ["requires": "test_runner_features": "allowed_warnings"] so runners that do not support the [allowed_warnings] \ section can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -218,8 +218,8 @@ private static Stream validateExecutableSections( .filter(section -> false == hasYamlRunnerFeature("allowed_warnings_regex", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ attempted to add a [do] with a [allowed_warnings_regex] section without a corresponding \ - ["skip": "features": "allowed_warnings_regex"] so runners that do not support the [allowed_warnings_regex] \ - section can skip the test at line [%d]\ + ["requires": "test_runner_features": "allowed_warnings_regex"] so runners that do not support the \ + [allowed_warnings_regex] section can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -232,7 +232,7 @@ private static Stream validateExecutableSections( .filter(section -> false == hasYamlRunnerFeature("node_selector", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ attempted to add a [do] with a [node_selector] section without a corresponding \ - ["skip": "features": "node_selector"] so runners that do not support the [node_selector] section \ + ["requires": "test_runner_features": "node_selector"] so runners that do not support the [node_selector] section \ can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -243,7 +243,7 @@ private static Stream validateExecutableSections( .filter(section -> section instanceof ContainsAssertion) .filter(section -> false == hasYamlRunnerFeature("contains", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ - attempted to add a [contains] assertion without a corresponding ["skip": "features": "contains"] \ + attempted to add a [contains] assertion without a corresponding ["requires": "test_runner_features": "contains"] \ so runners that do not support the [contains] assertion can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -256,8 +256,9 @@ private static Stream validateExecutableSections( .filter(section -> false == section.getApiCallSection().getHeaders().isEmpty()) .filter(section -> false == hasYamlRunnerFeature("headers", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ - attempted to add a [do] with a [headers] section without a corresponding ["skip": "features": "headers"] \ - so runners that do not support the [headers] section can skip the test at line [%d]\ + attempted to add a [do] with a [headers] section without a corresponding \ + ["requires": "test_runner_features": "headers"] so runners that do not support the [headers] section \ + can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -267,7 +268,7 @@ private static Stream validateExecutableSections( .filter(section -> section instanceof CloseToAssertion) .filter(section -> false == hasYamlRunnerFeature("close_to", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ - attempted to add a [close_to] assertion without a corresponding ["skip": "features": "close_to"] \ + attempted to add a [close_to] assertion without a corresponding ["requires": "test_runner_features": "close_to"] \ so runners that do not support the [close_to] assertion can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); @@ -278,7 +279,7 @@ private static Stream validateExecutableSections( .filter(section -> section instanceof IsAfterAssertion) .filter(section -> false == hasYamlRunnerFeature("is_after", testSection, setupSection, teardownSection)) .map(section -> String.format(Locale.ROOT, """ - attempted to add an [is_after] assertion without a corresponding ["skip": "features": "is_after"] \ + attempted to add an [is_after] assertion without a corresponding ["requires": "test_runner_features": "is_after"] \ so runners that do not support the [is_after] assertion can skip the test at line [%d]\ """, section.getLocation().lineNumber())) ); diff --git a/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSection.java b/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSection.java index 7f65a29e510b6..f4c9aaa619911 100644 --- a/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSection.java +++ b/test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSection.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.test.rest.yaml.ClientYamlTestExecutionContext; import org.elasticsearch.test.rest.yaml.Features; import org.elasticsearch.xcontent.XContentLocation; @@ -17,7 +18,9 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.Predicate; /** @@ -34,9 +37,13 @@ public class PrerequisiteSection { static class PrerequisiteSectionBuilder { String skipVersionRange = null; String skipReason = null; + String requiresReason = null; List requiredYamlRunnerFeatures = new ArrayList<>(); List skipOperatingSystems = new ArrayList<>(); + Set skipClusterFeatures = new HashSet<>(); + Set requiredClusterFeatures = new HashSet<>(); + enum XPackRequired { NOT_SPECIFIED, YES, @@ -56,6 +63,11 @@ public PrerequisiteSectionBuilder setSkipReason(String skipReason) { return this; } + public PrerequisiteSectionBuilder setRequiresReason(String requiresReason) { + this.requiresReason = requiresReason; + return this; + } + public PrerequisiteSectionBuilder requireYamlRunnerFeature(String featureName) { requiredYamlRunnerFeatures.add(featureName); return this; @@ -79,6 +91,16 @@ public PrerequisiteSectionBuilder skipIfXPack() { return this; } + public PrerequisiteSectionBuilder skipIfClusterFeature(String featureName) { + skipClusterFeatures.add(featureName); + return this; + } + + public PrerequisiteSectionBuilder requireClusterFeature(String featureName) { + requiredClusterFeatures.add(featureName); + return this; + } + public PrerequisiteSectionBuilder skipIfOs(String osName) { this.skipOperatingSystems.add(osName); return this; @@ -88,7 +110,9 @@ void validate(XContentLocation contentLocation) { if ((Strings.hasLength(skipVersionRange) == false) && requiredYamlRunnerFeatures.isEmpty() && skipOperatingSystems.isEmpty() - && xpackRequired == XPackRequired.NOT_SPECIFIED) { + && xpackRequired == XPackRequired.NOT_SPECIFIED + && requiredClusterFeatures.isEmpty() + && skipClusterFeatures.isEmpty()) { throw new ParsingException( contentLocation, "at least one criteria (version, cluster features, runner features, os) is mandatory within a skip section" @@ -100,6 +124,12 @@ void validate(XContentLocation contentLocation) { if (skipOperatingSystems.isEmpty() == false && Strings.hasLength(skipReason) == false) { throw new ParsingException(contentLocation, "reason is mandatory within skip os section"); } + if (skipClusterFeatures.isEmpty() == false && Strings.hasLength(skipReason) == false) { + throw new ParsingException(contentLocation, "reason is mandatory within skip cluster_features section"); + } + if (requiredClusterFeatures.isEmpty() == false && Strings.hasLength(requiresReason) == false) { + throw new ParsingException(contentLocation, "reason is mandatory within requires cluster_features section"); + } // make feature "skip_os" mandatory if os is given, this is a temporary solution until language client tests know about os if (skipOperatingSystems.isEmpty() == false && requiredYamlRunnerFeatures.contains("skip_os") == false) { throw new ParsingException(contentLocation, "if os is specified, test runner feature [skip_os] must be set"); @@ -107,6 +137,9 @@ void validate(XContentLocation contentLocation) { if (xpackRequired == XPackRequired.MISMATCHED) { throw new ParsingException(contentLocation, "either [xpack] or [no_xpack] can be present, not both"); } + if (Sets.haveNonEmptyIntersection(skipClusterFeatures, requiredClusterFeatures)) { + throw new ParsingException(contentLocation, "a cluster feature can be specified either in [requires] or [skip], not both"); + } } public PrerequisiteSection build() { @@ -131,8 +164,14 @@ public PrerequisiteSection build() { if (skipOperatingSystems.isEmpty() == false) { skipCriteriaList.add(Prerequisites.skipOnOsList(skipOperatingSystems)); } + if (requiredClusterFeatures.isEmpty() == false) { + requiresCriteriaList.add(Prerequisites.requireClusterFeatures(requiredClusterFeatures)); + } + if (skipClusterFeatures.isEmpty() == false) { + skipCriteriaList.add(Prerequisites.skipOnClusterFeatures(skipClusterFeatures)); + } } - return new PrerequisiteSection(skipCriteriaList, skipReason, requiresCriteriaList, null, requiredYamlRunnerFeatures); + return new PrerequisiteSection(skipCriteriaList, skipReason, requiresCriteriaList, requiresReason, requiredYamlRunnerFeatures); } } @@ -160,6 +199,10 @@ static PrerequisiteSectionBuilder parseInternal(XContentParser parser) throws IO parseSkipSection(parser, builder); hasPrerequisiteSection = true; maybeAdvanceToNextField(parser); + } else if ("requires".equals(parser.currentName())) { + parseRequiresSection(parser, builder); + hasPrerequisiteSection = true; + maybeAdvanceToNextField(parser); } else { unknownFieldName = true; } @@ -209,6 +252,8 @@ static void parseSkipSection(XContentParser parser, PrerequisiteSectionBuilder b parseFeatureField(parser.text(), builder); } else if ("os".equals(currentFieldName)) { builder.skipIfOs(parser.text()); + } else if ("cluster_features".equals(currentFieldName)) { + builder.skipIfClusterFeature(parser.text()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -224,6 +269,54 @@ static void parseSkipSection(XContentParser parser, PrerequisiteSectionBuilder b while (parser.nextToken() != XContentParser.Token.END_ARRAY) { builder.skipIfOs(parser.text()); } + } else if ("cluster_features".equals(currentFieldName)) { + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + builder.skipIfClusterFeature(parser.text()); + } + } + } + } + parser.nextToken(); + } + + static void parseRequiresSection(XContentParser parser, PrerequisiteSectionBuilder builder) throws IOException { + if (parser.nextToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "Expected [" + + XContentParser.Token.START_OBJECT + + ", found [" + + parser.currentToken() + + "], the requires section is not properly indented" + ); + } + String currentFieldName = null; + XContentParser.Token token; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if ("reason".equals(currentFieldName)) { + builder.setRequiresReason(parser.text()); + } else if ("test_runner_features".equals(currentFieldName)) { + parseFeatureField(parser.text(), builder); + } else if ("cluster_features".equals(currentFieldName)) { + builder.requireClusterFeature(parser.text()); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "field " + currentFieldName + " not supported within requires section" + ); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if ("test_runner_features".equals(currentFieldName)) { + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + parseFeatureField(parser.text(), builder); + } + } else if ("cluster_features".equals(currentFieldName)) { + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + builder.requireClusterFeature(parser.text()); + } } } } diff --git a/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuiteTests.java b/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuiteTests.java index edc043e15527d..1f5bdc71dde37 100644 --- a/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuiteTests.java +++ b/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/ClientYamlTestSuiteTests.java @@ -468,6 +468,41 @@ public void testParseSkipOs() throws Exception { assertThat(restTestSuite.getTestSections().get(0).getPrerequisiteSection().hasYamlRunnerFeature("skip_os"), equalTo(true)); } + public void testParseSkipAndRequireClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + "Broken on some os": + + - skip: + cluster_features: [unsupported-feature1, unsupported-feature2] + reason: "unsupported-features are not supported" + - requires: + cluster_features: required-feature1 + reason: "required-feature1 is required" + - do: + indices.get_mapping: + index: test_index + type: test_type + + - match: {test_type.properties.text.type: string} + - match: {test_type.properties.text.analyzer: whitespace} + """); + + ClientYamlTestSuite restTestSuite = ClientYamlTestSuite.parse(getTestClass().getName(), getTestName(), Optional.empty(), parser); + + assertThat(restTestSuite, notNullValue()); + assertThat(restTestSuite.getName(), equalTo(getTestName())); + assertThat(restTestSuite.getFile().isPresent(), equalTo(false)); + assertThat(restTestSuite.getTestSections().size(), equalTo(1)); + + assertThat(restTestSuite.getTestSections().get(0).getName(), equalTo("Broken on some os")); + assertThat(restTestSuite.getTestSections().get(0).getPrerequisiteSection().isEmpty(), equalTo(false)); + assertThat( + restTestSuite.getTestSections().get(0).getPrerequisiteSection().skipReason, + equalTo("unsupported-features are not supported") + ); + assertThat(restTestSuite.getTestSections().get(0).getPrerequisiteSection().requireReason, equalTo("required-feature1 is required")); + } + public void testParseFileWithSingleTestSection() throws Exception { final Path filePath = createTempFile("tyf", ".yml"); Files.writeString(filePath, """ @@ -541,7 +576,7 @@ public void testAddingDoWithWarningWithoutSkipWarnings() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [warnings] section without a corresponding ["skip": "features": "warnings"] \ + attempted to add a [do] with a [warnings] section without a corresponding ["requires": "test_runner_features": "warnings"] \ so runners that do not support the [warnings] section can skip the test at line [%d]\ """, lineNumber))); } @@ -555,7 +590,8 @@ public void testAddingDoWithWarningRegexWithoutSkipWarnings() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [warnings_regex] section without a corresponding ["skip": "features": "warnings_regex"] \ + attempted to add a [do] with a [warnings_regex] section without a corresponding \ + ["requires": "test_runner_features": "warnings_regex"] \ so runners that do not support the [warnings_regex] section can skip the test at line [%d]\ """, lineNumber))); } @@ -569,7 +605,7 @@ public void testAddingDoWithAllowedWarningWithoutSkipAllowedWarnings() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [allowed_warnings] section without a corresponding ["skip": "features": \ + attempted to add a [do] with a [allowed_warnings] section without a corresponding ["requires": "test_runner_features": \ "allowed_warnings"] so runners that do not support the [allowed_warnings] section can skip the test at \ line [%d]\ """, lineNumber))); @@ -584,7 +620,7 @@ public void testAddingDoWithAllowedWarningRegexWithoutSkipAllowedWarnings() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [allowed_warnings_regex] section without a corresponding ["skip": "features": \ + attempted to add a [do] with a [allowed_warnings_regex] section without a corresponding ["requires": "test_runner_features": \ "allowed_warnings_regex"] so runners that do not support the [allowed_warnings_regex] section can skip the test \ at line [%d]\ """, lineNumber))); @@ -600,7 +636,7 @@ public void testAddingDoWithHeaderWithoutSkipHeaders() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [headers] section without a corresponding ["skip": "features": "headers"] \ + attempted to add a [do] with a [headers] section without a corresponding ["requires": "test_runner_features": "headers"] \ so runners that do not support the [headers] section can skip the test at line [%d]\ """, lineNumber))); } @@ -615,7 +651,8 @@ public void testAddingDoWithNodeSelectorWithoutSkipNodeSelector() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [do] with a [node_selector] section without a corresponding ["skip": "features": "node_selector"] \ + attempted to add a [do] with a [node_selector] section without a corresponding \ + ["requires": "test_runner_features": "node_selector"] \ so runners that do not support the [node_selector] section can skip the test at line [%d]\ """, lineNumber))); } @@ -631,7 +668,7 @@ public void testAddingContainsWithoutSkipContains() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertThat(e.getMessage(), containsString(Strings.format(""" api/name: - attempted to add a [contains] assertion without a corresponding ["skip": "features": "contains"] \ + attempted to add a [contains] assertion without a corresponding ["requires": "test_runner_features": "contains"] \ so runners that do not support the [contains] assertion can skip the test at line [%d]\ """, lineNumber))); } @@ -683,13 +720,15 @@ public void testMultipleValidationErrors() { Exception e = expectThrows(IllegalArgumentException.class, testSuite::validate); assertEquals(Strings.format(""" api/name: - attempted to add a [contains] assertion without a corresponding ["skip": "features": "contains"] so runners that \ - do not support the [contains] assertion can skip the test at line [%d], - attempted to add a [do] with a [warnings] section without a corresponding ["skip": "features": "warnings"] so runners \ - that do not support the [warnings] section can skip the test at line [%d], - attempted to add a [do] with a [node_selector] section without a corresponding ["skip": "features": "node_selector"] so \ - runners that do not support the [node_selector] section can skip the test \ - at line [%d]\ + attempted to add a [contains] assertion without a corresponding \ + ["requires": "test_runner_features": "contains"] \ + so runners that do not support the [contains] assertion can skip the test at line [%d], + attempted to add a [do] with a [warnings] section without a corresponding \ + ["requires": "test_runner_features": "warnings"] \ + so runners that do not support the [warnings] section can skip the test at line [%d], + attempted to add a [do] with a [node_selector] section without a corresponding \ + ["requires": "test_runner_features": "node_selector"] \ + so runners that do not support the [node_selector] section can skip the test at line [%d]\ """, firstLineNumber, secondLineNumber, thirdLineNumber), e.getMessage()); } diff --git a/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSectionTests.java b/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSectionTests.java index b02658694d82f..181ec34fefb7e 100644 --- a/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSectionTests.java +++ b/test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/PrerequisiteSectionTests.java @@ -363,8 +363,10 @@ public void testParseSkipSectionOsListNoVersion() throws Exception { public void testParseSkipSectionOsListTestFeaturesInRequires() throws Exception { parser = createParser(YamlXContent.yamlXContent, """ + - requires: + test_runner_features: skip_os + reason: skip_os is needed for skip based on os - skip: - features: [skip_os] os: [debian-9,windows-95,ms-dos] reason: see gh#xyz """); @@ -391,6 +393,95 @@ public void testParseSkipSectionOsNoFeatureNoVersion() throws Exception { assertThat(e.getMessage(), is("if os is specified, test runner feature [skip_os] must be set")); } + public void testParseRequireSectionClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + cluster_features: needed-feature + reason: test skipped when cluster lacks needed-feature + """); + + var skipSectionBuilder = new PrerequisiteSection.PrerequisiteSectionBuilder(); + PrerequisiteSection.parseRequiresSection(parser, skipSectionBuilder); + assertThat(skipSectionBuilder, notNullValue()); + assertThat(skipSectionBuilder.skipVersionRange, emptyOrNullString()); + assertThat(skipSectionBuilder.requiredClusterFeatures, contains("needed-feature")); + assertThat(skipSectionBuilder.requiresReason, is("test skipped when cluster lacks needed-feature")); + } + + public void testParseSkipSectionClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + cluster_features: undesired-feature + reason: test skipped when undesired-feature is present + """); + + var skipSectionBuilder = new PrerequisiteSection.PrerequisiteSectionBuilder(); + PrerequisiteSection.parseSkipSection(parser, skipSectionBuilder); + assertThat(skipSectionBuilder, notNullValue()); + assertThat(skipSectionBuilder.skipVersionRange, emptyOrNullString()); + assertThat(skipSectionBuilder.skipClusterFeatures, contains("undesired-feature")); + assertThat(skipSectionBuilder.skipReason, is("test skipped when undesired-feature is present")); + } + + public void testParseRequireAndSkipSectionsClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + - requires: + cluster_features: needed-feature + reason: test needs needed-feature to run + - skip: + cluster_features: undesired-feature + reason: test cannot run when undesired-feature are present + """); + + var skipSectionBuilder = PrerequisiteSection.parseInternal(parser); + assertThat(skipSectionBuilder, notNullValue()); + assertThat(skipSectionBuilder.skipVersionRange, emptyOrNullString()); + assertThat(skipSectionBuilder.skipClusterFeatures, contains("undesired-feature")); + assertThat(skipSectionBuilder.requiredClusterFeatures, contains("needed-feature")); + assertThat(skipSectionBuilder.skipReason, is("test cannot run when undesired-feature are present")); + assertThat(skipSectionBuilder.requiresReason, is("test needs needed-feature to run")); + + assertThat(parser.currentToken(), equalTo(XContentParser.Token.END_ARRAY)); + assertThat(parser.nextToken(), nullValue()); + } + + public void testParseRequireAndSkipSectionMultipleClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + - requires: + cluster_features: [needed-feature-1, needed-feature-2] + reason: test needs some to run + - skip: + cluster_features: [undesired-feature-1, undesired-feature-2] + reason: test cannot run when some are present + """); + + var skipSectionBuilder = PrerequisiteSection.parseInternal(parser); + assertThat(skipSectionBuilder, notNullValue()); + assertThat(skipSectionBuilder.skipVersionRange, emptyOrNullString()); + assertThat(skipSectionBuilder.skipClusterFeatures, containsInAnyOrder("undesired-feature-1", "undesired-feature-2")); + assertThat(skipSectionBuilder.requiredClusterFeatures, containsInAnyOrder("needed-feature-1", "needed-feature-2")); + assertThat(skipSectionBuilder.skipReason, is("test cannot run when some are present")); + assertThat(skipSectionBuilder.requiresReason, is("test needs some to run")); + + assertThat(parser.currentToken(), equalTo(XContentParser.Token.END_ARRAY)); + assertThat(parser.nextToken(), nullValue()); + } + + public void testParseSameRequireAndSkipClusterFeatures() throws Exception { + parser = createParser(YamlXContent.yamlXContent, """ + - requires: + cluster_features: some-feature + reason: test needs some-feature to run + - skip: + cluster_features: some-feature + reason: test cannot run with some-feature + """); + + var e = expectThrows(ParsingException.class, () -> PrerequisiteSection.parseInternal(parser)); + assertThat(e.getMessage(), is("a cluster feature can be specified either in [requires] or [skip], not both")); + + assertThat(parser.currentToken(), equalTo(XContentParser.Token.END_ARRAY)); + assertThat(parser.nextToken(), nullValue()); + } + public void testSkipClusterFeaturesAllRequiredMatch() { PrerequisiteSection section = new PrerequisiteSection( emptyList(), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQueryTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQueryTests.java index 1d62bc0b6eaaa..55e8ba164ba70 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQueryTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQueryTests.java @@ -77,6 +77,7 @@ public void testMatchAll() throws IOException { testCase(new SingleValueQuery(new MatchAll(Source.EMPTY), "foo").asBuilder(), false, false, this::runCase); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105952") public void testMatchSome() throws IOException { int max = between(1, 100); testCase( diff --git a/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/DownsampleActionIT.java b/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/DownsampleActionIT.java index ec9fad3e5077d..6d34fb0eced79 100644 --- a/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/DownsampleActionIT.java +++ b/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/DownsampleActionIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.rest.action.admin.indices.RestPutIndexTemplateAction; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -395,7 +396,7 @@ public void testILMWaitsForTimeSeriesEndTimeToLapse() throws Exception { }, 30, TimeUnit.SECONDS); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/103981") + @TestLogging(value = "org.elasticsearch.xpack.ilm:TRACE", reason = "https://github.com/elastic/elasticsearch/issues/103981") public void testRollupNonTSIndex() throws Exception { createIndex(index, alias, false); index(client(), index, true, null, "@timestamp", "2020-01-01T05:10:00Z", "volume", 11.0, "metricset", randomAlphaOfLength(5)); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 99dfc9582eb05..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -101,11 +101,6 @@ public TestServiceModel( super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); } - @Override - public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() { - return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings(); - } - @Override public TestTaskSettings getTaskSettings() { return (TestTaskSettings) super.getTaskSettings(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index e5020774a70f3..33bbc94901e9d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -115,7 +115,7 @@ private SparseEmbeddingResults makeResults(List input) { for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F)); } embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); } @@ -127,7 +127,7 @@ private List makeChunkedResults(List inp for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 9e6c1eb0a6586..dbde641d8f757 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; @@ -25,25 +25,25 @@ import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; /** * A mapper for the {@code _semantic_text_inference} field. @@ -102,16 +102,13 @@ */ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = "_semantic_text_inference"; + public static final String NAME = ROOT_INFERENCE_FIELD; public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); - private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( - List.of(), - Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME) - ); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + static class SemanticTextInferenceFieldType extends MappedFieldType { private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); @@ -142,75 +139,86 @@ private SemanticTextInferenceResultFieldMapper() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parseInferenceResults(context); + parseAllFields(context); } - private static void parseInferenceResults(DocumentParserContext context) throws IOException { + private static void parseAllFields(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + token); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResults(context, mapperBuilderContext); + parseSingleField(context, mapperBuilderContext); } } - private static void parseFieldInferenceResults(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) - throws IOException { + private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - String fieldName = context.parser().currentName(); + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); Mapper mapper = context.getMapper(fieldName); if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new DocumentParsingException( - context.parser().getTokenLocation(), + parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ); } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); + parser.nextToken(); + ModelSettings modelSettings = ModelSettings.parse(parser); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResultsArray(context, mapperBuilderContext, fieldName); + String currentName = parser.currentName(); + if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( + context, + mapperBuilderContext, + fieldName, + modelSettings + ); + parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); + } else { + logger.debug("Skipping unrecognized field name [" + currentName + "]"); + advancePastCurrentFieldName(parser); + } + } } - private static void parseFieldInferenceResultsArray( + private static void parseFieldInferenceChunks( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings, + NestedObjectMapper nestedObjectMapper ) throws IOException { XContentParser parser = context.parser(); - NestedObjectMapper nestedObjectMapper = createNestedObjectMapper(context, mapperBuilderContext, fieldName); - if (parser.nextToken() != XContentParser.Token.START_ARRAY) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_ARRAY, got " + parser.currentToken()); - } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceResultElement(nestedContext, nestedObjectMapper, new LinkedList<>()); + parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); } } - private static void parseFieldInferenceResultElement( + private static void parseFieldInferenceChunkElement( DocumentParserContext context, ObjectMapper objectMapper, - LinkedList subfieldPath + ModelSettings modelSettings ) throws IOException { XContentParser parser = context.parser(); DocumentParserContext childContext = context.createChildContext(objectMapper); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); Set visitedSubfields = new HashSet<>(); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); visitedSubfields.add(currentName); @@ -222,14 +230,9 @@ private static void parseFieldInferenceResultElement( continue; } - if (childMapper instanceof FieldMapper) { + if (childMapper instanceof FieldMapper fieldMapper) { parser.nextToken(); - ((FieldMapper) childMapper).parse(childContext); - } else if (childMapper instanceof ObjectMapper) { - parser.nextToken(); - subfieldPath.push(currentName); - parseFieldInferenceResultElement(childContext, (ObjectMapper) childMapper, subfieldPath); - subfieldPath.pop(); + fieldMapper.parse(childContext); } else { // This should never happen, but fail parsing if it does so that it's not a silent failure throw new DocumentParsingException( @@ -239,29 +242,51 @@ private static void parseFieldInferenceResultElement( } } - Set requiredSubfields = REQUIRED_SUBFIELDS_MAP.get(subfieldPath); - if (requiredSubfields != null && visitedSubfields.containsAll(requiredSubfields) == false) { - Set missingSubfields = requiredSubfields.stream() + if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() .filter(s -> visitedSubfields.contains(s) == false) .collect(Collectors.toSet()); throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); } } - private static NestedObjectMapper createNestedObjectMapper( + private static NestedObjectMapper createInferenceResultsObjectMapper( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings ) { IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - ObjectMapper.Builder sparseVectorMapperBuilder = new ObjectMapper.Builder( - SPARSE_VECTOR_SUBFIELD_NAME, - ObjectMapper.Defaults.SUBOBJECTS - ).add( - new BooleanFieldMapper.Builder(SparseEmbeddingResults.Embedding.IS_TRUNCATED, ScriptCompiler.NONE, false, indexVersionCreated) - ).add(new SparseVectorFieldMapper.Builder(SparseEmbeddingResults.Embedding.EMBEDDING)); + FieldMapper.Builder resultsBuilder; + if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { + resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + Integer dimensions = modelSettings.dimensions(); + if (dimensions == null) { + throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); + } + denseVectorMapperBuilder.dimensions(dimensions); + resultsBuilder = denseVectorMapperBuilder; + } else { + throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); + } + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - TEXT_SUBFIELD_NAME, + INFERENCE_CHUNKS_TEXT, indexVersionCreated, context.indexAnalyzers() ).index(false).store(false); @@ -270,7 +295,7 @@ private static NestedObjectMapper createNestedObjectMapper( fieldName, context.indexSettings().getIndexVersionCreated() ); - nestedBuilder.add(sparseVectorMapperBuilder).add(textMapperBuilder); + nestedBuilder.add(resultsBuilder).add(textMapperBuilder); return nestedBuilder.build(mapperBuilderContext); } @@ -286,6 +311,15 @@ private static void advancePastCurrentFieldName(XContentParser parser) throws IO } } + private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException( + parser.getTokenLocation(), + "Expected a " + expected.toString() + ", got " + parser.currentToken() + ); + } + } + @Override protected String contentType() { return CONTENT_TYPE; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index aa2ad72941e0e..06a665ade3ab4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -31,6 +31,8 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; @@ -51,8 +53,9 @@ import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { @@ -214,7 +217,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]")); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]")); } { DocumentParsingException ex = expectThrows( @@ -232,7 +235,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]")); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]")); } { DocumentParsingException ex = expectThrows( @@ -252,7 +255,7 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + ", " + TEXT_SUBFIELD_NAME + "]") + containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]") ); } } @@ -411,8 +414,10 @@ private static void addSemanticTextInferenceResults( Map extraSubfields ) throws IOException { - Map>> inferenceResultsMap = new HashMap<>(); + Map> inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { + Map fieldMap = new HashMap<>(); + fieldMap.put(ModelSettings.NAME, modelSettingsMap()); List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() @@ -425,17 +430,10 @@ private static void addSemanticTextInferenceResults( Map subfieldMap = new HashMap<>(); if (sparseVectorSubfieldOptions.include()) { - Map embeddingMap = embedding.asMap(); - if (sparseVectorSubfieldOptions.includeIsTruncated() == false) { - embeddingMap.remove(SparseEmbeddingResults.Embedding.IS_TRUNCATED); - } - if (sparseVectorSubfieldOptions.includeEmbedding() == false) { - embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING); - } - subfieldMap.put(SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); + subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); } if (includeTextSubfield) { - subfieldMap.put(TEXT_SUBFIELD_NAME, text); + subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); } if (extraSubfields != null) { subfieldMap.putAll(extraSubfields); @@ -444,28 +442,42 @@ private static void addSemanticTextInferenceResults( parsedInferenceResults.add(subfieldMap); } - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), parsedInferenceResults); + fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); + inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); } + private static Map modelSettingsMap() { + return Map.of( + ModelSettings.TASK_TYPE_FIELD.getPreferredName(), + TaskType.SPARSE_EMBEDDING.toString(), + ModelSettings.INFERENCE_ID_FIELD.getPreferredName(), + randomAlphaOfLength(8) + ); + } + private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { mappingBuilder.startObject(semanticTextFieldName); - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SPARSE_VECTOR_SUBFIELD_NAME); - mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING); - mappingBuilder.field("type", "sparse_vector"); - mappingBuilder.endObject(); - mappingBuilder.endObject(); - mappingBuilder.endObject(); - mappingBuilder.startObject(TEXT_SUBFIELD_NAME); - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - mappingBuilder.endObject(); - mappingBuilder.endObject(); + { + mappingBuilder.field("type", "nested"); + mappingBuilder.startObject("properties"); + { + mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); + { + mappingBuilder.field("type", "sparse_vector"); + } + mappingBuilder.endObject(); + mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); + { + mappingBuilder.field("type", "text"); + mappingBuilder.field("index", false); + } + mappingBuilder.endObject(); + } + mappingBuilder.endObject(); + } mappingBuilder.endObject(); } @@ -477,12 +489,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); for (String token : tokens) { queryBuilder.add( - new BooleanClause( - new TermQuery( - new Term(path + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token) - ), - BooleanClause.Occur.MUST - ) + new BooleanClause(new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)), BooleanClause.Occur.MUST) ); } queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); @@ -497,12 +504,7 @@ private static void assertValidChildDoc( ) { assertEquals(expectedParent, childDoc.getParent()); visitedChildDocs.add( - new VisitedChildDocInfo( - childDoc.getPath(), - childDoc.getFields( - childDoc.getPath() + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING - ).size() - ) + new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) ); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 933e696d29d83..a397d9864d23d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -21,7 +21,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") - .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .plugin("inference-service-test") .distribution(DistributionType.DEFAULT) .build(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 0e1b33252153b..ead7f904ad57b 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -6,7 +6,7 @@ setup: - do: inference.put_model: task_type: sparse_embedding - inference_id: test-inference-id + inference_id: sparse-inference-id body: > { "service": "test_service", @@ -17,27 +17,57 @@ setup: "task_settings": { } } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: sparse-inference-id + another_inference_field: + type: semantic_text + model_id: sparse-inference-id + non_inference_field: + type: text - do: indices.create: - index: test-index + index: test-dense-index body: mappings: properties: inference_field: type: semantic_text - model_id: test-inference-id + model_id: dense-inference-id another_inference_field: type: semantic_text - model_id: test-inference-id + model_id: dense-inference-id non_inference_field: type: text --- -"Calculates embeddings for new documents": +"Calculates text expansion results for new documents": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -46,24 +76,73 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "inference test" } - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + + - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference + - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + +--- +"text expansion documents do not create new mappings": + - do: + indices.get_mapping: + index: test-sparse-index + + - match: {test-sparse-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.non_inference_field.type: text} + - length: {test-sparse-index.mappings.properties: 3} + +--- +"Calculates text embeddings results for new documents": + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-dense-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding - - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference + - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + + +--- +"text embeddings documents do not create new mappings": + - do: + indices.get_mapping: + index: test-dense-index + + - match: {test-dense-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.non_inference_field.type: text} + - length: {test-dense-index.mappings.properties: 3} --- "Updating non semantic_text fields does not recalculate embeddings": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -72,15 +151,15 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } - do: update: - index: test-index + index: test-sparse-index id: doc_1 body: doc: @@ -88,24 +167,24 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "inference test" } - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -114,12 +193,12 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - do: update: - index: test-index + index: test-sparse-index id: doc_1 body: doc: @@ -128,22 +207,21 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "updated inference test" } - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" } - + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -152,11 +230,11 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -169,10 +247,10 @@ setup: properties: inference_field: type: semantic_text - model_id: test-inference-id + model_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: test-inference-id + model_id: sparse-inference-id non_inference_field: type: text @@ -181,7 +259,7 @@ setup: wait_for_completion: true body: source: - index: test-index + index: test-sparse-index dest: index: destination-index - do: @@ -193,17 +271,17 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": - do: indices.create: - index: incorrect-test-index + index: incorrect-test-sparse-index body: mappings: properties: @@ -216,7 +294,7 @@ setup: - do: catch: bad_request index: - index: incorrect-test-index + index: incorrect-test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -227,7 +305,7 @@ setup: # Succeeds when semantic_text field is not used - do: index: - index: incorrect-test-index + index: incorrect-test-sparse-index id: doc_1 body: non_inference_field: "non inference test" diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml new file mode 100644 index 0000000000000..da61e6e403ed8 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -0,0 +1,153 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + model_id: sparse-inference-id + dense_field: + type: semantic_text + model_id: dense-inference-id + non_inference_field: + type: text + +--- +"Sparse vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + inference_id: sparse-inference-id + task_type: sparse_embedding + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + feature_2: 0.2 + feature_3: 0.3 + feature_4: 0.4 + - text: "another inference test" + inference: + feature_1: 0.1 + feature_2: 0.2 + feature_3: 0.3 + feature_4: 0.4 + +--- +"Dense vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + dense_field: + model_settings: + inference_id: sparse-inference-id + task_type: text_embedding + dimensions: 5 + similarity: cosine + inference_results: + - text: "inference test" + inference: [0.1, 0.2, 0.3, 0.4, 0.5] + - text: "another inference test" + inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + +--- +"Model settings inference id not included": + - do: + catch: /Required \[inference_id\]/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + task_type: sparse_embedding + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + +--- +"Model settings task type not included": + - do: + catch: /Required \[task_type\]/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + inference_id: sparse-inference-id + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + +--- +"Model settings dense vector dimensions not included": + - do: + catch: /Model settings for field \[dense_field\] must contain dimensions/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + dense_field: + model_settings: + inference_id: sparse-inference-id + task_type: text_embedding + inference_results: + - text: "inference test" + inference: [0.1, 0.2, 0.3, 0.4, 0.5] + - text: "another inference test" + inference: [-0.1, -0.2, -0.3, -0.4, -0.5] diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java index 4d90d2a186858..d8c9dc2efd927 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java @@ -59,10 +59,10 @@ public void testIngestWithMultipleModelTypes() throws IOException { assertThat(simulatedDocs, hasSize(2)); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); var sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); } { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 675d062fdb3af..f6fa7ca9005c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.xcontent.ParseField; @@ -67,12 +68,7 @@ public String getTypeName() { } public static boolean isFieldTypeAllowed(String typeName) { - for (AllowedFieldType fieldType : values()) { - if (fieldType.getTypeName().equals(typeName)) { - return true; - } - } - return false; + return Arrays.stream(values()).anyMatch(value -> value.typeName.equals(typeName)); } public static String getAllowedFieldTypesAsString() { @@ -168,8 +164,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep } @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { if (weightedTokensSupplier != null) { if (weightedTokensSupplier.get() == null) { return this; @@ -188,8 +183,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); SetOnce textExpansionResultsSupplier = new SetOnce<>(); - queryRewriteContext.registerAsyncAction((client, listener) -> { - executeAsyncWithOrigin( + queryRewriteContext.registerAsyncAction( + (client, listener) -> executeAsyncWithOrigin( client, ML_ORIGIN, CoordinatedInferenceAction.INSTANCE, @@ -220,21 +215,34 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws ); } }, listener::onFailure) - ); - }); + ) + ); return new TextExpansionQueryBuilder(this, textExpansionResultsSupplier); } private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) { - WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( - fieldName, - textExpansionResults.getWeightedTokens(), - tokenPruningConfig - ); - weightedTokensQueryBuilder.queryName(queryName); - weightedTokensQueryBuilder.boost(boost); - return weightedTokensQueryBuilder; + if (tokenPruningConfig != null) { + WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( + fieldName, + textExpansionResults.getWeightedTokens(), + tokenPruningConfig + ); + weightedTokensQueryBuilder.queryName(queryName); + weightedTokensQueryBuilder.boost(boost); + return weightedTokensQueryBuilder; + } + // Note: Weighted tokens queries were introduced in 8.13.0. To support mixed version clusters prior to 8.13.0, + // if no token pruning configuration is specified we fall back to a boolean query. + // TODO this should be updated to always use a WeightedTokensQueryBuilder once it's in all supported versions. + var boolQuery = QueryBuilders.boolQuery(); + for (var weightedToken : textExpansionResults.getWeightedTokens()) { + boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); + } + boolQuery.minimumShouldMatch(1); + boolQuery.boost(boost); + boolQuery.queryName(queryName); + return boolQuery; } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 50561d92f5d37..13f12f3cdc1e1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.Plugin; @@ -259,6 +260,10 @@ public void testThatTokensAreCorrectlyPruned() { SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); - assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); + if (queryBuilder.getTokenPruningConfig() == null) { + assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); + } else { + assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); + } } } diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java index 18b4e6ed7cb31..4b9e1b0d9211e 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java @@ -102,7 +102,10 @@ public void testCreateAndRestorePartialSearchableSnapshot() throws Exception { // we can bypass this by forcing soft deletes to be used. TODO this restriction can be lifted when #55142 is resolved. final Settings.Builder originalIndexSettings = Settings.builder().put(INDEX_SOFT_DELETES_SETTING.getKey(), true); if (randomBoolean()) { - originalIndexSettings.put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), randomFrom("false", "true", "checksum")); + // INDEX_CHECK_ON_STARTUP requires expensive processing due to verification the integrity of many important files during + // a shard recovery or relocation. Therefore, it takes lots of time for the files to clean up and the assertShardFolder + // check may not complete in 30s. + originalIndexSettings.put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), "false"); } assertAcked(prepareCreate(indexName, originalIndexSettings)); assertAcked(indicesAdmin().prepareAliases().addAlias(indexName, aliasName)); diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java index 38222f64b282b..ddd9f40b5404c 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java @@ -111,7 +111,7 @@ public void testCreateAndRestoreSearchableSnapshot() throws Exception { // we can bypass this by forcing soft deletes to be used. TODO this restriction can be lifted when #55142 is resolved. final Settings.Builder originalIndexSettings = Settings.builder().put(INDEX_SOFT_DELETES_SETTING.getKey(), true); if (randomBoolean()) { - originalIndexSettings.put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), randomFrom("false", "true", "checksum")); + originalIndexSettings.put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), "false"); } assertAcked(prepareCreate(indexName, originalIndexSettings)); assertAcked(indicesAdmin().prepareAliases().addAlias(indexName, aliasName)); diff --git a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/TransportSLMGetExpiredSnapshotsActionTests.java b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/TransportSLMGetExpiredSnapshotsActionTests.java index 573edc6e517bf..e6d7a66a2bdb3 100644 --- a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/TransportSLMGetExpiredSnapshotsActionTests.java +++ b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/TransportSLMGetExpiredSnapshotsActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.common.settings.Settings; @@ -286,7 +287,7 @@ private static Repository createMockRepository(ThreadPool threadPool, List consumer = invocation.getArgument(3); final ActionListener listener = invocation.getArgument(4); - final Set snapshotIds = new HashSet<>(snapshotIdCollection); - for (SnapshotInfo snapshotInfo : snapshotInfos) { - if (snapshotIds.remove(snapshotInfo.snapshotId())) { - threadPool.generic().execute(() -> { - try { - consumer.accept(snapshotInfo); - } catch (Exception e) { - fail(e); - } - }); + try (var refs = new RefCountingRunnable(() -> listener.onResponse(null))) { + final Set snapshotIds = new HashSet<>(snapshotIdCollection); + for (SnapshotInfo snapshotInfo : snapshotInfos) { + if (snapshotIds.remove(snapshotInfo.snapshotId())) { + threadPool.generic().execute(ActionRunnable.run(refs.acquireListener(), () -> { + try { + consumer.accept(snapshotInfo); + } catch (Exception e) { + fail(e); + } + })); + } } } - listener.onResponse(null); return null; }).when(repository).getSnapshotInfo(any(), anyBoolean(), any(), any(), any()); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/120_profile.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/120_profile.yml index 81d87435ad39e..c2e728535a408 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/120_profile.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/120_profile.yml @@ -121,11 +121,10 @@ setup: --- avg 8.14 or after: - skip: - features: ["node_selector"] + version: " - 8.13.99" + reason: "avg changed starting 8.14" - do: - node_selector: - version: "8.13.99 - " esql.query: body: query: 'FROM test | STATS AVG(data) | LIMIT 1' diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml index dc4e1751ccdee..f92870b61f1b1 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml @@ -304,3 +304,4 @@ setup: source_text: model_id: text_expansion_model model_text: "octopus comforter smells" + pruning_config: {}