From ba06e01b6f0d090642d1d79182cfb60a2aa26fe2 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 30 Apr 2024 12:09:51 +0200 Subject: [PATCH] Add inference calculation for semantic_text --- .../action/bulk/BulkOperation.java | 4 + .../action/bulk/BulkShardRequest.java | 28 + x-pack/plugin/inference/build.gradle | 2 +- .../mock/AbstractTestInferenceService.java | 5 - .../TestDenseInferenceServiceExtension.java | 2 +- .../TestSparseInferenceServiceExtension.java | 12 +- .../xpack/inference/InferencePlugin.java | 16 + .../ShardBulkInferenceActionFilter.java | 534 +++++++++++++++++ .../inference/mapper/SemanticTextField.java | 76 +++ .../mapper/SemanticTextFieldMapper.java | 81 --- .../ShardBulkInferenceActionFilterTests.java | 386 ++++++++++++ .../mapper/SemanticTextFieldTests.java | 2 +- .../xpack/inference/InferenceRestIT.java | 3 +- .../inference/30_semantic_text_inference.yml | 551 ++++++++++++++++++ .../CoordinatedInferenceIngestIT.java | 4 +- 15 files changed, 1608 insertions(+), 98 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index fcad07d0696f3..7356dc0ea140e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -298,6 +298,10 @@ private void executeBulkRequestsByShard( bulkRequest.getRefreshPolicy(), requests.toArray(new BulkItemRequest[0]) ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); + } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..8d1618b443ace 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -22,6 +23,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -33,6 +35,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest inferenceFieldMap = null; + public BulkShardRequest(StreamInput in) throws IOException { super(in); items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +48,30 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. + */ + public void setInferenceFieldMap(Map fieldInferenceMap) { + this.inferenceFieldMap = fieldInferenceMap; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public Map consumeInferenceFieldMap() { + Map ret = inferenceFieldMap; + inferenceFieldMap = null; + return ret; + } + + /** + * Public for test + */ + public Map getInferenceFieldMap() { + return inferenceFieldMap; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 0aef8601ffcc6..48b6156a43039 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -10,7 +10,7 @@ apply plugin: 'elasticsearch.internal-yaml-rest-test' restResources { restApi { - include '_common', 'indices', 'inference', 'index' + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex', 'search' } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 99dfc9582eb05..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -101,11 +101,6 @@ public TestServiceModel( super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); } - @Override - public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() { - return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings(); - } - @Override public TestTaskSettings getTaskSettings() { return (TestTaskSettings) super.getTaskSettings(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index c81dbdc45463c..562b58887432a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -172,7 +172,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index b13e65d1ba802..38a21209f59b4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -121,7 +121,7 @@ private SparseEmbeddingResults makeResults(List input) { for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F)); } embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); } @@ -129,15 +129,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1afe3c891db80..34459c3beff95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -45,6 +46,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -76,6 +78,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.Collections.singletonList; + public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** @@ -101,6 +105,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -166,6 +171,9 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); + return List.of(modelRegistry, registry); } @@ -272,4 +280,12 @@ public Map getMappers() { } return Map.of(); } + + @Override + public Collection getActionFilters() { + if (SemanticTextFeature.isEnabled()) { + return singletonList(shardBulkInferenceActionFilter.get()); + } + return List.of(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..6a8a278d85f75 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,534 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; + +/** + * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified + * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in + * the request source, we generate embeddings and include the results in the source under the semantic text field + * name as a {@link SemanticTextField}. + * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the + * results during indexing on the shard. + * + * TODO: batchSize should be configurable via a cluster setting + */ +public class ShardBulkInferenceActionFilter implements MappedActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + protected static final int DEFAULT_BATCH_SIZE = 512; + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + private final int batchSize; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + } + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + this.batchSize = batchSize; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public String actionName() { + return TransportShardBulkAction.ACTION_NAME; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + Map fieldInferenceMap, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + /** + * A field inference request on a single input. + * @param index The index of the request in the original bulk request. + * @param field The target field. + * @param input The input to run inference on. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + */ + private record FieldInferenceRequest(int index, String field, String input, int inputOrder, boolean isOriginalFieldInput) {} + + /** + * The field inference response. + * @param field The target field. + * @param input The input that was used to run inference. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + * @param model The model used to run inference. + * @param chunkedResults The actual results. + */ + private record FieldInferenceResponse( + String field, + String input, + int inputOrder, + boolean isOriginalFieldInput, + Model model, + ChunkedInferenceServiceResults chunkedResults + ) {} + + private record FieldInferenceResponseAccumulator( + int id, + Map> responses, + List failures + ) { + void addOrUpdateResponse(FieldInferenceResponse response) { + synchronized (this) { + var list = responses.computeIfAbsent(response.field, k -> new ArrayList<>()); + list.add(response); + } + } + + void addFailure(Exception exc) { + synchronized (this) { + failures.add(exc); + } + } + } + + private class AsyncBulkShardInferenceAction implements Runnable { + private final Map fieldInferenceMap; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + Map fieldInferenceMap, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMap = fieldInferenceMap; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.index).failures.add( + new ResourceNotFoundException( + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.index).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + int currentBatchSize = Math.min(requests.size(), batchSize); + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + try { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.index); + if (result instanceof ErrorChunkedInferenceResults error) { + acc.addFailure( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + error.getException(), + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } else { + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.input(), + request.inputOrder(), + request.isOriginalFieldInput(), + inferenceProvider.model, + result + ) + ); + } + } + } finally { + onFinish(); + } + } + + @Override + public void onFailure(Exception exc) { + try { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + addInferenceResponseFailure( + request.index, + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } finally { + onFinish(); + } + } + + private void onFinish() { + if (nextBatch.isEmpty()) { + onFinish.close(); + } else { + executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + null, + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + TimeValue.MAX_VALUE, + completionListener + ); + } + + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + FieldInferenceResponseAccumulator acc = inferenceResults.get(id); + if (acc == null) { + acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); + inferenceResults.set(id, acc); + } + return acc; + } + + private void addInferenceResponseFailure(int id, Exception failure) { + var acc = ensureResponseAccumulatorSlot(id); + acc.addFailure(failure); + } + + /** + * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is marked as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results under the + * {@link SemanticTextField#INFERENCE_FIELD} field. + */ + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } + return; + } + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : response.responses.entrySet()) { + var fieldName = entry.getKey(); + var responses = entry.getValue(); + var model = responses.get(0).model(); + // ensure that the order in the original field is consistent in case of multiple inputs + Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); + List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + var result = new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), results, indexRequest.getContentType()) + ), + indexRequest.getContentType() + ); + newDocMap.put(fieldName, result); + } + indexRequest.source(newDocMap, indexRequest.getContentType()); + } + + /** + * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. + * If results are already populated for fields in the original index request, the inference request for this specific + * field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, + * where an error will be thrown if they mismatch or if the content is malformed. + *

+ * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? + */ + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + int itemIndex = 0; + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security) + continue; + } + boolean isUpdateRequest = false; + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; + if (updateRequest.script() != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + continue; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + var originalFieldValue = XContentMapValues.extractValue(field, docMap); + if (originalFieldValue instanceof Map) { + continue; + } + int order = 0; + for (var sourceField : entry.getSourceFields()) { + boolean isOriginalFieldInput = sourceField.equals(field); + var valueObj = XContentMapValues.extractValue(sourceField, docMap); + if (valueObj == null) { + if (isUpdateRequest) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Field [{}] must be specified on an update request to calculate inference for field [{}]", + RestStatus.BAD_REQUEST, + sourceField, + field + ) + ); + break; + } + continue; + } + ensureResponseAccumulatorSlot(itemIndex); + final List values; + try { + values = nodeStringValues(field, valueObj); + } catch (Exception exc) { + addInferenceResponseFailure(item.id(), exc); + break; + } + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + for (var v : values) { + fieldRequests.add(new FieldInferenceRequest(itemIndex, field, v, order++, isOriginalFieldInput)); + } + } + } + itemIndex++; + } + return fieldRequestsMap; + } + } + + /** + * This method converts the given {@code valueObj} into a list of strings. + * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. + */ + private static List nodeStringValues(String field, Object valueObj) { + if (valueObj instanceof String value) { + return List.of(value); + } else if (valueObj instanceof Collection values) { + List valuesString = new ArrayList<>(); + for (var v : values) { + if (v instanceof String value) { + valuesString.add(value); + } else { + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + } + return valuesString; + } + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 3a435f3e3276d..33ef5cd0b17e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -8,26 +8,35 @@ package org.elasticsearch.xpack.inference.mapper; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -271,4 +280,71 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD)); MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD)); } + + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + */ + public static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index cd2fbc9c7be10..c4293d16ce6a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -8,11 +8,9 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Explicit; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; @@ -37,19 +35,11 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentLocation; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; @@ -275,77 +265,6 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields); } - /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link SemanticTextField.Chunk}. - */ - static List toSemanticTextFieldChunks( - String field, - String inferenceId, - List results, - XContentType contentType - ) { - List chunks = new ArrayList<>(); - for (var result : results) { - if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add( - new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens())) - ); - } - } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add( - new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding())) - ); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - inferenceId, - result.getWriteableName() - ); - } - } - return chunks; - } - - /** - * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, double[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (double v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } - public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final SemanticTextField.ModelSettings modelSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..c87faa2b52cc8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,386 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getInferenceFieldMap()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) + ); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = StaticModel.createRandomInstance(); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + "field2", + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + "field3", + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(Map.of(), inferenceFieldMap)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testItemFailures() throws Exception { + StaticModel model = StaticModel.createRandomInstance(); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); + model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(3)); + + // item 0 is a failure + assertNotNull(bulkShardRequest.items()[0].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[0].getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + + // item 1 is a success + assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); + IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); + assertThat(XContentMapValues.extractValue("field1.text", actualRequest.sourceAsMap()), equalTo("I am a success")); + + // item 2 is a failure + assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); + failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[3]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = StaticModel.createRandomInstance(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map inferenceFieldMap = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); + } + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(inferenceModelMap, inferenceFieldMap); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[2]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[7]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + Map modelMap, + Map fieldInferenceMap + ) { + Map docMap = new LinkedHashMap<>(); + Map expectedDocMap = new LinkedHashMap<>(); + XContentType requestContentType = randomFrom(XContentType.values()); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + var model = modelMap.get(entry.getInferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + expectedDocMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + var result = randomSemanticText(field, model, List.of(text), requestContentType); + model.putResult(text, toChunkedResult(result)); + expectedDocMap.put(field, result); + } + + int requestId = randomIntBetween(0, Integer.MAX_VALUE); + return new BulkItemRequest[] { + new BulkItemRequest(requestId, new IndexRequest("index").source(docMap, requestContentType)), + new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + public static StaticModel createRandomInstance() { + TestModel testModel = TestModel.createRandomInstance(); + return new StaticModel( + testModel.getInferenceEntityId(), + testModel.getTaskType(), + randomAlphaOfLength(10), + testModel.getServiceSettings(), + testModel.getTaskSettings(), + testModel.getSecretSettings() + ); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults result) { + resultMap.put(text, result); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 9f7b80dd61979..3885563720484 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -31,7 +31,7 @@ import java.util.function.Predicate; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.toSemanticTextFieldChunks; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index a594c577dcdd2..a397d9864d23d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -21,9 +21,8 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") - .plugin("x-pack-inference") .plugin("inference-service-test") - .distribution(DistributionType.INTEG_TEST) + .distribution(DistributionType.DEFAULT) .build(); public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml new file mode 100644 index 0000000000000..ee3a5c28376b2 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml @@ -0,0 +1,551 @@ +setup: + - skip: + version: " - 8.14.99" + reason: semantic_text introduced in 8.15.0 + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + another_inference_field: + type: semantic_text + inference_id: sparse-inference-id + non_inference_field: + type: text + + - do: + indices.create: + index: test-dense-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + another_inference_field: + type: semantic_text + inference_id: dense-inference-id + non_inference_field: + type: text + +--- +"Calculates text expansion results for new documents": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + +--- +"text expansion documents do not create new mappings": + - do: + indices.get_mapping: + index: test-sparse-index + + - match: {test-sparse-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.non_inference_field.type: text} + - length: {test-sparse-index.mappings.properties: 3} + +--- +"Calculates text embeddings results for new documents": + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-dense-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.non_inference_field: "non inference test" } + + +--- +"text embeddings documents do not create new mappings": + - do: + indices.get_mapping: + index: test-dense-index + + - match: {test-dense-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.non_inference_field.type: text} + - length: {test-dense-index.mappings.properties: 3} + +--- +"Sparse vector results are indexed as nested chunks and searchable": + - do: + bulk: + index: test-sparse-index + refresh: true + body: | + {"index":{}} + {"inference_field": ["you know, for testing", "now with chunks"]} + {"index":{}} + {"inference_field": ["some more tests", "that include chunks"]} + + - do: + search: + index: test-sparse-index + body: + query: + nested: + path: inference_field.inference.chunks + query: + text_expansion: + inference_field.inference.chunks.embeddings: + model_id: sparse-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } + + # Search with inner hits + - do: + search: + _source: false + index: test-sparse-index + body: + query: + nested: + path: inference_field.inference.chunks + inner_hits: + _source: false + fields: [inference_field.inference.chunks.text] + query: + text_expansion: + inference_field.inference.chunks.embeddings: + model_id: sparse-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } + - match: { hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.total.value: 2 } + - exists: hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.hits.0.fields.inference_field\.inference\.chunks.0.text + - exists: hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.hits.1.fields.inference_field\.inference\.chunks.0.text + + +--- +"Dense vector results are indexed as nested chunks and searchable": + - do: + bulk: + index: test-dense-index + refresh: true + body: | + {"index":{}} + {"inference_field": ["you know, for testing", "now with chunks"]} + {"index":{}} + {"inference_field": ["some more tests", "that include chunks"]} + + - do: + search: + index: test-dense-index + body: + query: + nested: + path: inference_field.inference.chunks + query: + knn: + field: inference_field.inference.chunks.embeddings + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: "you know, for testing" + + # Search with inner hits + - do: + search: + _source: false + index: test-dense-index + body: + query: + nested: + path: inference_field.inference.chunks + inner_hits: + _source: false + fields: [inference_field.inference.chunks.text] + query: + knn: + field: inference_field.inference.chunks.embeddings + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } + - match: { hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.total.value: 2 } + - exists: hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.hits.0.fields.inference_field\.inference\.chunks.0.text + - exists: hits.hits.0.inner_hits.inference_field\.inference\.chunks.hits.hits.1.fields.inference_field\.inference\.chunks.0.text + + +--- +"Updating non semantic_text fields does not recalculate embeddings": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } + + - do: + update: + index: test-sparse-index + id: doc_1 + body: + doc: + non_inference_field: "another non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } + - match: { _source.non_inference_field: "another non inference test" } + +--- +"Updating semantic_text fields recalculates embeddings": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "I am a test" } + - match: { _source.inference_field.inference.chunks.0.text: "I am a test" } + - match: { _source.another_inference_field.text: "I am a teapot" } + - match: { _source.another_inference_field.inference.chunks.0.text: "I am a teapot" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + update: + index: test-sparse-index + id: doc_1 + body: + doc: + inference_field: "updated inference test" + another_inference_field: "another updated inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "updated inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "updated inference test" } + - match: { _source.another_inference_field.text: "another updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "bulk inference test", "another_inference_field": "bulk updated inference test"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "bulk inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "bulk inference test" } + - match: { _source.another_inference_field.text: "bulk updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "bulk updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + +--- +"Reindex works for semantic_text fields": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } + + - do: + indices.refresh: { } + + - do: + indices.create: + index: destination-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + another_inference_field: + type: semantic_text + inference_id: sparse-inference-id + non_inference_field: + type: text + + - do: + reindex: + wait_for_completion: true + body: + source: + index: test-sparse-index + dest: + index: destination-index + - do: + get: + index: destination-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } + - match: { _source.non_inference_field: "non inference test" } + +--- +"Fails for non-existent inference": + - do: + indices.create: + index: incorrect-test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: non-existing-inference-id + non_inference_field: + type: text + + - do: + catch: missing + index: + index: incorrect-test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + non_inference_field: "non inference test" + + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } + + # Succeeds when semantic_text field is not used + - do: + index: + index: incorrect-test-sparse-index + id: doc_1 + body: + non_inference_field: "non inference test" + +--- +"Updates with script are not allowed": + - do: + bulk: + index: test-sparse-index + body: + - '{"index": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"script": "ctx._source.new_field = \"hello\"", "scripted_upsert": true}' + + - match: { errors: true } + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } + +--- +"semantic_text copy_to calculate inference for source fields": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "copy_to inference test" + inference_field: "inference test" + another_source_field: "another copy_to inference test" + + - do: + get: + index: test-copy-to-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - length: { _source.inference_field.inference.chunks: 3 } + - match: { _source.inference_field.inference.chunks.0.text: "another copy_to inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.1.text: "inference test" } + - exists: _source.inference_field.inference.chunks.1.embeddings + - match: { _source.inference_field.inference.chunks.2.text: "copy_to inference test" } + - exists: _source.inference_field.inference.chunks.2.embeddings + + +--- +"semantic_text copy_to needs values for every source field for updates": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + # Not every source field needed on creation + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "a single source field provided" + inference_field: "inference test" + + # Every source field needed on bulk updates + - do: + bulk: + body: + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' + - '{"doc": {"source_field": "a single source field is kept as provided via bulk", "inference_field": "updated inference test" }}' + + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Field [another_source_field] must be specified on an update request to calculate inference for field [inference_field]" } diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java index 4d90d2a186858..d8c9dc2efd927 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java @@ -59,10 +59,10 @@ public void testIngestWithMultipleModelTypes() throws IOException { assertThat(simulatedDocs, hasSize(2)); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); var sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); } {