From da4a5506b38008462526dd45c9873b55c26812ac Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 9 Feb 2024 12:39:25 -0500 Subject: [PATCH 01/40] Added skeleton code for SemanticQueryBuilder --- .../queries/SemanticQueryBuilder.java | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java new file mode 100644 index 0000000000000..6808075ac69cb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SemanticQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "semantic_query"; + + private static final ParseField QUERY_FIELD = new ParseField("query"); + + private final String fieldName; + private final String query; + + public SemanticQueryBuilder(String fieldName, String query) { + if (fieldName == null) { + throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); + } + if (query == null) { + throw new IllegalArgumentException("[" + NAME + "] requires a " + QUERY_FIELD.getPreferredName() + " value"); + } + this.fieldName = fieldName; + this.query = query; + } + + public SemanticQueryBuilder(StreamInput in) throws IOException { + super(in); + this.fieldName = in.readString(); + this.query = in.readString(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.SEMANTIC_TEXT_FIELD_ADDED; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeString(query); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.startObject(fieldName); + builder.field(QUERY_FIELD.getPreferredName(), query); + builder.endObject(); + builder.endObject(); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + // TODO: Implement + return null; + } + + @Override + protected boolean doEquals(SemanticQueryBuilder other) { + return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query); + } + + @Override + protected int doHashCode() { + return Objects.hash(fieldName, query); + } +} From 330f867a6901351d27a33c26ca8730395a7efeb7 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 9 Feb 2024 12:41:33 -0500 Subject: [PATCH 02/40] Add boost and query name to XContent --- .../xpack/inference/queries/SemanticQueryBuilder.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 6808075ac69cb..c44a7125a6fd7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -66,6 +66,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.startObject(NAME); builder.startObject(fieldName); builder.field(QUERY_FIELD.getPreferredName(), query); + boostAndQueryNameToXContent(builder); builder.endObject(); builder.endObject(); } From ed2abf7932505eeab81c82658e4a231810edc180 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 14 Feb 2024 18:17:55 +0100 Subject: [PATCH 03/40] Add dimensions and similarity to ServiceSettings, create ModelSettings class --- .../inference/ModelSettings.java | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/inference/ModelSettings.java diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java new file mode 100644 index 0000000000000..f11f5e67f80fe --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) { + + public static final String NAME = "model_settings"; + private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + + public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + Objects.requireNonNull(taskType, "task type must not be null"); + Objects.requireNonNull(inferenceId, "inferenceId must not be null"); + this.taskType = taskType; + this.inferenceId = inferenceId; + this.dimensions = dimensions; + this.similarity = similarity; + } + + public ModelSettings(Model model) { + this( + model.getTaskType(), + model.getInferenceEntityId(), + model.getServiceSettings().dimensions(), + model.getServiceSettings().similarity() + ); + } + + public static ModelSettings parse(XContentParser parser) throws IOException { + return PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + String inferenceId = (String) args[1]; + Integer dimensions = (Integer) args[2]; + SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new ModelSettings(taskType, inferenceId, dimensions, similarity); + }); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); + } + + public Map asMap() { + Map attrsMap = new HashMap<>(); + attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + if (dimensions != null) { + attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return Map.of(NAME, attrsMap); + } +} From 1ee76531d29e3fae4b374bc921d1d023b2fdebc7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 14 Feb 2024 18:19:05 +0100 Subject: [PATCH 04/40] Change implementation of asMap() to avoid extra nesting in inference results --- .../core/inference/results/LegacyTextEmbeddingResults.java | 4 ++-- .../xpack/core/inference/results/TextEmbeddingResults.java | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index 72a24fd916763..31394c31df09e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -80,7 +80,7 @@ public String getResultsField() { @Override public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(getResultsField(), embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); + map.put(getResultsField(), embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); return map; } @@ -88,7 +88,7 @@ public Map asMap() { @Override public Map asMap(String outputField) { Map map = new LinkedHashMap<>(); - map.put(outputField, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); + map.put(outputField, embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); return map; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 15271c1da58fa..24e53ef822a8c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -129,10 +128,7 @@ public List transformToLegacyFormat() { } public Map asMap() { - Map map = new LinkedHashMap<>(); - map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); - - return map; + return Map.of(TEXT_EMBEDDING, embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); } public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt { From 80c01c7c34e1ec7e9d82e42e38760775fba1e9e0 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 14 Feb 2024 19:11:56 +0100 Subject: [PATCH 05/40] Fix BulkOperationTests --- .../java/org/elasticsearch/inference/ModelSettings.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java index f11f5e67f80fe..154d4d34ba74d 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -21,10 +21,10 @@ public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) { public static final String NAME = "model_settings"; - private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); From a4c1184618a9b9e3d629a1dfc626596cc074ff5a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 14 Feb 2024 20:32:31 +0100 Subject: [PATCH 06/40] Fix tests of SemanticTextInferenceResultFieldMapper --- .../inference/ModelSettings.java | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java index 154d4d34ba74d..654a65cd2f489 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -8,7 +8,6 @@ package org.elasticsearch.inference; -import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -18,13 +17,17 @@ import java.util.Map; import java.util.Objects; -public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) { +public class ModelSettings { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + private final TaskType taskType; + private final String inferenceId; + private final Integer dimensions; + private final SimilarityMeasure similarity; public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); @@ -74,4 +77,20 @@ public Map asMap() { } return Map.of(NAME, attrsMap); } + + public TaskType taskType() { + return taskType; + } + + public String inferenceId() { + return inferenceId; + } + + public Integer dimensions() { + return dimensions; + } + + public SimilarityMeasure similarity() { + return similarity; + } } From f28e0512943c70aa25ba3d7f99a88cee6555a129 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 15 Feb 2024 10:23:23 +0100 Subject: [PATCH 07/40] Revert "Change implementation of asMap() to avoid extra nesting in inference results" This reverts commit bd4e19e13f3ad6d54d9a2c38d34562a91b47ac8d. --- .../core/inference/results/LegacyTextEmbeddingResults.java | 4 ++-- .../xpack/core/inference/results/TextEmbeddingResults.java | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index 31394c31df09e..72a24fd916763 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -80,7 +80,7 @@ public String getResultsField() { @Override public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(getResultsField(), embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); + map.put(getResultsField(), embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); return map; } @@ -88,7 +88,7 @@ public Map asMap() { @Override public Map asMap(String outputField) { Map map = new LinkedHashMap<>(); - map.put(outputField, embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); + map.put(outputField, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); return map; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 24e53ef822a8c..15271c1da58fa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -128,7 +129,10 @@ public List transformToLegacyFormat() { } public Map asMap() { - return Map.of(TEXT_EMBEDDING, embeddings.stream().flatMap(v -> v.values.stream()).collect(Collectors.toList())); + Map map = new LinkedHashMap<>(); + map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); + + return map; } public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt { From 0d515a492eb9b651cb6f6eb5fd0e15124993a36e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 15 Feb 2024 12:29:05 +0100 Subject: [PATCH 08/40] Add service extension for dense vector embeddings --- ...nferenceTextEmbeddingServiceExtension.java | 392 ++++++++++++++++++ 1 file changed, 392 insertions(+) create mode 100644 x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java new file mode 100644 index 0000000000000..e7fe171842946 --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java @@ -0,0 +1,392 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class TestInferenceTextEmbeddingServiceExtension implements InferenceServiceExtension { + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + public static class TestInferenceService implements InferenceService { + private static final String NAME = "text_embedding_test_service"; + + public TestInferenceService(InferenceServiceFactoryContext context) {} + + @Override + public String name() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @SuppressWarnings("unchecked") + private static Map getTaskSettingsMap(Map settings) { + Map taskSettingsMap; + // task settings are optional + if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) { + taskSettingsMap = (Map) settings.remove(ModelConfigurations.TASK_SETTINGS); + } else { + taskSettingsMap = Map.of(); + } + + return taskSettingsMap; + } + + @Override + @SuppressWarnings("unchecked") + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); + } + + @Override + @SuppressWarnings("unchecked") + public TestServiceModel parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); + + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); + } + + @Override + @SuppressWarnings("unchecked") + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null); + } + + @Override + public void infer( + Model model, + List input, + Map taskSettings, + InputType inputType, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case ANY, TEXT_EMBEDDING -> listener.onResponse( + makeResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) + ); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + @Override + public void chunkedInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + ActionListener> listener + ) { + switch (model.getConfigurations().getTaskType()) { + case ANY, TEXT_EMBEDDING -> listener.onResponse( + makeChunkedResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) + ); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private TextEmbeddingResults makeResults(List input, int dimensions) { + List embeddings = new ArrayList<>(); + for (int i = 0; i < input.size(); i++) { + List values = new ArrayList<>(); + for (int j = 0; j < dimensions; j++) { + values.add((float) j); + } + embeddings.add(new TextEmbeddingResults.Embedding(values)); + } + return new TextEmbeddingResults(embeddings); + } + + private List makeChunkedResults(List input, int dimensions) { + var results = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + double[] values = new double[dimensions]; + for (int j = 0; j < 5; j++) { + values[j] = j; + } + results.add( + new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults( + List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), values)) + ) + ); + } + return results; + } + + @Override + public void start(Model model, ActionListener listener) { + listener.onResponse(true); + } + + @Override + public void close() throws IOException {} + } + + public static class TestServiceModel extends Model { + + public TestServiceModel( + String modelId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); + } + + @Override + public TestServiceSettings getServiceSettings() { + return (TestServiceSettings) super.getServiceSettings(); + } + + @Override + public TestTaskSettings getTaskSettings() { + return (TestTaskSettings) super.getTaskSettings(); + } + + @Override + public TestSecretSettings getSecretSettings() { + return (TestSecretSettings) super.getSecretSettings(); + } + } + + public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings { + + static final String NAME = "test_text_embedding_service_settings"; + + public static TestServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } + + Integer dimensions = (Integer) map.remove("dimensions"); + if (dimensions == null) { + validationException.addValidationError("missing dimensions"); + } + + SimilarityMeasure similarity = null; + String similarityStr = (String) map.remove("similarity"); + if (similarityStr != null) { + similarity = SimilarityMeasure.valueOf(similarityStr); + } + + return new TestServiceSettings(model, dimensions, similarity); + } + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readOptionalInt(), in.readOptionalEnum(SimilarityMeasure.class)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.field("dimensions", dimensions); + if (similarity != null) { + builder.field("similarity", similarity); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeInt(dimensions); + out.writeOptionalEnum(similarity); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return (builder, params) -> { + builder.startObject(); + builder.field("model", model); + builder.field("dimensions", dimensions); + if (similarity != null) { + builder.field("similarity", similarity); + } + builder.endObject(); + return builder; + }; + } + } + + public record TestTaskSettings() implements TaskSettings { + + static final String NAME = "test_text_embedding_task_settings"; + + public static TestTaskSettings fromMap(Map map) { + return new TestTaskSettings(); + } + + public TestTaskSettings(StreamInput in) throws IOException { + this(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } + + public record TestSecretSettings(String apiKey) implements SecretSettings { + + static final String NAME = "test_text_embedding_secret_settings"; + + public static TestSecretSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String apiKey = (String) map.remove("api_key"); + + if (apiKey == null) { + validationException.addValidationError("missing api_key"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TestSecretSettings(apiKey); + } + + public TestSecretSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(apiKey); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("api_key", apiKey); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } +} From d53583e2b3be7cf608c958146c56de023b851c50 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 15 Feb 2024 13:09:45 +0100 Subject: [PATCH 09/40] Fix bug in model settings --- .../main/java/org/elasticsearch/inference/ModelSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java index 654a65cd2f489..10466114873bd 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -55,7 +55,7 @@ public static ModelSettings parse(XContentParser parser) throws IOException { TaskType taskType = TaskType.fromString((String) args[0]); String inferenceId = (String) args[1]; Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[2]); + SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); return new ModelSettings(taskType, inferenceId, dimensions, similarity); }); static { From 1edab71fc9e9434dab098e35b9ae571d896c0491 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 15 Feb 2024 13:29:54 +0100 Subject: [PATCH 10/40] Fix spotless --- .../mock/TestInferenceTextEmbeddingServiceExtension.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java index e7fe171842946..efdc7f80d7175 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java @@ -321,8 +321,7 @@ public TestTaskSettings(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException { - } + public void writeTo(StreamOutput out) throws IOException {} @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { From 3c14e12f17e5aa559b97ab811221799e78b244f7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 15 Feb 2024 13:44:35 +0100 Subject: [PATCH 11/40] Refactored inference services with common abstract class --- ...nferenceTextEmbeddingServiceExtension.java | 391 ------------------ 1 file changed, 391 deletions(-) delete mode 100644 x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java deleted file mode 100644 index efdc7f80d7175..0000000000000 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceTextEmbeddingServiceExtension.java +++ /dev/null @@ -1,391 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mock; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.SecretSettings; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; - -public class TestInferenceTextEmbeddingServiceExtension implements InferenceServiceExtension { - @Override - public List getInferenceServiceFactories() { - return List.of(TestInferenceService::new); - } - - public static class TestInferenceService implements InferenceService { - private static final String NAME = "text_embedding_test_service"; - - public TestInferenceService(InferenceServiceFactoryContext context) {} - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - - @SuppressWarnings("unchecked") - private static Map getTaskSettingsMap(Map settings) { - Map taskSettingsMap; - // task settings are optional - if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = (Map) settings.remove(ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - return taskSettingsMap; - } - - @Override - @SuppressWarnings("unchecked") - public void parseRequestConfig( - String modelId, - TaskType taskType, - Map config, - Set platformArchitectures, - ActionListener parsedModelListener - ) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); - var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - - parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); - } - - @Override - @SuppressWarnings("unchecked") - public TestServiceModel parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); - - var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); - var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); - } - - @Override - @SuppressWarnings("unchecked") - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - - var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null); - } - - @Override - public void infer( - Model model, - List input, - Map taskSettings, - InputType inputType, - ActionListener listener - ) { - switch (model.getConfigurations().getTaskType()) { - case ANY, TEXT_EMBEDDING -> listener.onResponse( - makeResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) - ); - default -> listener.onFailure( - new ElasticsearchStatusException( - TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), - RestStatus.BAD_REQUEST - ) - ); - } - } - - @Override - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - ChunkingOptions chunkingOptions, - ActionListener> listener - ) { - switch (model.getConfigurations().getTaskType()) { - case ANY, TEXT_EMBEDDING -> listener.onResponse( - makeChunkedResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) - ); - default -> listener.onFailure( - new ElasticsearchStatusException( - TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), - RestStatus.BAD_REQUEST - ) - ); - } - } - - private TextEmbeddingResults makeResults(List input, int dimensions) { - List embeddings = new ArrayList<>(); - for (int i = 0; i < input.size(); i++) { - List values = new ArrayList<>(); - for (int j = 0; j < dimensions; j++) { - values.add((float) j); - } - embeddings.add(new TextEmbeddingResults.Embedding(values)); - } - return new TextEmbeddingResults(embeddings); - } - - private List makeChunkedResults(List input, int dimensions) { - var results = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - double[] values = new double[dimensions]; - for (int j = 0; j < 5; j++) { - values[j] = j; - } - results.add( - new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults( - List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), values)) - ) - ); - } - return results; - } - - @Override - public void start(Model model, ActionListener listener) { - listener.onResponse(true); - } - - @Override - public void close() throws IOException {} - } - - public static class TestServiceModel extends Model { - - public TestServiceModel( - String modelId, - TaskType taskType, - String service, - TestServiceSettings serviceSettings, - TestTaskSettings taskSettings, - TestSecretSettings secretSettings - ) { - super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); - } - - @Override - public TestServiceSettings getServiceSettings() { - return (TestServiceSettings) super.getServiceSettings(); - } - - @Override - public TestTaskSettings getTaskSettings() { - return (TestTaskSettings) super.getTaskSettings(); - } - - @Override - public TestSecretSettings getSecretSettings() { - return (TestSecretSettings) super.getSecretSettings(); - } - } - - public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings { - - static final String NAME = "test_text_embedding_service_settings"; - - public static TestServiceSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String model = (String) map.remove("model"); - if (model == null) { - validationException.addValidationError("missing model"); - } - - Integer dimensions = (Integer) map.remove("dimensions"); - if (dimensions == null) { - validationException.addValidationError("missing dimensions"); - } - - SimilarityMeasure similarity = null; - String similarityStr = (String) map.remove("similarity"); - if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); - } - - return new TestServiceSettings(model, dimensions, similarity); - } - - public TestServiceSettings(StreamInput in) throws IOException { - this(in.readString(), in.readOptionalInt(), in.readOptionalEnum(SimilarityMeasure.class)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("model", model); - builder.field("dimensions", dimensions); - if (similarity != null) { - builder.field("similarity", similarity); - } - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(model); - out.writeInt(dimensions); - out.writeOptionalEnum(similarity); - } - - @Override - public ToXContentObject getFilteredXContentObject() { - return (builder, params) -> { - builder.startObject(); - builder.field("model", model); - builder.field("dimensions", dimensions); - if (similarity != null) { - builder.field("similarity", similarity); - } - builder.endObject(); - return builder; - }; - } - } - - public record TestTaskSettings() implements TaskSettings { - - static final String NAME = "test_text_embedding_task_settings"; - - public static TestTaskSettings fromMap(Map map) { - return new TestTaskSettings(); - } - - public TestTaskSettings(StreamInput in) throws IOException { - this(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException {} - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } - - public record TestSecretSettings(String apiKey) implements SecretSettings { - - static final String NAME = "test_text_embedding_secret_settings"; - - public static TestSecretSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String apiKey = (String) map.remove("api_key"); - - if (apiKey == null) { - validationException.addValidationError("missing api_key"); - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new TestSecretSettings(apiKey); - } - - public TestSecretSettings(StreamInput in) throws IOException { - this(in.readString()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(apiKey); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("api_key", apiKey); - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } -} From adb77110ff6630ebe8d1cc21601579072c72c15a Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 20 Feb 2024 16:15:31 -0500 Subject: [PATCH 12/40] Added modelsForFields to QueryRewriteContext --- .../index/query/QueryRewriteContext.java | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index e36c4d608d59f..ad0987d399fd7 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -37,6 +37,7 @@ import java.util.function.BooleanSupplier; import java.util.function.LongSupplier; import java.util.function.Predicate; +import java.util.stream.Collectors; /** * Context object used to rewrite {@link QueryBuilder} instances into simplified version. @@ -59,6 +60,7 @@ public class QueryRewriteContext { protected boolean allowUnmappedFields; protected boolean mapUnmappedFieldAsString; protected Predicate allowedFields; + private final Map> modelsForFields; public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, @@ -74,7 +76,8 @@ public QueryRewriteContext( final NamedWriteableRegistry namedWriteableRegistry, final ValuesSourceRegistry valuesSourceRegistry, final BooleanSupplier allowExpensiveQueries, - final ScriptCompiler scriptService + final ScriptCompiler scriptService, + final Map> modelsForFields ) { this.parserConfiguration = parserConfiguration; @@ -92,6 +95,9 @@ public QueryRewriteContext( this.valuesSourceRegistry = valuesSourceRegistry; this.allowExpensiveQueries = allowExpensiveQueries; this.scriptService = scriptService; + this.modelsForFields = modelsForFields != null ? + modelsForFields.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> Set.copyOf(e.getValue()))) : + Collections.emptyMap(); } public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) { @@ -109,10 +115,36 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, + null, null ); } + public QueryRewriteContext( + final XContentParserConfiguration parserConfiguration, + final Client client, + final LongSupplier nowInMillis, + final Map> modelsForFields + ) { + this( + parserConfiguration, + client, + nowInMillis, + null, + MappingLookup.EMPTY, + Collections.emptyMap(), + null, + null, + null, + null, + null, + null, + null, + null, + modelsForFields + ); + } + /** * The registry used to build new {@link XContentParser}s. Contains registered named parsers needed to parse the query. * @@ -345,4 +377,9 @@ public Iterable getAllFieldNames() { ? allFromMapping : () -> Iterators.concat(allFromMapping.iterator(), runtimeMappings.keySet().iterator()); } + + public Set getModelsForField(String fieldName) { + Set models = modelsForFields.get(fieldName); + return models != null ? models : Collections.emptySet(); + } } From ebb382701eeb405a31b0a028b35bf107d1f1b741 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 20 Feb 2024 17:02:14 -0500 Subject: [PATCH 13/40] Updated IndicesService to create the modelsForFields map --- .../query/TransportValidateQueryAction.java | 2 +- .../action/explain/TransportExplainAction.java | 2 +- .../action/search/TransportSearchAction.java | 6 +++++- .../search/TransportSearchShardsAction.java | 2 +- .../org/elasticsearch/index/IndexService.java | 3 ++- .../index/query/CoordinatorRewriteContext.java | 1 + .../index/query/SearchExecutionContext.java | 3 ++- .../elasticsearch/indices/IndicesService.java | 18 ++++++++++++++++-- .../elasticsearch/search/SearchService.java | 5 +++-- .../search/TransportSearchActionTests.java | 3 ++- .../test/AbstractBuilderTestCase.java | 3 ++- 11 files changed, 36 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index d4832fa0d14e1..64c1faf0401c0 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -107,7 +107,7 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener if (request.query() == null) { rewriteListener.onResponse(request.query()); } else { - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider), rewriteListener); + Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, request), rewriteListener); } } diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index d2d7a945520c1..6af5ac813cd43 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -84,7 +84,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider), rewriteListener); + Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, request), rewriteListener); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 0922e15999e8c..083b89a5cae04 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -453,7 +453,11 @@ void executeRequest( } } }); - Rewriteable.rewriteAndFetch(original, searchService.getRewriteContext(timeProvider::absoluteStartMillis), rewriteListener); + Rewriteable.rewriteAndFetch( + original, + searchService.getRewriteContext(timeProvider::absoluteStartMillis, original), + rewriteListener + ); } static void adjustSearchType(SearchRequest searchRequest, boolean singleShard) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index 60efb910a5269..068a5caac237a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -104,7 +104,7 @@ protected void doExecute(Task task, SearchShardsRequest searchShardsRequest, Act ClusterState clusterState = clusterService.state(); Rewriteable.rewriteAndFetch( original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis), + searchService.getRewriteContext(timeProvider::absoluteStartMillis, original), listener.delegateFailureAndWrap((delegate, searchRequest) -> { Map groupedIndices = remoteClusterService.groupIndices( searchRequest.indicesOptions(), diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 16a5d153a3c19..21d3ea932c28d 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -725,7 +725,8 @@ public QueryRewriteContext newQueryRewriteContext( namedWriteableRegistry, valuesSourceRegistry, allowExpensiveQueries, - scriptService + scriptService, + null ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index 2a1062f8876d2..ac6512b0839e6 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -51,6 +51,7 @@ public CoordinatorRewriteContext( null, null, null, + null, null ); this.indexLongFieldRange = indexLongFieldRange; diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index 86af6d21b7a09..be175dee804b1 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -265,7 +265,8 @@ private SearchExecutionContext( namedWriteableRegistry, valuesSourceRegistry, allowExpensiveQueries, - scriptService + scriptService, + null ); this.shardId = shardId; this.shardRequestIndex = shardRequestIndex; diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index b47d10882a5c1..43e294d9a2658 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -18,6 +18,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.action.admin.indices.mapping.put.TransportAutoPutMappingAction; import org.elasticsearch.action.admin.indices.mapping.put.TransportPutMappingAction; @@ -151,6 +152,7 @@ import java.util.Collection; import java.util.EnumMap; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -1695,8 +1697,20 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set> modelsForFields = new HashMap<>(); + for (Index index : indices) { + Map> fieldsForModels = indexService(index).getMetadata().getFieldsForModels(); + for (Map.Entry> entry : fieldsForModels.entrySet()) { + for (String fieldName : entry.getValue()) { + Set models = modelsForFields.computeIfAbsent(fieldName, v -> new HashSet<>()); + models.add(entry.getKey()); + } + } + } + + return new QueryRewriteContext(parserConfig, client, nowInMillis, modelsForFields); } public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 70a002d676235..129022b96c451 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -19,6 +19,7 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.search.CanMatchNodeRequest; import org.elasticsearch.action.search.CanMatchNodeResponse; import org.elasticsearch.action.search.SearchRequest; @@ -1759,8 +1760,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis) { - return indicesService.getRewriteContext(nowInMillis); + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, IndicesRequest indicesRequest) { + return indicesService.getRewriteContext(nowInMillis, indicesRequest); } public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 604d404c2f519..4b4f1490179e4 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -124,6 +124,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -1717,7 +1718,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { NodeClient client = new NodeClient(settings, threadPool); SearchService searchService = mock(SearchService.class); - when(searchService.getRewriteContext(any())).thenReturn(new QueryRewriteContext(null, null, null)); + when(searchService.getRewriteContext(any(), eq(searchRequest))).thenReturn(new QueryRewriteContext(null, null, null)); ClusterService clusterService = new ClusterService( settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 76b836ba7e2a7..1d163b2ee7d33 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -601,7 +601,8 @@ QueryRewriteContext createQueryRewriteContext() { namedWriteableRegistry, null, () -> true, - scriptService + scriptService, + null ); } From 8a466a98ec399d2be071a07a15a7e4adee096a19 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 13:40:24 -0500 Subject: [PATCH 14/40] Updated SemanticQueryBuilder to implement doRewrite --- .../queries/SemanticQueryBuilder.java | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index c44a7125a6fd7..6819b3f097720 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,17 +8,30 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.io.IOException; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class SemanticQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "semantic_query"; @@ -28,6 +41,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsSupplier; + public SemanticQueryBuilder(String fieldName, String query) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); @@ -45,6 +60,12 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { this.query = in.readString(); } + private SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce inferenceResultsSupplier) { + this.fieldName = other.fieldName; + this.query = other.query; + this.inferenceResultsSupplier = inferenceResultsSupplier; + } + @Override public String getWriteableName() { return NAME; @@ -71,6 +92,46 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.endObject(); } + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { + if (inferenceResultsSupplier != null) { + return this; + } + + Set modelsForField = queryRewriteContext.getModelsForField(fieldName); + if (modelsForField.isEmpty()) { + throw new IllegalArgumentException("field [" + fieldName + "] is not a semantic_text field type"); + } + + if (modelsForField.size() > 1) { + // TODO: Handle multi-index semantic queries + throw new IllegalArgumentException("field [" + fieldName + "] has multiple models associated with it"); + } + + // TODO: How to determine task type? + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + modelsForField.iterator().next(), + List.of(query), + Map.of(), + InputType.SEARCH + ); + + SetOnce inferenceResultsSupplier = new SetOnce<>(); + queryRewriteContext.registerAsyncAction((client, listener) -> executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + listener.delegateFailureAndWrap((l, inferenceResponse) -> { + inferenceResultsSupplier.set(inferenceResponse.getResults()); + l.onResponse(null); + }) + )); + + return new SemanticQueryBuilder(this, inferenceResultsSupplier); + } + @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { // TODO: Implement From 5d27f6a0f6dbb872363dcc93e91c77dd2f46accc Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 15:23:53 -0500 Subject: [PATCH 15/40] Added semanticQuery to SemanticTextFieldMapper --- .../inference/src/main/java/module-info.java | 1 + .../mapper/SemanticTextFieldMapper.java | 32 +++++++++++++++++++ .../queries/SemanticQueryBuilder.java | 3 ++ 3 files changed, 36 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index ddd56c758d67c..09a0adb384c2d 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -18,6 +18,7 @@ requires org.apache.httpcomponents.httpcore.nio; requires org.apache.lucene.core; requires org.elasticsearch.logging; + requires org.apache.lucene.join; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..69dad8e46a113 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -8,7 +8,10 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -20,11 +23,18 @@ import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.Map; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; + /** * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference * at ingestion and query time. @@ -126,5 +136,27 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } + + public Query semanticQuery( + InferenceResults inferenceResults, + SearchExecutionContext context, + float boost, + String queryName + ) throws IOException { + String fieldName = name() + "." + INFERENCE_CHUNKS_RESULTS; + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery().minimumShouldMatch(1).boost(boost).queryName(queryName); + + // TODO: Support dense vectors + if (inferenceResults instanceof TextExpansionResults textExpansionResults) { + for (TextExpansionResults.WeightedToken weightedToken : textExpansionResults.getWeightedTokens()) { + queryBuilder.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); + } + } else { + throw new IllegalArgumentException("Unsupported inference results type [" + inferenceResults.getWriteableName() + "]"); + } + + BitSetProducer parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated())); + return new ESToParentBlockJoinQuery(queryBuilder.toQuery(context), parentFilter, ScoreMode.Total, name()); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 6819b3f097720..3c7ae64ce9fdf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -63,6 +63,8 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { private SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce inferenceResultsSupplier) { this.fieldName = other.fieldName; this.query = other.query; + this.boost = other.boost; + this.queryName = other.queryName; this.inferenceResultsSupplier = inferenceResultsSupplier; } @@ -135,6 +137,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { // TODO: Implement + // TODO: Pass boost to generated query return null; } From 6abf24ce3b38d8ebc7dd110daf244ffc827b3175 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 15:51:46 -0500 Subject: [PATCH 16/40] Updated SemanticQueryBuilder to implement doToQuery --- .../queries/SemanticQueryBuilder.java | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 3c7ae64ce9fdf..b69861943117b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -13,16 +13,19 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; import java.util.List; @@ -102,12 +105,12 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { Set modelsForField = queryRewriteContext.getModelsForField(fieldName); if (modelsForField.isEmpty()) { - throw new IllegalArgumentException("field [" + fieldName + "] is not a semantic_text field type"); + throw new IllegalArgumentException("Field [" + fieldName + "] is not a semantic_text field type"); } if (modelsForField.size() > 1) { // TODO: Handle multi-index semantic queries - throw new IllegalArgumentException("field [" + fieldName + "] has multiple models associated with it"); + throw new IllegalArgumentException("Field [" + fieldName + "] has multiple models associated with it"); } // TODO: How to determine task type? @@ -136,18 +139,40 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - // TODO: Implement - // TODO: Pass boost to generated query - return null; + InferenceServiceResults inferenceServiceResults = inferenceResultsSupplier.get(); + if (inferenceServiceResults == null) { + throw new IllegalArgumentException("Inference results supplier for field [" + fieldName + "] is empty"); + } + + List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); + if (inferenceResultsList.isEmpty()) { + throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]"); + } else if (inferenceResultsList.size() > 1) { + // TODO: How to handle multiple inference results? + throw new IllegalArgumentException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]"); + } + + InferenceResults inferenceResults = inferenceResultsList.get(0); + MappedFieldType fieldType = context.getFieldType(fieldName); + if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType == false) { + // TODO: Better exception type to throw here? + throw new IllegalArgumentException( + "Field [" + fieldName + "] is not registered as a " + SemanticTextFieldMapper.CONTENT_TYPE + " field type" + ); + } + + return ((SemanticTextFieldMapper.SemanticTextFieldType) fieldType).semanticQuery(inferenceResults, context, boost, queryName); } @Override protected boolean doEquals(SemanticQueryBuilder other) { - return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query); + return Objects.equals(fieldName, other.fieldName) + && Objects.equals(query, other.query) + && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query); + return Objects.hash(fieldName, query, inferenceResultsSupplier); } } From b434da5c075fe1a516262b1f370507de433df06b Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 16:16:36 -0500 Subject: [PATCH 17/40] Updated SemanticQueryBuilder to add fromXContent --- .../queries/SemanticQueryBuilder.java | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index b69861943117b..96c274fef526c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.mapper.MappedFieldType; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -175,4 +177,56 @@ protected boolean doEquals(SemanticQueryBuilder other) { protected int doHashCode() { return Objects.hash(fieldName, query, inferenceResultsSupplier); } + + public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IOException { + String fieldName = null; + String query = null; + float boost = AbstractQueryBuilder.DEFAULT_BOOST; + String queryName = null; + + String currentFieldName = null; + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName); + fieldName = currentFieldName; + for (token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + query = parser.text(); + } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + boost = parser.floatValue(); + } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + queryName = parser.text(); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } + } + } + } + + if (fieldName == null) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] no field name specified"); + } + if (query == null) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] no query specified"); + } + + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(fieldName, query); + queryBuilder.queryName(queryName); + queryBuilder.boost(boost); + return queryBuilder; + } } From ba852b8e97db9797c5fb199ed1402d8fa18f5555 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 16:25:30 -0500 Subject: [PATCH 18/40] Added SemanticQueryBuilder to inference plugin --- .../xpack/inference/InferencePlugin.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 821a804596cff..4893af14096fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -23,6 +23,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -33,6 +34,7 @@ import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -58,6 +60,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -86,7 +89,8 @@ public class InferencePlugin extends Plugin ExtensiblePlugin, SystemIndexPlugin, InferenceRegistryPlugin, - MapperPlugin { + MapperPlugin, + SearchPlugin { /** * When this setting is true the verification check that @@ -302,4 +306,13 @@ public Map getMappers() { public Map getMetadataMappers() { return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); } + + @Override + public List> getQueries() { + return List.of(new QuerySpec( + SemanticQueryBuilder.NAME, + SemanticQueryBuilder::new, + SemanticQueryBuilder::fromXContent + )); + } } From b2c4574a654df2ebed3cc04c7c0c24fecb3875ca Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 21 Feb 2024 17:47:37 -0500 Subject: [PATCH 19/40] Use Lucene queries to build semantic query --- .../mapper/SemanticTextFieldMapper.java | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 69dad8e46a113..6d991db7f22ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -7,7 +7,12 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.Strings; @@ -23,8 +28,6 @@ import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.inference.InferenceResults; @@ -142,21 +145,32 @@ public Query semanticQuery( SearchExecutionContext context, float boost, String queryName - ) throws IOException { + ) { + // Cant use QueryBuilders.boolQuery() because a mapper is not registered for .inference, causing + // TermQueryBuilder#doToQuery to fail (at TermQueryBuilder:202) + // TODO: Handle boost and queryName String fieldName = name() + "." + INFERENCE_CHUNKS_RESULTS; - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery().minimumShouldMatch(1).boost(boost).queryName(queryName); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder().setMinimumNumberShouldMatch(1); // TODO: Support dense vectors if (inferenceResults instanceof TextExpansionResults textExpansionResults) { for (TextExpansionResults.WeightedToken weightedToken : textExpansionResults.getWeightedTokens()) { - queryBuilder.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); + queryBuilder.add( + new BoostQuery( + new TermQuery( + new Term(fieldName, weightedToken.token()) + ), + weightedToken.weight() + ), + BooleanClause.Occur.SHOULD + ); } } else { throw new IllegalArgumentException("Unsupported inference results type [" + inferenceResults.getWriteableName() + "]"); } BitSetProducer parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated())); - return new ESToParentBlockJoinQuery(queryBuilder.toQuery(context), parentFilter, ScoreMode.Total, name()); + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, name()); } } } From 1361f0865ab89eeae63333b9919096eaaac85a50 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 11 Mar 2024 19:24:21 +0100 Subject: [PATCH 20/40] Spotless --- .../elasticsearch/xpack/inference/InferencePlugin.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 4893af14096fd..7daa53148bbd7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -309,10 +309,8 @@ public Map getMetadataMappers() { @Override public List> getQueries() { - return List.of(new QuerySpec( - SemanticQueryBuilder.NAME, - SemanticQueryBuilder::new, - SemanticQueryBuilder::fromXContent - )); + return List.of( + new QuerySpec(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent) + ); } } From ddf374c420d3ee9e81b4c5f68c233ad2a3872fc7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 17:23:25 +0100 Subject: [PATCH 21/40] New class for dealing with field inference metadata --- .../metadata/FieldInferenceMetadata.java | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java new file mode 100644 index 0000000000000..f3043da3e1c7f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator + * node, which not necessarily has mapping information. + */ +public class FieldInferenceMetadata implements Diffable, ToXContentFragment { + + // Keys: field names. Values: Inference ID associated to the field name for calculating inference + private final ImmutableOpenMap inferenceIdForField; + + // Keys: field names. Values: Field names that provide source for this field (either as copy_to or multifield sources) + private final ImmutableOpenMap> sourceFields; + + // Keys: inference IDs. Values: Field names that use the inference id for calculating inference. Reverse of inferenceForFields. + private Map> fieldsForInferenceIds; + + public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of(), ImmutableOpenMap.of()); + + public static final ParseField INFERENCE_FOR_FIELDS_FIELD = new ParseField("inference_for_fields"); + public static final ParseField COPY_FROM_FIELDS_FIELD = new ParseField("copy_from_fields"); + + public FieldInferenceMetadata( + ImmutableOpenMap inferenceIdsForFields, + ImmutableOpenMap> sourceFields + ) { + this.inferenceIdForField = Objects.requireNonNull(inferenceIdsForFields); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public FieldInferenceMetadata( + Map inferenceIdsForFields, + Map> sourceFields + ) { + this.inferenceIdForField = ImmutableOpenMap.builder(Objects.requireNonNull(inferenceIdsForFields)).build(); + this.sourceFields = ImmutableOpenMap.builder(Objects.requireNonNull(sourceFields)).build(); + } + + public FieldInferenceMetadata(MappingLookup mappingLookup) { + this.inferenceIdForField = ImmutableOpenMap.builder(mappingLookup.getInferenceIdsForFields()).build(); + ImmutableOpenMap.Builder> sourcePathsBuilder = ImmutableOpenMap.builder(inferenceIdForField.size()); + inferenceIdForField.keySet().forEach(fieldName -> sourcePathsBuilder.put(fieldName, mappingLookup.sourcePaths(fieldName))); + this.sourceFields = sourcePathsBuilder.build(); + } + + public FieldInferenceMetadata(StreamInput in) throws IOException { + inferenceIdForField = in.readImmutableOpenMap(StreamInput::readString, StreamInput::readString); + sourceFields = in.readImmutableOpenMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(inferenceIdForField, StreamOutput::writeString); + out.writeMap(sourceFields, StreamOutput::writeStringCollection); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(INFERENCE_FOR_FIELDS_FIELD.getPreferredName(), inferenceIdForField); + builder.field(COPY_FROM_FIELDS_FIELD.getPreferredName(), sourceFields); + + return builder; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field_inference_metadata_parser", + false, + (args, unused) -> new FieldInferenceMetadata((Map) args[0], (Map>) args[1]) + ); + + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), INFERENCE_FOR_FIELDS_FIELD); + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> p.map( + HashMap::new, + v -> v.list().stream().map(Object::toString).collect(Collectors.toSet()) + ), + COPY_FROM_FIELDS_FIELD + ); + } + + public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @Override + public Diff diff(FieldInferenceMetadata previousState) { + return new FieldInferenceMetadataDiff(previousState, this); + } + + public String getInferenceIdForField(String field) { + return inferenceIdForField.get(field); + } + + public Map getInferenceIdForFields() { + return inferenceIdForField; + } + + public Set getSourceFields(String field) { + return sourceFields.get(field); + } + + public Map> getFieldsForInferenceIds() { + if (fieldsForInferenceIds != null) { + return fieldsForInferenceIds; + } + + // Cache the result as a field + Map> fieldsForInferenceIdsMap = new HashMap<>(); + for (Map.Entry entry : inferenceIdForField.entrySet()) { + String inferenceId = entry.getValue(); + String fieldName = entry.getKey(); + + // Get or create the set associated with the inferenceId + Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); + fields.add(fieldName); + } + + fieldsForInferenceIds = Collections.unmodifiableMap(fieldsForInferenceIdsMap); + return fieldsForInferenceIds; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FieldInferenceMetadata that = (FieldInferenceMetadata) o; + return Objects.equals(inferenceIdForField, that.inferenceIdForField) && Objects.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceIdForField, sourceFields); + } + + public static class FieldInferenceMetadataDiff implements Diff { + + private final Diff> inferenceForFields; + private final Diff>> copyFromFields; + + private static final DiffableUtils.NonDiffableValueSerializer STRING_VALUE_SERIALIZER = + new DiffableUtils.NonDiffableValueSerializer<>() { + @Override + public void write(String value, StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String read(StreamInput in, String key) throws IOException { + return in.readString(); + } + }; + + FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { + inferenceForFields = DiffableUtils.diff( + before.inferenceIdForField, + after.inferenceIdForField, + DiffableUtils.getStringKeySerializer(), + STRING_VALUE_SERIALIZER); + copyFromFields = DiffableUtils.diff( + before.sourceFields, + after.sourceFields, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } + + FieldInferenceMetadataDiff(StreamInput in) throws IOException { + inferenceForFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + STRING_VALUE_SERIALIZER + ); + copyFromFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } + + public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( + FieldInferenceMetadata.EMPTY, + FieldInferenceMetadata.EMPTY + ) { + @Override + public FieldInferenceMetadata apply(FieldInferenceMetadata part) { + return part; + } + }; + @Override + public FieldInferenceMetadata apply(FieldInferenceMetadata part) { + ImmutableOpenMap modelForFields = this.inferenceForFields.apply(part.inferenceIdForField); + ImmutableOpenMap> copyFromFields = this.copyFromFields.apply(part.sourceFields); + return new FieldInferenceMetadata(modelForFields, copyFromFields); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + inferenceForFields.writeTo(out); + copyFromFields.writeTo(out); + } + } +} From 354bb09857237177a212203e2e4f0fa5a5e0b847 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:17:13 +0100 Subject: [PATCH 22/40] Include FieldInferenceMetadata into IndexMetadata --- .../cluster/metadata/IndexMetadata.java | 106 +++++++----------- 1 file changed, 42 insertions(+), 64 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 81406f0a74ce5..42b60afa07e35 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -78,7 +78,6 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -541,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String KEY_FIELD_INFERENCE_METADATA = "field_inference_metadata"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -632,8 +631,7 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - // Key: model ID, Value: Fields that use model - private final ImmutableOpenMap> fieldsForModels; + private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -680,7 +678,7 @@ private IndexMetadata( @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, @Nullable Long shardSizeInBytesForecast, - final ImmutableOpenMap> fieldsForModels + @Nullable FieldInferenceMetadata fieldInferenceMetadata ) { this.index = index; this.version = version; @@ -736,7 +734,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldsForModels = Objects.requireNonNull(fieldsForModels); + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -788,7 +786,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -847,7 +845,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -904,7 +902,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -961,7 +959,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1014,7 +1012,7 @@ public IndexMetadata withIncrementedVersion() { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1218,8 +1216,8 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public Map> getFieldsForModels() { - return fieldsForModels; + public FieldInferenceMetadata getFieldInferenceMetadata() { + return fieldInferenceMetadata; } public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; @@ -1419,7 +1417,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldsForModels.equals(that.fieldsForModels) == false) { + if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { return false; } if (isSystem != that.isSystem) { @@ -1442,7 +1440,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldsForModels.hashCode(); + result = 31 * result + fieldInferenceMetadata.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1498,7 +1496,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff>> fieldsForModels; + private final FieldInferenceMetadata.FieldInferenceMetadataDiff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1535,11 +1533,9 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldsForModels = DiffableUtils.diff( - before.fieldsForModels, - after.fieldsForModels, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() + fieldInferenceMetadata = new FieldInferenceMetadata.FieldInferenceMetadataDiff( + before.fieldInferenceMetadata, + after.fieldInferenceMetadata ); } @@ -1601,13 +1597,9 @@ private static class IndexMetadataDiff implements Diff { shardSizeInBytesForecast = null; } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = DiffableUtils.readJdkMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); + fieldInferenceMetadata = new FieldInferenceMetadata.FieldInferenceMetadataDiff(in); } else { - fieldsForModels = DiffableUtils.emptyDiff(); + fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; } } @@ -1645,7 +1637,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels.writeTo(out); + fieldInferenceMetadata.writeTo(out); } } @@ -1676,7 +1668,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); + builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1745,9 +1737,9 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) - ); + if (in.readBoolean()) { + builder.fieldInferenceMetadata(new FieldInferenceMetadata(in)); + } } return builder.build(true); } @@ -1796,7 +1788,12 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + if (fieldInferenceMetadata == FieldInferenceMetadata.EMPTY) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + fieldInferenceMetadata.writeTo(out); + } } } @@ -1847,7 +1844,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private final ImmutableOpenMap.Builder> fieldsForModels; + private FieldInferenceMetadata fieldInferenceMetadata; public Builder(String index) { this.index = index; @@ -1855,7 +1852,7 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); - this.fieldsForModels = ImmutableOpenMap.builder(); + this.fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; this.isSystem = false; } @@ -1880,7 +1877,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); + this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2110,8 +2107,8 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldsForModels(Map> fieldsForModels) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); + public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { + this.fieldInferenceMetadata = fieldInferenceMetadata; return this; } @@ -2310,7 +2307,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels.build() + fieldInferenceMetadata ); } @@ -2436,8 +2433,10 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); + if (indexMetadata.fieldInferenceMetadata != FieldInferenceMetadata.EMPTY) { + builder.startObject(KEY_FIELD_INFERENCE_METADATA); + indexMetadata.fieldInferenceMetadata.toXContent(builder, params); + builder.endObject(); } builder.endObject(); @@ -2517,18 +2516,8 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) - .entrySet() - .stream() - .collect( - Collectors.toMap( - Map.Entry::getKey, - v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) - ) - ); - builder.fieldsForModels(fieldsForModels); + case KEY_FIELD_INFERENCE_METADATA: + builder.fieldInferenceMetadata(FieldInferenceMetadata.fromXContent(parser)); break; default: // assume it's custom index metadata @@ -2726,17 +2715,6 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } - - private static void processFieldsForModels( - ImmutableOpenMap.Builder> builder, - Map> fieldsForModels - ) { - builder.clear(); - if (fieldsForModels != null) { - // Ensure that all field sets contained in the processed map are immutable - fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); - } - } } /** From c88b9cd49b4b6ef45424e2958347a0a31d1595ec Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:18:24 +0100 Subject: [PATCH 23/40] Create new FieldInferenceMetadata structure --- .../cluster/metadata/MetadataCreateIndexService.java | 4 ++-- .../cluster/metadata/MetadataMappingService.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index d8fe0b0c19e52..96ca7a15edc30 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1267,8 +1267,8 @@ static IndexMetadata buildIndexMetadata( if (mapper != null) { MappingMetadata mappingMd = new MappingMetadata(mapper); mappingsMetadata.put(mapper.type(), mappingMd); - - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); + indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index d913a6465482d..0e31592991369 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -204,7 +204,7 @@ private static ClusterState applyRequest( DocumentMapper mapper = mapperService.documentMapper(); if (mapper != null) { indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); From db7f53115a45e609c8445086b706f815f041b43f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:19:56 +0100 Subject: [PATCH 24/40] Use FieldInferenceMetadata structure in lookups, some renaming --- .../index/mapper/FieldTypeLookup.java | 17 ++++++----------- .../index/mapper/InferenceModelFieldType.java | 2 +- .../index/mapper/MappingLookup.java | 4 ++-- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 564e6f903a2ae..774d25e726149 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -39,7 +39,7 @@ final class FieldTypeLookup { /** * A map from inference model ID to all fields that use the model to generate embeddings. */ - private final Map> fieldsForModels; + private final Map inferenceIdsForFields; private final int maxParentPathDots; @@ -53,7 +53,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map> fieldsForModels = new HashMap<>(); + final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -72,11 +72,7 @@ final class FieldTypeLookup { fieldToCopiedFields.get(targetField).add(fieldName); } if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - String inferenceModel = inferenceModelFieldType.getInferenceModel(); - if (inferenceModel != null) { - Set fields = fieldsForModels.computeIfAbsent(inferenceModel, v -> new HashSet<>()); - fields.add(fieldName); - } + inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); } } @@ -110,8 +106,7 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - fieldsForModels.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); - this.fieldsForModels = Map.copyOf(fieldsForModels); + this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -220,8 +215,8 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map> getFieldsForModels() { - return fieldsForModels; + Map getInferenceIdsForFields() { + return inferenceIdsForFields; } /** diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java index 490d7f36219cf..6e12a204ed7d0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java @@ -17,5 +17,5 @@ public interface InferenceModelFieldType { * * @return model id used by the field type */ - String getInferenceModel(); + String getInferenceId(); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 96eb0211a4a0c..66b2f2f171af6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -517,7 +517,7 @@ public void validateDoesNotShadow(String name) { } } - public Map> getFieldsForModels() { - return fieldTypeLookup.getFieldsForModels(); + public Map getInferenceIdsForFields() { + return fieldTypeLookup.getInferenceIdsForFields(); } } From 32293d0ef857c51182913934951cfb46d0d32126 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:21:18 +0100 Subject: [PATCH 25/40] Use FieldInferenceMetadata structure in dependencies --- .../elasticsearch/action/bulk/BulkOperationTests.java | 2 +- .../cluster/metadata/IndexMetadataTests.java | 4 ++-- .../index/mapper/FieldTypeLookupTests.java | 6 +++--- .../elasticsearch/index/mapper/MappingLookupTests.java | 6 +++--- .../index/mapper/MockInferenceModelFieldType.java | 2 +- .../src/main/java/org/elasticsearch/node/MockNode.java | 10 ++++++++-- .../org/elasticsearch/search/MockSearchService.java | 6 +++++- .../xpack/inference/queries/SemanticQueryBuilder.java | 2 +- 8 files changed, 24 insertions(+), 14 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 2ce7b161d3dd1..38860a6ce91cd 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -472,7 +472,7 @@ private static BulkShardRequest runBulkOperation( ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldsForModels) .settings(settings) .numberOfShards(1) .numberOfReplicas(0) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 58b8adcf53538..dbe3c25044f82 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -108,7 +108,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldsForModels) .build(); assertEquals(system, metadata.isSystem()); @@ -556,7 +556,7 @@ public void testFieldsForModels() { assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldsForModels).build(); assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 27663edde945c..9cd33e36b1d71 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -37,7 +37,7 @@ public void testEmpty() { assertNotNull(names); assertThat(names, hasSize(0)); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map> fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -48,7 +48,7 @@ public void testAddNewField() { assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map> fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -440,7 +440,7 @@ public void testInferenceModelFieldType() { assertEquals(f2.fieldType(), lookup.get("foo2")); assertEquals(f3.fieldType(), lookup.get("foo3")); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map> fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertEquals(2, fieldsForModels.size()); assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1")); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index f512f5d352a43..028bf5c864d11 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -122,8 +122,8 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getFieldsForModels()); - assertTrue(mappingLookup.getFieldsForModels().isEmpty()); + assertNotNull(mappingLookup.getInferenceIdsForFields()); + assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -201,7 +201,7 @@ public void testFieldsForModels() { assertEquals(1, size(mappingLookup.fieldMappers())); assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - Map> fieldsForModels = mappingLookup.getFieldsForModels(); + Map> fieldsForModels = mappingLookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertEquals(1, fieldsForModels.size()); assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java index 854749d6308db..0d21134b5d9a9 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java @@ -39,7 +39,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } } diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index ef29f9fca4f93..4d4e35c332afd 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -28,6 +28,8 @@ import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.MockPluginsService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsService; @@ -101,7 +103,9 @@ SearchService newSearchService( ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, ExecutorSelector executorSelector, - Tracer tracer + Tracer tracer, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry ) { if (pluginsService.filterPlugins(MockSearchService.TestPlugin.class).findAny().isEmpty()) { return super.newSearchService( @@ -115,7 +119,9 @@ SearchService newSearchService( responseCollectorService, circuitBreakerService, executorSelector, - tracer + tracer, + modelRegistry, + inferenceServiceRegistry ); } return new MockSearchService( diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index aa1889e15d594..3252986501367 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -15,6 +15,8 @@ import org.elasticsearch.indices.ExecutorSelector; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.MockNode; import org.elasticsearch.node.ResponseCollectorService; import org.elasticsearch.plugins.Plugin; @@ -97,7 +99,9 @@ public MockSearchService( responseCollectorService, circuitBreakerService, executorSelector, - tracer + tracer, + null, + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 96c274fef526c..51c449067f18f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -105,7 +105,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { return this; } - Set modelsForField = queryRewriteContext.getModelsForField(fieldName); + Set modelsForField = queryRewriteContext.getInferenceIdsForField(fieldName); if (modelsForField.isEmpty()) { throw new IllegalArgumentException("Field [" + fieldName + "] is not a semantic_text field type"); } From 949a8d029463cbbc176976264be2d8dc4837747b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:23:47 +0100 Subject: [PATCH 26/40] Use FieldInferenceMetadata structure in dependencies --- .../bulk/BulkShardRequestInferenceProvider.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 4b7a67e9ca0e3..7fe760f402e0f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -77,8 +77,8 @@ public static void getInstance( ) { Set inferenceIds = new HashSet<>(); shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); - inferenceIds.addAll(fieldsForModels.keySet()); + var fieldsForInferenceIds = clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldsForInferenceIds(); + inferenceIds.addAll(fieldsForInferenceIds.keySet()); }); final Map inferenceProviderMap = new ConcurrentHashMap<>(); Runnable onModelLoadingComplete = () -> listener.onResponse( @@ -134,11 +134,11 @@ public void processBulkShardRequest( BiConsumer onBulkItemFailure ) { - Map> fieldsForModels = clusterState.metadata() + Map> fieldsForInferenceIds = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); + .getFieldInferenceMetadata().getFieldsForInferenceIds(); // No inference fields? Terminate early - if (fieldsForModels.isEmpty()) { + if (fieldsForInferenceIds.isEmpty()) { listener.onResponse(bulkShardRequest); return; } @@ -176,7 +176,7 @@ public void processBulkShardRequest( if (bulkItemRequest != null) { performInferenceOnBulkItemRequest( bulkItemRequest, - fieldsForModels, + fieldsForInferenceIds, i, onBulkItemFailureWithIndex, bulkItemReqRef.acquire() From 21bf90b8338a09d093eab30863f7c0c90a75e75a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 18:28:35 +0100 Subject: [PATCH 27/40] Renaming fields --- .../cluster/metadata/FieldInferenceMetadata.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index f3043da3e1c7f..5be3c33fca4e4 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -48,7 +48,7 @@ public class FieldInferenceMetadata implements Diffable, public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of(), ImmutableOpenMap.of()); public static final ParseField INFERENCE_FOR_FIELDS_FIELD = new ParseField("inference_for_fields"); - public static final ParseField COPY_FROM_FIELDS_FIELD = new ParseField("copy_from_fields"); + public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); public FieldInferenceMetadata( ImmutableOpenMap inferenceIdsForFields, @@ -87,7 +87,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(INFERENCE_FOR_FIELDS_FIELD.getPreferredName(), inferenceIdForField); - builder.field(COPY_FROM_FIELDS_FIELD.getPreferredName(), sourceFields); + builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); return builder; } @@ -107,7 +107,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws HashMap::new, v -> v.list().stream().map(Object::toString).collect(Collectors.toSet()) ), - COPY_FROM_FIELDS_FIELD + SOURCE_FIELDS_FIELD ); } From a21e9e4b9a08f9f9f7ac16cb49a8d5cf6e09fa0d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 19:50:32 +0100 Subject: [PATCH 28/40] Fix rebasing --- .../cluster/ClusterStateDiffIT.java | 2 +- .../query/TransportValidateQueryAction.java | 2 +- .../explain/TransportExplainAction.java | 2 +- .../action/search/TransportSearchAction.java | 2 +- .../search/TransportSearchShardsAction.java | 2 +- .../elasticsearch/indices/IndicesService.java | 18 +- .../elasticsearch/search/SearchService.java | 4 +- .../java/org/elasticsearch/node/MockNode.java | 10 +- .../search/MockSearchService.java | 6 +- .../xpack/inference/InferencePlugin.java | 13 +- .../mapper/SemanticTextFieldMapper.java | 48 +--- .../queries/SemanticQueryBuilder.java | 232 ------------------ 12 files changed, 14 insertions(+), 327 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 3a1f6e20bb288..404cfe3dc560b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldsForModels(randomFieldsForModels()); + builder.fieldInferenceMetadata(randomFieldsForModels()); break; default: throw new IllegalArgumentException("Shouldn't be here"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index 64c1faf0401c0..d4832fa0d14e1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -107,7 +107,7 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener if (request.query() == null) { rewriteListener.onResponse(request.query()); } else { - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, request), rewriteListener); + Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider), rewriteListener); } } diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index 6af5ac813cd43..d2d7a945520c1 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -84,7 +84,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, request), rewriteListener); + Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider), rewriteListener); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 083b89a5cae04..d1bd5bd38d1b5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -455,7 +455,7 @@ void executeRequest( }); Rewriteable.rewriteAndFetch( original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, original), + searchService.getRewriteContext(timeProvider::absoluteStartMillis), rewriteListener ); } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index 068a5caac237a..60efb910a5269 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -104,7 +104,7 @@ protected void doExecute(Task task, SearchShardsRequest searchShardsRequest, Act ClusterState clusterState = clusterService.state(); Rewriteable.rewriteAndFetch( original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, original), + searchService.getRewriteContext(timeProvider::absoluteStartMillis), listener.delegateFailureAndWrap((delegate, searchRequest) -> { Map groupedIndices = remoteClusterService.groupIndices( searchRequest.indicesOptions(), diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 43e294d9a2658..b47d10882a5c1 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -18,7 +18,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.action.admin.indices.mapping.put.TransportAutoPutMappingAction; import org.elasticsearch.action.admin.indices.mapping.put.TransportPutMappingAction; @@ -152,7 +151,6 @@ import java.util.Collection; import java.util.EnumMap; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -1697,20 +1695,8 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set> modelsForFields = new HashMap<>(); - for (Index index : indices) { - Map> fieldsForModels = indexService(index).getMetadata().getFieldsForModels(); - for (Map.Entry> entry : fieldsForModels.entrySet()) { - for (String fieldName : entry.getValue()) { - Set models = modelsForFields.computeIfAbsent(fieldName, v -> new HashSet<>()); - models.add(entry.getKey()); - } - } - } - - return new QueryRewriteContext(parserConfig, client, nowInMillis, modelsForFields); + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis) { + return new QueryRewriteContext(parserConfig, client, nowInMillis); } public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 129022b96c451..fa4b12f56dd18 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -1760,8 +1760,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, IndicesRequest indicesRequest) { - return indicesService.getRewriteContext(nowInMillis, indicesRequest); + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis) { + return indicesService.getRewriteContext(nowInMillis); } public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) { diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index 4d4e35c332afd..ef29f9fca4f93 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -28,8 +28,6 @@ import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.recovery.RecoverySettings; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.MockPluginsService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsService; @@ -103,9 +101,7 @@ SearchService newSearchService( ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, ExecutorSelector executorSelector, - Tracer tracer, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry + Tracer tracer ) { if (pluginsService.filterPlugins(MockSearchService.TestPlugin.class).findAny().isEmpty()) { return super.newSearchService( @@ -119,9 +115,7 @@ SearchService newSearchService( responseCollectorService, circuitBreakerService, executorSelector, - tracer, - modelRegistry, - inferenceServiceRegistry + tracer ); } return new MockSearchService( diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index 3252986501367..aa1889e15d594 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -15,8 +15,6 @@ import org.elasticsearch.indices.ExecutorSelector; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.breaker.CircuitBreakerService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.MockNode; import org.elasticsearch.node.ResponseCollectorService; import org.elasticsearch.plugins.Plugin; @@ -99,9 +97,7 @@ public MockSearchService( responseCollectorService, circuitBreakerService, executorSelector, - tracer, - null, - null + tracer ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7daa53148bbd7..821a804596cff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -23,7 +23,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -34,7 +33,6 @@ import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -60,7 +58,6 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -89,8 +86,7 @@ public class InferencePlugin extends Plugin ExtensiblePlugin, SystemIndexPlugin, InferenceRegistryPlugin, - MapperPlugin, - SearchPlugin { + MapperPlugin { /** * When this setting is true the verification check that @@ -306,11 +302,4 @@ public Map getMappers() { public Map getMetadataMappers() { return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); } - - @Override - public List> getQueries() { - return List.of( - new QuerySpec(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent) - ); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 6d991db7f22ff..d9e18728615ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -7,16 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -29,15 +21,10 @@ import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.Map; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; - /** * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference * at ingestion and query time. @@ -121,7 +108,7 @@ public String typeName() { } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } @@ -139,38 +126,5 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } - - public Query semanticQuery( - InferenceResults inferenceResults, - SearchExecutionContext context, - float boost, - String queryName - ) { - // Cant use QueryBuilders.boolQuery() because a mapper is not registered for .inference, causing - // TermQueryBuilder#doToQuery to fail (at TermQueryBuilder:202) - // TODO: Handle boost and queryName - String fieldName = name() + "." + INFERENCE_CHUNKS_RESULTS; - BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder().setMinimumNumberShouldMatch(1); - - // TODO: Support dense vectors - if (inferenceResults instanceof TextExpansionResults textExpansionResults) { - for (TextExpansionResults.WeightedToken weightedToken : textExpansionResults.getWeightedTokens()) { - queryBuilder.add( - new BoostQuery( - new TermQuery( - new Term(fieldName, weightedToken.token()) - ), - weightedToken.weight() - ), - BooleanClause.Occur.SHOULD - ); - } - } else { - throw new IllegalArgumentException("Unsupported inference results type [" + inferenceResults.getWriteableName() + "]"); - } - - BitSetProducer parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated())); - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, name()); - } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java deleted file mode 100644 index 51c449067f18f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.queries; - -import org.apache.lucene.search.Query; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ParsingException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.query.AbstractQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; - -public class SemanticQueryBuilder extends AbstractQueryBuilder { - public static final String NAME = "semantic_query"; - - private static final ParseField QUERY_FIELD = new ParseField("query"); - - private final String fieldName; - private final String query; - - private SetOnce inferenceResultsSupplier; - - public SemanticQueryBuilder(String fieldName, String query) { - if (fieldName == null) { - throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); - } - if (query == null) { - throw new IllegalArgumentException("[" + NAME + "] requires a " + QUERY_FIELD.getPreferredName() + " value"); - } - this.fieldName = fieldName; - this.query = query; - } - - public SemanticQueryBuilder(StreamInput in) throws IOException { - super(in); - this.fieldName = in.readString(); - this.query = in.readString(); - } - - private SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce inferenceResultsSupplier) { - this.fieldName = other.fieldName; - this.query = other.query; - this.boost = other.boost; - this.queryName = other.queryName; - this.inferenceResultsSupplier = inferenceResultsSupplier; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.SEMANTIC_TEXT_FIELD_ADDED; - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(fieldName); - out.writeString(query); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); - builder.startObject(fieldName); - builder.field(QUERY_FIELD.getPreferredName(), query); - boostAndQueryNameToXContent(builder); - builder.endObject(); - builder.endObject(); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { - if (inferenceResultsSupplier != null) { - return this; - } - - Set modelsForField = queryRewriteContext.getInferenceIdsForField(fieldName); - if (modelsForField.isEmpty()) { - throw new IllegalArgumentException("Field [" + fieldName + "] is not a semantic_text field type"); - } - - if (modelsForField.size() > 1) { - // TODO: Handle multi-index semantic queries - throw new IllegalArgumentException("Field [" + fieldName + "] has multiple models associated with it"); - } - - // TODO: How to determine task type? - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.SPARSE_EMBEDDING, - modelsForField.iterator().next(), - List.of(query), - Map.of(), - InputType.SEARCH - ); - - SetOnce inferenceResultsSupplier = new SetOnce<>(); - queryRewriteContext.registerAsyncAction((client, listener) -> executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsSupplier.set(inferenceResponse.getResults()); - l.onResponse(null); - }) - )); - - return new SemanticQueryBuilder(this, inferenceResultsSupplier); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - InferenceServiceResults inferenceServiceResults = inferenceResultsSupplier.get(); - if (inferenceServiceResults == null) { - throw new IllegalArgumentException("Inference results supplier for field [" + fieldName + "] is empty"); - } - - List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); - if (inferenceResultsList.isEmpty()) { - throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]"); - } else if (inferenceResultsList.size() > 1) { - // TODO: How to handle multiple inference results? - throw new IllegalArgumentException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]"); - } - - InferenceResults inferenceResults = inferenceResultsList.get(0); - MappedFieldType fieldType = context.getFieldType(fieldName); - if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType == false) { - // TODO: Better exception type to throw here? - throw new IllegalArgumentException( - "Field [" + fieldName + "] is not registered as a " + SemanticTextFieldMapper.CONTENT_TYPE + " field type" - ); - } - - return ((SemanticTextFieldMapper.SemanticTextFieldType) fieldType).semanticQuery(inferenceResults, context, boost, queryName); - } - - @Override - protected boolean doEquals(SemanticQueryBuilder other) { - return Objects.equals(fieldName, other.fieldName) - && Objects.equals(query, other.query) - && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); - } - - @Override - protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResultsSupplier); - } - - public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IOException { - String fieldName = null; - String query = null; - float boost = AbstractQueryBuilder.DEFAULT_BOOST; - String queryName = null; - - String currentFieldName = null; - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token == XContentParser.Token.START_OBJECT) { - throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName); - fieldName = currentFieldName; - for (token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token.isValue()) { - if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - query = parser.text(); - } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - boost = parser.floatValue(); - } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - queryName = parser.text(); - } else { - throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" - ); - } - } else { - throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" - ); - } - } - } - } - - if (fieldName == null) { - throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] no field name specified"); - } - if (query == null) { - throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] no query specified"); - } - - SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(fieldName, query); - queryBuilder.queryName(queryName); - queryBuilder.boost(boost); - return queryBuilder; - } -} From d14cc536e135f65afd4658148f9fe73714c01c70 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 20:01:48 +0100 Subject: [PATCH 29/40] Fix rebasing --- .../org/elasticsearch/index/IndexService.java | 3 +- .../query/CoordinatorRewriteContext.java | 1 - .../index/query/QueryRewriteContext.java | 39 +------- .../index/query/SearchExecutionContext.java | 3 +- .../inference/ModelSettings.java | 96 ------------------- .../elasticsearch/search/SearchService.java | 1 - .../search/TransportSearchActionTests.java | 3 +- .../test/AbstractBuilderTestCase.java | 3 +- 8 files changed, 5 insertions(+), 144 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/inference/ModelSettings.java diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 21d3ea932c28d..16a5d153a3c19 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -725,8 +725,7 @@ public QueryRewriteContext newQueryRewriteContext( namedWriteableRegistry, valuesSourceRegistry, allowExpensiveQueries, - scriptService, - null + scriptService ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index ac6512b0839e6..2a1062f8876d2 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -51,7 +51,6 @@ public CoordinatorRewriteContext( null, null, null, - null, null ); this.indexLongFieldRange = indexLongFieldRange; diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index ad0987d399fd7..e36c4d608d59f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -37,7 +37,6 @@ import java.util.function.BooleanSupplier; import java.util.function.LongSupplier; import java.util.function.Predicate; -import java.util.stream.Collectors; /** * Context object used to rewrite {@link QueryBuilder} instances into simplified version. @@ -60,7 +59,6 @@ public class QueryRewriteContext { protected boolean allowUnmappedFields; protected boolean mapUnmappedFieldAsString; protected Predicate allowedFields; - private final Map> modelsForFields; public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, @@ -76,8 +74,7 @@ public QueryRewriteContext( final NamedWriteableRegistry namedWriteableRegistry, final ValuesSourceRegistry valuesSourceRegistry, final BooleanSupplier allowExpensiveQueries, - final ScriptCompiler scriptService, - final Map> modelsForFields + final ScriptCompiler scriptService ) { this.parserConfiguration = parserConfiguration; @@ -95,9 +92,6 @@ public QueryRewriteContext( this.valuesSourceRegistry = valuesSourceRegistry; this.allowExpensiveQueries = allowExpensiveQueries; this.scriptService = scriptService; - this.modelsForFields = modelsForFields != null ? - modelsForFields.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> Set.copyOf(e.getValue()))) : - Collections.emptyMap(); } public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) { @@ -115,36 +109,10 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, - null, null ); } - public QueryRewriteContext( - final XContentParserConfiguration parserConfiguration, - final Client client, - final LongSupplier nowInMillis, - final Map> modelsForFields - ) { - this( - parserConfiguration, - client, - nowInMillis, - null, - MappingLookup.EMPTY, - Collections.emptyMap(), - null, - null, - null, - null, - null, - null, - null, - null, - modelsForFields - ); - } - /** * The registry used to build new {@link XContentParser}s. Contains registered named parsers needed to parse the query. * @@ -377,9 +345,4 @@ public Iterable getAllFieldNames() { ? allFromMapping : () -> Iterators.concat(allFromMapping.iterator(), runtimeMappings.keySet().iterator()); } - - public Set getModelsForField(String fieldName) { - Set models = modelsForFields.get(fieldName); - return models != null ? models : Collections.emptySet(); - } } diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index be175dee804b1..86af6d21b7a09 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -265,8 +265,7 @@ private SearchExecutionContext( namedWriteableRegistry, valuesSourceRegistry, allowExpensiveQueries, - scriptService, - null + scriptService ); this.shardId = shardId; this.shardRequestIndex = shardRequestIndex; diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java deleted file mode 100644 index 10466114873bd..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -public class ModelSettings { - - public static final String NAME = "model_settings"; - public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); - private final TaskType taskType; - private final String inferenceId; - private final Integer dimensions; - private final SimilarityMeasure similarity; - - public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { - Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); - this.taskType = taskType; - this.inferenceId = inferenceId; - this.dimensions = dimensions; - this.similarity = similarity; - } - - public ModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); - } - - public static ModelSettings parse(XContentParser parser) throws IOException { - return PARSER.apply(parser, null); - } - - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new ModelSettings(taskType, inferenceId, dimensions, similarity); - }); - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); - } - - public Map asMap() { - Map attrsMap = new HashMap<>(); - attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - if (dimensions != null) { - attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return Map.of(NAME, attrsMap); - } - - public TaskType taskType() { - return taskType; - } - - public String inferenceId() { - return inferenceId; - } - - public Integer dimensions() { - return dimensions; - } - - public SimilarityMeasure similarity() { - return similarity; - } -} diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index fa4b12f56dd18..70a002d676235 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -19,7 +19,6 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; -import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.search.CanMatchNodeRequest; import org.elasticsearch.action.search.CanMatchNodeResponse; import org.elasticsearch.action.search.SearchRequest; diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 4b4f1490179e4..604d404c2f519 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -124,7 +124,6 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -1718,7 +1717,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { NodeClient client = new NodeClient(settings, threadPool); SearchService searchService = mock(SearchService.class); - when(searchService.getRewriteContext(any(), eq(searchRequest))).thenReturn(new QueryRewriteContext(null, null, null)); + when(searchService.getRewriteContext(any())).thenReturn(new QueryRewriteContext(null, null, null)); ClusterService clusterService = new ClusterService( settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 1d163b2ee7d33..76b836ba7e2a7 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -601,8 +601,7 @@ QueryRewriteContext createQueryRewriteContext() { namedWriteableRegistry, null, () -> true, - scriptService, - null + scriptService ); } From 053080aa15da784f986da3ec3d542601d885e4d9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Mar 2024 20:04:02 +0100 Subject: [PATCH 30/40] Fix rebasing --- x-pack/plugin/inference/src/main/java/module-info.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 09a0adb384c2d..ddd56c758d67c 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -18,7 +18,6 @@ requires org.apache.httpcomponents.httpcore.nio; requires org.apache.lucene.core; requires org.elasticsearch.logging; - requires org.apache.lucene.join; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; From bb76d53845a9a31dc1387e4bd9b6e5415dd3303b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 17:05:03 +0100 Subject: [PATCH 31/40] Rework FieldInferenceMetadata to have a single map instead of multiple maps --- .../metadata/FieldInferenceMetadata.java | 250 ++++++++---------- 1 file changed, 113 insertions(+), 137 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index 5be3c33fca4e4..0d02f0aa43db1 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -11,6 +11,7 @@ import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.Diffable; import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -28,7 +29,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.stream.Collectors; +import java.util.function.Function; /** * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator @@ -36,100 +37,143 @@ */ public class FieldInferenceMetadata implements Diffable, ToXContentFragment { - // Keys: field names. Values: Inference ID associated to the field name for calculating inference - private final ImmutableOpenMap inferenceIdForField; - - // Keys: field names. Values: Field names that provide source for this field (either as copy_to or multifield sources) - private final ImmutableOpenMap> sourceFields; - - // Keys: inference IDs. Values: Field names that use the inference id for calculating inference. Reverse of inferenceForFields. + private final ImmutableOpenMap fieldInferenceMap; private Map> fieldsForInferenceIds; - public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of(), ImmutableOpenMap.of()); + public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); - public static final ParseField INFERENCE_FOR_FIELDS_FIELD = new ParseField("inference_for_fields"); - public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); + public FieldInferenceMetadata(MappingLookup mappingLookup) { + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { + builder.put(entry.getKey(), new FieldInference(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); + }); + fieldInferenceMap = builder.build(); + } - public FieldInferenceMetadata( - ImmutableOpenMap inferenceIdsForFields, - ImmutableOpenMap> sourceFields - ) { - this.inferenceIdForField = Objects.requireNonNull(inferenceIdsForFields); - this.sourceFields = Objects.requireNonNull(sourceFields); + public FieldInferenceMetadata(StreamInput in) throws IOException { + fieldInferenceMap = in.readImmutableOpenMap(StreamInput::readString, FieldInference::new); } - public FieldInferenceMetadata( - Map inferenceIdsForFields, - Map> sourceFields - ) { - this.inferenceIdForField = ImmutableOpenMap.builder(Objects.requireNonNull(inferenceIdsForFields)).build(); - this.sourceFields = ImmutableOpenMap.builder(Objects.requireNonNull(sourceFields)).build(); + public FieldInferenceMetadata(Map fieldsToInferenceMap) { + fieldInferenceMap = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); } - public FieldInferenceMetadata(MappingLookup mappingLookup) { - this.inferenceIdForField = ImmutableOpenMap.builder(mappingLookup.getInferenceIdsForFields()).build(); - ImmutableOpenMap.Builder> sourcePathsBuilder = ImmutableOpenMap.builder(inferenceIdForField.size()); - inferenceIdForField.keySet().forEach(fieldName -> sourcePathsBuilder.put(fieldName, mappingLookup.sourcePaths(fieldName))); - this.sourceFields = sourcePathsBuilder.build(); + public record FieldInference(String inferenceId, Set sourceFields) + implements + SimpleDiffable, + ToXContentFragment { + + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); + + FieldInference(StreamInput in) throws IOException { + this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeStringCollection(sourceFields); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); + builder.endObject(); + return builder; + } + + public static FieldInference fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field_inference_parser", + false, + (args, unused) -> new FieldInference((String) args[0], (Set) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); + } } - public FieldInferenceMetadata(StreamInput in) throws IOException { - inferenceIdForField = in.readImmutableOpenMap(StreamInput::readString, StreamInput::readString); - sourceFields = in.readImmutableOpenMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); + @Override + public Diff diff(FieldInferenceMetadata previousState) { + return new FieldInferenceMetadataDiff(previousState, this); + } + + static class FieldInferenceMetadataDiff implements Diff { + private final Diff> fieldInferenceMapDiff; + + private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(FieldInference::new, FieldInferenceMetadataDiff::readDiffFrom); + + FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { + fieldInferenceMapDiff = DiffableUtils.diff( + before.fieldInferenceMap, + after.fieldInferenceMap, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + FieldInferenceMetadataDiff(StreamInput in) throws IOException { + fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(FieldInference::new, in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + fieldInferenceMapDiff.writeTo(out); + } + + @Override + public FieldInferenceMetadata apply(FieldInferenceMetadata part) { + return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceMap)); + } } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(inferenceIdForField, StreamOutput::writeString); - out.writeMap(sourceFields, StreamOutput::writeStringCollection); + out.writeMap(fieldInferenceMap, (o, v) -> v.writeTo(o)); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(INFERENCE_FOR_FIELDS_FIELD.getPreferredName(), inferenceIdForField); - builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); - + builder.map(fieldInferenceMap); return builder; } - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "field_inference_metadata_parser", - false, - (args, unused) -> new FieldInferenceMetadata((Map) args[0], (Map>) args[1]) - ); - - static { - PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), INFERENCE_FOR_FIELDS_FIELD); - PARSER.declareObject( - ConstructingObjectParser.constructorArg(), - (p, c) -> p.map( - HashMap::new, - v -> v.list().stream().map(Object::toString).collect(Collectors.toSet()) - ), - SOURCE_FIELDS_FIELD - ); - } - public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - @Override - public Diff diff(FieldInferenceMetadata previousState) { - return new FieldInferenceMetadataDiff(previousState, this); + return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInference::fromXContent)); } public String getInferenceIdForField(String field) { - return inferenceIdForField.get(field); + return getInferenceSafe(field, FieldInference::inferenceId); } - public Map getInferenceIdForFields() { - return inferenceIdForField; + private T getInferenceSafe(String field, Function fieldInferenceFunction) { + FieldInference fieldInference = fieldInferenceMap.get(field); + if (fieldInference == null) { + return null; + } + return fieldInferenceFunction.apply(fieldInference); } public Set getSourceFields(String field) { - return sourceFields.get(field); + return getInferenceSafe(field, FieldInference::sourceFields); } public Map> getFieldsForInferenceIds() { @@ -139,9 +183,9 @@ public Map> getFieldsForInferenceIds() { // Cache the result as a field Map> fieldsForInferenceIdsMap = new HashMap<>(); - for (Map.Entry entry : inferenceIdForField.entrySet()) { - String inferenceId = entry.getValue(); + for (Map.Entry entry : fieldInferenceMap.entrySet()) { String fieldName = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); // Get or create the set associated with the inferenceId Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); @@ -157,79 +201,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(inferenceIdForField, that.inferenceIdForField) && Objects.equals(sourceFields, that.sourceFields); + return Objects.equals(fieldInferenceMap, that.fieldInferenceMap); } @Override public int hashCode() { - return Objects.hash(inferenceIdForField, sourceFields); - } - - public static class FieldInferenceMetadataDiff implements Diff { - - private final Diff> inferenceForFields; - private final Diff>> copyFromFields; - - private static final DiffableUtils.NonDiffableValueSerializer STRING_VALUE_SERIALIZER = - new DiffableUtils.NonDiffableValueSerializer<>() { - @Override - public void write(String value, StreamOutput out) throws IOException { - out.writeString(value); - } - - @Override - public String read(StreamInput in, String key) throws IOException { - return in.readString(); - } - }; - - FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { - inferenceForFields = DiffableUtils.diff( - before.inferenceIdForField, - after.inferenceIdForField, - DiffableUtils.getStringKeySerializer(), - STRING_VALUE_SERIALIZER); - copyFromFields = DiffableUtils.diff( - before.sourceFields, - after.sourceFields, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); - } - - FieldInferenceMetadataDiff(StreamInput in) throws IOException { - inferenceForFields = DiffableUtils.readImmutableOpenMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - STRING_VALUE_SERIALIZER - ); - copyFromFields = DiffableUtils.readImmutableOpenMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); - } - - public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( - FieldInferenceMetadata.EMPTY, - FieldInferenceMetadata.EMPTY - ) { - @Override - public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - return part; - } - }; - @Override - public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - ImmutableOpenMap modelForFields = this.inferenceForFields.apply(part.inferenceIdForField); - ImmutableOpenMap> copyFromFields = this.copyFromFields.apply(part.sourceFields); - return new FieldInferenceMetadata(modelForFields, copyFromFields); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - inferenceForFields.writeTo(out); - copyFromFields.writeTo(out); - } + return Objects.hash(fieldInferenceMap); } } From 83ccfd2d75ac18d4bde835d1f32f72045b595ed5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 17:18:11 +0100 Subject: [PATCH 32/40] Serialization fixes --- .../cluster/metadata/IndexMetadata.java | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 42b60afa07e35..1060addd008d1 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -734,7 +734,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); + this.fieldInferenceMetadata = fieldInferenceMetadata; } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -1496,7 +1496,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final FieldInferenceMetadata.FieldInferenceMetadataDiff fieldInferenceMetadata; + private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1533,10 +1533,7 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldInferenceMetadata = new FieldInferenceMetadata.FieldInferenceMetadataDiff( - before.fieldInferenceMetadata, - after.fieldInferenceMetadata - ); + fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1597,9 +1594,9 @@ private static class IndexMetadataDiff implements Diff { shardSizeInBytesForecast = null; } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldInferenceMetadata = new FieldInferenceMetadata.FieldInferenceMetadataDiff(in); + fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); } else { - fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; + fieldInferenceMetadata = null; } } @@ -1637,7 +1634,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldInferenceMetadata.writeTo(out); + out.writeOptionalWriteable(fieldInferenceMetadata); } } @@ -1737,9 +1734,7 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function(); this.rolloverInfos = ImmutableOpenMap.builder(); - this.fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; this.isSystem = false; } @@ -2108,7 +2097,7 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { } public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { - this.fieldInferenceMetadata = fieldInferenceMetadata; + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); return this; } @@ -2433,11 +2422,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldInferenceMetadata != FieldInferenceMetadata.EMPTY) { - builder.startObject(KEY_FIELD_INFERENCE_METADATA); - indexMetadata.fieldInferenceMetadata.toXContent(builder, params); - builder.endObject(); - } + builder.field(KEY_FIELD_INFERENCE_METADATA, indexMetadata.fieldInferenceMetadata); builder.endObject(); } From a222d7922cea4c80263c37db6273a75667a82718 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 17:18:47 +0100 Subject: [PATCH 33/40] Test fixes --- .../cluster/ClusterStateDiffIT.java | 20 +----- .../action/bulk/BulkOperationTests.java | 61 +++++++++++-------- .../cluster/metadata/IndexMetadataTests.java | 42 ++++++------- .../index/mapper/FieldTypeLookupTests.java | 16 ++--- .../index/mapper/MappingLookupTests.java | 11 ++-- 5 files changed, 74 insertions(+), 76 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 404cfe3dc560b..0ca959f3b49e4 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.metadata.IndexGraveyard; import org.elasticsearch.cluster.metadata.IndexGraveyardTests; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexMetadataTests; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.RepositoriesMetadata; @@ -54,7 +55,6 @@ import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldInferenceMetadata(randomFieldsForModels()); + builder.fieldInferenceMetadata(IndexMetadataTests.randomFieldInferenceMetadata(true)); break; default: throw new IllegalArgumentException("Shouldn't be here"); @@ -598,22 +598,6 @@ public IndexMetadata randomChange(IndexMetadata part) { /** * Generates a random fieldsForModels map */ - private Map> randomFieldsForModels() { - if (randomBoolean()) { - return null; - } - - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } - - return fieldsForModels; - } }); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 38860a6ce91cd..ceb6db0f6014a 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -92,7 +93,7 @@ public class BulkOperationTests extends ESTestCase { public void testNoInference() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry( Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); @@ -116,7 +117,7 @@ public void testNoInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -158,7 +159,7 @@ private static Model mockModel(String inferenceServiceId) { public void testFailedBulkShardRequest() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry(Map.of()); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); @@ -176,7 +177,7 @@ public void testFailedBulkShardRequest() { runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -206,11 +207,15 @@ public void testFailedBulkShardRequest() { @SuppressWarnings("unchecked") public void testInference() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry( @@ -244,7 +249,7 @@ public void testInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -279,7 +284,9 @@ public void testInference() { public void testFailedInference() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -298,7 +305,7 @@ public void testFailedInference() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -313,7 +320,9 @@ public void testFailedInference() { public void testInferenceFailsForIncorrectRootObject() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -331,7 +340,7 @@ public void testInferenceFailsForIncorrectRootObject() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -343,11 +352,15 @@ public void testInferenceFailsForIncorrectRootObject() { public void testInferenceIdNotFound() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -368,7 +381,7 @@ public void testInferenceIdNotFound() { ActionListener bulkOperationListener = mock(ActionListener.class); doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -444,7 +457,7 @@ public String toString() { private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, boolean expectTransportShardBulkActionToExecute, @@ -452,7 +465,7 @@ private static BulkShardRequest runBulkOperation( ) { return runBulkOperation( docSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -463,7 +476,7 @@ private static BulkShardRequest runBulkOperation( private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener, @@ -472,7 +485,7 @@ private static BulkShardRequest runBulkOperation( ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldInferenceMetadata(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .settings(settings) .numberOfShards(1) .numberOfReplicas(0) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index dbe3c25044f82..9e567e7af3f80 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -40,7 +41,6 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,7 +83,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - Map> fieldsForModels = randomFieldsForModels(true); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -108,7 +108,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldInferenceMetadata(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .build(); assertEquals(system, metadata.isSystem()); @@ -142,7 +142,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -166,7 +166,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); } } @@ -550,35 +550,35 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldsForModels() { + public void testFieldInferenceMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); + assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); - Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldsForModels).build(); - assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); + assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - private static Map> randomFieldsForModels(boolean allowNull) { - if (allowNull && randomBoolean()) { + public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { + if (randomBoolean() && allowNull) { return null; } - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } + Map fieldInferenceMap = randomMap( + 0, + 10, + () -> new Tuple<>(randomIdentifier(), randomFieldInference()) + ); + return new FieldInferenceMetadata(fieldInferenceMap); + } - return fieldsForModels; + private static FieldInferenceMetadata.FieldInference randomFieldInference() { + return new FieldInferenceMetadata.FieldInference(randomAlphaOfLength(5), randomSet(0, 5, () -> randomIdentifier())); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 9cd33e36b1d71..932eac3e60d27 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -37,7 +37,7 @@ public void testEmpty() { assertNotNull(names); assertThat(names, hasSize(0)); - Map> fieldsForModels = lookup.getInferenceIdsForFields(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -48,7 +48,7 @@ public void testAddNewField() { assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - Map> fieldsForModels = lookup.getInferenceIdsForFields(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -440,11 +440,13 @@ public void testInferenceModelFieldType() { assertEquals(f2.fieldType(), lookup.get("foo2")); assertEquals(f3.fieldType(), lookup.get("foo3")); - Map> fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertEquals(2, fieldsForModels.size()); - assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1")); - assertEquals(Set.of("foo3"), fieldsForModels.get("bar2")); + Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(3, inferenceIdsForFields.size()); + + assertEquals("bar1", inferenceIdsForFields.get("foo1")); + assertEquals("bar1", inferenceIdsForFields.get("foo2")); + assertEquals("bar2", inferenceIdsForFields.get("foo3")); } private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index 028bf5c864d11..bb337d0c61c93 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,7 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -191,7 +190,7 @@ public MetricType getMetricType() { ); } - public void testFieldsForModels() { + public void testInferenceIdsForFields() { MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); MappingLookup mappingLookup = createMappingLookup( Collections.singletonList(new MockFieldMapper(fieldType)), @@ -201,10 +200,10 @@ public void testFieldsForModels() { assertEquals(1, size(mappingLookup.fieldMappers())); assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - Map> fieldsForModels = mappingLookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertEquals(1, fieldsForModels.size()); - assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); + Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(1, inferenceIdsForFields.size()); + assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); } private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { From a026612b101067796674f3dec14211f70992cc96 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 17:19:08 +0100 Subject: [PATCH 34/40] Spotless --- .../action/bulk/BulkShardRequestInferenceProvider.java | 3 ++- .../elasticsearch/action/search/TransportSearchAction.java | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 7fe760f402e0f..5e77f5c1faffd 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -136,7 +136,8 @@ public void processBulkShardRequest( Map> fieldsForInferenceIds = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) - .getFieldInferenceMetadata().getFieldsForInferenceIds(); + .getFieldInferenceMetadata() + .getFieldsForInferenceIds(); // No inference fields? Terminate early if (fieldsForInferenceIds.isEmpty()) { listener.onResponse(bulkShardRequest); diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index d1bd5bd38d1b5..0922e15999e8c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -453,11 +453,7 @@ void executeRequest( } } }); - Rewriteable.rewriteAndFetch( - original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis), - rewriteListener - ); + Rewriteable.rewriteAndFetch(original, searchService.getRewriteContext(timeProvider::absoluteStartMillis), rewriteListener); } static void adjustSearchType(SearchRequest searchRequest, boolean singleShard) { From f56db05bc0b350ca6cb21b90e24596a680493bd1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 17:44:44 +0100 Subject: [PATCH 35/40] Fix serialization issues when inference is empty --- .../cluster/metadata/FieldInferenceMetadata.java | 4 ++++ .../org/elasticsearch/cluster/metadata/IndexMetadata.java | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index 0d02f0aa43db1..08a570b9980dc 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -58,6 +58,10 @@ public FieldInferenceMetadata(Map fieldsToInferenceMap) fieldInferenceMap = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); } + public boolean isEmpty() { + return fieldInferenceMap.isEmpty(); + } + public record FieldInference(String inferenceId, Set sourceFields) implements SimpleDiffable, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 1060addd008d1..842edc5bf9f1e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -734,7 +734,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldInferenceMetadata = fieldInferenceMetadata; + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -2422,7 +2422,9 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - builder.field(KEY_FIELD_INFERENCE_METADATA, indexMetadata.fieldInferenceMetadata); + if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { + builder.field(KEY_FIELD_INFERENCE_METADATA, indexMetadata.fieldInferenceMetadata); + } builder.endObject(); } From 0edc8d3668ab747d83bf5f8888a1ca074f0e14f5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 19:12:30 +0100 Subject: [PATCH 36/40] Fix test --- .../cluster/metadata/SemanticTextClusterMetadataTests.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 69fa64ffa6d1c..deddbc60ef10f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -20,8 +20,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Set; public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @@ -35,7 +33,7 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") ); - assertEquals(Map.of("test_model", Set.of("field")), indexService.getMetadata().getFieldsForModels()); + assertEquals(indexService.getMetadata().getFieldInferenceMetadata().getInferenceIdForField("field"), "test_model"); } public void testAddSemanticTextField() throws Exception { @@ -52,7 +50,7 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals(Map.of("test_model", Set.of("field")), resultingState.metadata().index("test").getFieldsForModels()); + assertEquals(resultingState.metadata().index("test").getFieldInferenceMetadata().getInferenceIdForField("field"), "test_model"); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From ba6f00fe3b9fb57d69d4d39124caf30846e319bb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 19:48:05 +0100 Subject: [PATCH 37/40] Use empty diff state to avoid bwc errors --- .../cluster/metadata/FieldInferenceMetadata.java | 9 +++++++++ .../elasticsearch/cluster/metadata/IndexMetadata.java | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index 08a570b9980dc..844c35cdc42f2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -108,10 +108,19 @@ public static FieldInference fromXContent(XContentParser parser) throws IOExcept @Override public Diff diff(FieldInferenceMetadata previousState) { + if (previousState == null) { + previousState = EMPTY; + } return new FieldInferenceMetadataDiff(previousState, this); } static class FieldInferenceMetadataDiff implements Diff { + + public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( + FieldInferenceMetadata.EMPTY, + FieldInferenceMetadata.EMPTY + ); + private final Diff> fieldInferenceMapDiff; private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 842edc5bf9f1e..3e04406f68409 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1596,7 +1596,7 @@ private static class IndexMetadataDiff implements Diff { if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); } else { - fieldInferenceMetadata = null; + fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; } } From f3a6af0b372dc067c68fdedc1af93529385b2b8a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Mar 2024 20:55:16 +0100 Subject: [PATCH 38/40] Fix parsing error, styling --- .../cluster/ClusterStateDiffIT.java | 8 +- .../metadata/FieldInferenceMetadata.java | 173 +++++++++--------- .../cluster/metadata/IndexMetadata.java | 6 +- .../cluster/metadata/IndexMetadataTests.java | 2 +- 4 files changed, 94 insertions(+), 95 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 0ca959f3b49e4..fbb3016b925da 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.cluster.metadata.IndexGraveyard; import org.elasticsearch.cluster.metadata.IndexGraveyardTests; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.IndexMetadataTests; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.RepositoriesMetadata; @@ -62,6 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,17 +587,13 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldInferenceMetadata(IndexMetadataTests.randomFieldInferenceMetadata(true)); + builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } - - /** - * Generates a random fieldsForModels map - */ }); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index 844c35cdc42f2..0e36dd6246c98 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -38,7 +39,9 @@ public class FieldInferenceMetadata implements Diffable, ToXContentFragment { private final ImmutableOpenMap fieldInferenceMap; - private Map> fieldsForInferenceIds; + + // Contains a lazily cached, reversed map of inferenceId -> fields + private volatile Map> fieldsForInferenceIds; public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); @@ -62,48 +65,68 @@ public boolean isEmpty() { return fieldInferenceMap.isEmpty(); } - public record FieldInference(String inferenceId, Set sourceFields) - implements - SimpleDiffable, - ToXContentFragment { + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(fieldInferenceMap, (o, v) -> v.writeTo(o)); + } - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.map(fieldInferenceMap); + return builder; + } - FieldInference(StreamInput in) throws IOException { - this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); - } + public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { + return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInference::fromXContent)); + } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - out.writeStringCollection(sourceFields); - } + public String getInferenceIdForField(String field) { + return getInferenceSafe(field, FieldInference::inferenceId); + } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); - builder.endObject(); - return builder; + private T getInferenceSafe(String field, Function fieldInferenceFunction) { + FieldInference fieldInference = fieldInferenceMap.get(field); + if (fieldInference == null) { + return null; } + return fieldInferenceFunction.apply(fieldInference); + } - public static FieldInference fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + public Set getSourceFields(String field) { + return getInferenceSafe(field, FieldInference::sourceFields); + } + + public Map> getFieldsForInferenceIds() { + if (fieldsForInferenceIds != null) { + return fieldsForInferenceIds; } - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "field_inference_parser", - false, - (args, unused) -> new FieldInference((String) args[0], (Set) args[1]) - ); + // Cache the result as a field + Map> fieldsForInferenceIdsMap = new HashMap<>(); + for (Map.Entry entry : fieldInferenceMap.entrySet()) { + String fieldName = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); + // Get or create the set associated with the inferenceId + Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); + fields.add(fieldName); } + + fieldsForInferenceIds = Collections.unmodifiableMap(fieldsForInferenceIdsMap); + return fieldsForInferenceIds; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FieldInferenceMetadata that = (FieldInferenceMetadata) o; + return Objects.equals(fieldInferenceMap, that.fieldInferenceMap); + } + + @Override + public int hashCode() { + return Objects.hash(fieldInferenceMap); } @Override @@ -158,67 +181,47 @@ public FieldInferenceMetadata apply(FieldInferenceMetadata part) { } } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeMap(fieldInferenceMap, (o, v) -> v.writeTo(o)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(fieldInferenceMap); - return builder; - } - - public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInference::fromXContent)); - } + public record FieldInference(String inferenceId, Set sourceFields) + implements + SimpleDiffable, + ToXContentFragment { - public String getInferenceIdForField(String field) { - return getInferenceSafe(field, FieldInference::inferenceId); - } + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); - private T getInferenceSafe(String field, Function fieldInferenceFunction) { - FieldInference fieldInference = fieldInferenceMap.get(field); - if (fieldInference == null) { - return null; + FieldInference(StreamInput in) throws IOException { + this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); } - return fieldInferenceFunction.apply(fieldInference); - } - public Set getSourceFields(String field) { - return getInferenceSafe(field, FieldInference::sourceFields); - } - - public Map> getFieldsForInferenceIds() { - if (fieldsForInferenceIds != null) { - return fieldsForInferenceIds; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeStringCollection(sourceFields); } - // Cache the result as a field - Map> fieldsForInferenceIdsMap = new HashMap<>(); - for (Map.Entry entry : fieldInferenceMap.entrySet()) { - String fieldName = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); - - // Get or create the set associated with the inferenceId - Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); - fields.add(fieldName); + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); + builder.endObject(); + return builder; } - fieldsForInferenceIds = Collections.unmodifiableMap(fieldsForInferenceIdsMap); - return fieldsForInferenceIds; - } + public static FieldInference fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(fieldInferenceMap, that.fieldInferenceMap); - } + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field_inference_parser", + false, + (args, unused) -> new FieldInference((String) args[0], new HashSet<>((List) args[1])) + ); - @Override - public int hashCode() { - return Objects.hash(fieldInferenceMap); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); + } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 3e04406f68409..89c925427cf88 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELD_INFERENCE_METADATA = "field_inference_metadata"; + public static final String KEY_FIELD_INFERENCE = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -2423,7 +2423,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build } if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { - builder.field(KEY_FIELD_INFERENCE_METADATA, indexMetadata.fieldInferenceMetadata); + builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); } builder.endObject(); @@ -2503,7 +2503,7 @@ public static IndexMetadata fromXContent(XContentParser parser, Map randomIdentifier())); + return new FieldInferenceMetadata.FieldInference(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { From 88af8612f03393623c0f0acf4ef8a7b1098f7f14 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 18 Mar 2024 10:48:23 +0100 Subject: [PATCH 39/40] Remove accessors for FieldInferenceMetadata, use the map instead. Rename FieldInference to FieldInferenceOptions --- .../BulkShardRequestInferenceProvider.java | 33 +++++-- .../metadata/FieldInferenceMetadata.java | 99 ++++++------------- .../action/bulk/BulkOperationTests.java | 16 +-- .../cluster/metadata/IndexMetadataTests.java | 6 +- .../SemanticTextClusterMetadataTests.java | 10 +- 5 files changed, 74 insertions(+), 90 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 5e77f5c1faffd..6dc4804eee9fe 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; @@ -75,11 +76,10 @@ public static void getInstance( Set shardIds, ActionListener listener ) { - Set inferenceIds = new HashSet<>(); - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForInferenceIds = clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldsForInferenceIds(); - inferenceIds.addAll(fieldsForInferenceIds.keySet()); - }); + Set inferenceIds = + shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream() + .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) + .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)).collect(Collectors.toSet()); final Map inferenceProviderMap = new ConcurrentHashMap<>(); Runnable onModelLoadingComplete = () -> listener.onResponse( new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) @@ -134,10 +134,9 @@ public void processBulkShardRequest( BiConsumer onBulkItemFailure ) { - Map> fieldsForInferenceIds = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldInferenceMetadata() - .getFieldsForInferenceIds(); + Map> fieldsForInferenceIds = getFieldsForInferenceIds( + clusterState.metadata().index(bulkShardRequest.shardId().getIndex()).getFieldInferenceMetadata().getFieldInferenceOptions() + ); // No inference fields? Terminate early if (fieldsForInferenceIds.isEmpty()) { listener.onResponse(bulkShardRequest); @@ -187,6 +186,22 @@ public void processBulkShardRequest( } } + private static Map> getFieldsForInferenceIds( + Map fieldInferenceMap + ) { + Map> fieldsForInferenceIdsMap = new HashMap<>(); + for (Map.Entry entry : fieldInferenceMap.entrySet()) { + String fieldName = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + + // Get or create the set associated with the inferenceId + Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); + fields.add(fieldName); + } + + return fieldsForInferenceIdsMap; + } + @SuppressWarnings("unchecked") private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java index 0e36dd6246c98..349706c139127 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -23,14 +23,12 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.function.Function; /** * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator @@ -38,82 +36,47 @@ */ public class FieldInferenceMetadata implements Diffable, ToXContentFragment { - private final ImmutableOpenMap fieldInferenceMap; - - // Contains a lazily cached, reversed map of inferenceId -> fields - private volatile Map> fieldsForInferenceIds; + private final ImmutableOpenMap fieldInferenceOptions; public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); public FieldInferenceMetadata(MappingLookup mappingLookup) { - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { - builder.put(entry.getKey(), new FieldInference(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); + builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); }); - fieldInferenceMap = builder.build(); + fieldInferenceOptions = builder.build(); } public FieldInferenceMetadata(StreamInput in) throws IOException { - fieldInferenceMap = in.readImmutableOpenMap(StreamInput::readString, FieldInference::new); + fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); + } + + public FieldInferenceMetadata(Map fieldsToInferenceMap) { + fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); } - public FieldInferenceMetadata(Map fieldsToInferenceMap) { - fieldInferenceMap = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); + public ImmutableOpenMap getFieldInferenceOptions() { + return fieldInferenceOptions; } public boolean isEmpty() { - return fieldInferenceMap.isEmpty(); + return fieldInferenceOptions.isEmpty(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(fieldInferenceMap, (o, v) -> v.writeTo(o)); + out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(fieldInferenceMap); + builder.map(fieldInferenceOptions); return builder; } public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInference::fromXContent)); - } - - public String getInferenceIdForField(String field) { - return getInferenceSafe(field, FieldInference::inferenceId); - } - - private T getInferenceSafe(String field, Function fieldInferenceFunction) { - FieldInference fieldInference = fieldInferenceMap.get(field); - if (fieldInference == null) { - return null; - } - return fieldInferenceFunction.apply(fieldInference); - } - - public Set getSourceFields(String field) { - return getInferenceSafe(field, FieldInference::sourceFields); - } - - public Map> getFieldsForInferenceIds() { - if (fieldsForInferenceIds != null) { - return fieldsForInferenceIds; - } - - // Cache the result as a field - Map> fieldsForInferenceIdsMap = new HashMap<>(); - for (Map.Entry entry : fieldInferenceMap.entrySet()) { - String fieldName = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); - - // Get or create the set associated with the inferenceId - Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); - fields.add(fieldName); - } - - fieldsForInferenceIds = Collections.unmodifiableMap(fieldsForInferenceIdsMap); - return fieldsForInferenceIds; + return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); } @Override @@ -121,12 +84,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(fieldInferenceMap, that.fieldInferenceMap); + return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); } @Override public int hashCode() { - return Objects.hash(fieldInferenceMap); + return Objects.hash(fieldInferenceOptions); } @Override @@ -144,15 +107,15 @@ static class FieldInferenceMetadataDiff implements Diff FieldInferenceMetadata.EMPTY ); - private final Diff> fieldInferenceMapDiff; + private final Diff> fieldInferenceMapDiff; - private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = - new DiffableUtils.DiffableValueReader<>(FieldInference::new, FieldInferenceMetadataDiff::readDiffFrom); + private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { fieldInferenceMapDiff = DiffableUtils.diff( - before.fieldInferenceMap, - after.fieldInferenceMap, + before.fieldInferenceOptions, + after.fieldInferenceOptions, DiffableUtils.getStringKeySerializer(), FIELD_INFERENCE_DIFF_VALUE_READER ); @@ -166,8 +129,8 @@ static class FieldInferenceMetadataDiff implements Diff ); } - public static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(FieldInference::new, in); + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); } @Override @@ -177,19 +140,19 @@ public void writeTo(StreamOutput out) throws IOException { @Override public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceMap)); + return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); } } - public record FieldInference(String inferenceId, Set sourceFields) + public record FieldInferenceOptions(String inferenceId, Set sourceFields) implements - SimpleDiffable, + SimpleDiffable, ToXContentFragment { public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); - FieldInference(StreamInput in) throws IOException { + FieldInferenceOptions(StreamInput in) throws IOException { this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); } @@ -208,15 +171,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static FieldInference fromXContent(XContentParser parser) throws IOException { + public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "field_inference_parser", false, - (args, unused) -> new FieldInference((String) args[0], new HashSet<>((List) args[1])) + (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) ); static { diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index ceb6db0f6014a..c3887f506b891 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -210,11 +210,11 @@ public void testInference() { FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_2_ID, Set.of()) + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) ) ); @@ -285,7 +285,7 @@ public void testInference() { public void testFailedInference() { FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of())) + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -321,7 +321,7 @@ public void testFailedInference() { public void testInferenceFailsForIncorrectRootObject() { FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of())) + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -355,11 +355,11 @@ public void testInferenceIdNotFound() { FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_1_ID, Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInference(INFERENCE_SERVICE_2_ID, Set.of()) + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) ) ); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 7e9550b1d8039..2392cd1040870 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -569,7 +569,7 @@ public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowN return null; } - Map fieldInferenceMap = randomMap( + Map fieldInferenceMap = randomMap( 0, 10, () -> new Tuple<>(randomIdentifier(), randomFieldInference()) @@ -577,8 +577,8 @@ public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowN return new FieldInferenceMetadata(fieldInferenceMap); } - private static FieldInferenceMetadata.FieldInference randomFieldInference() { - return new FieldInferenceMetadata.FieldInference(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); + private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { + return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index deddbc60ef10f..a7d3fcce26116 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -33,7 +33,10 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") ); - assertEquals(indexService.getMetadata().getFieldInferenceMetadata().getInferenceIdForField("field"), "test_model"); + assertEquals( + indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } public void testAddSemanticTextField() throws Exception { @@ -50,7 +53,10 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals(resultingState.metadata().index("test").getFieldInferenceMetadata().getInferenceIdForField("field"), "test_model"); + assertEquals( + resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From 3b8db712a65393bde110b85ecb84cc7d5554773d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 18 Mar 2024 11:25:54 +0100 Subject: [PATCH 40/40] Spotless --- .../bulk/BulkShardRequestInferenceProvider.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 6dc4804eee9fe..e80530f75cf4b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -76,10 +76,13 @@ public static void getInstance( Set shardIds, ActionListener listener ) { - Set inferenceIds = - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream() - .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) - .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)).collect(Collectors.toSet()); + Set inferenceIds = shardIds.stream() + .map(ShardId::getIndex) + .collect(Collectors.toSet()) + .stream() + .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) + .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)) + .collect(Collectors.toSet()); final Map inferenceProviderMap = new ConcurrentHashMap<>(); Runnable onModelLoadingComplete = () -> listener.onResponse( new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap)