From 7aaa3b6d458796a288ba05ff2ee6a28fc48f2bc3 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 7 Mar 2024 14:29:33 +0000 Subject: [PATCH 01/26] Revert "Extract interface from ModelRegistry so it can be used from server (#105012)" This reverts commit f4d3ab9df2a89879ac56262e618c0058fcb70d0e. --- .../action/bulk/BulkOperation.java | 114 +-- .../BulkShardRequestInferenceProvider.java | 319 --------- .../action/bulk/TransportBulkAction.java | 38 +- .../bulk/TransportSimulateBulkAction.java | 4 +- .../inference/InferenceServiceRegistry.java | 62 +- .../InferenceServiceRegistryImpl.java | 64 -- .../inference/ModelRegistry.java | 99 --- .../elasticsearch/node/NodeConstruction.java | 15 - .../plugins/InferenceRegistryPlugin.java | 22 - .../action/bulk/BulkOperationTests.java | 657 ------------------ ...ActionIndicesThatCannotBeCreatedTests.java | 8 +- .../bulk/TransportBulkActionIngestTests.java | 8 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 16 +- ...gistryImplIT.java => ModelRegistryIT.java} | 52 +- .../xpack/inference/InferencePlugin.java | 37 +- .../TransportDeleteInferenceModelAction.java | 2 +- .../TransportGetInferenceModelAction.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- ...emanticTextInferenceResultFieldMapper.java | 16 +- .../mapper}/SemanticTextModelSettings.java | 13 +- ...elRegistryImpl.java => ModelRegistry.java} | 82 ++- ...icTextInferenceResultFieldMapperTests.java | 6 +- ...ImplTests.java => ModelRegistryTests.java} | 34 +- 25 files changed, 212 insertions(+), 1466 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/ModelRegistry.java delete mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java delete mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryImplIT.java => ModelRegistryIT.java} (86%) rename {server/src/main/java/org/elasticsearch/inference => x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper}/SemanticTextModelSettings.java (89%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImpl.java => ModelRegistry.java} (86%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImplTests.java => ModelRegistryTests.java} (92%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 2b84ec8746cd2..1d95f430d5c7e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -36,8 +35,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -47,7 +44,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -73,8 +69,6 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; - private final InferenceServiceRegistry inferenceServiceRegistry; - private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -88,8 +82,6 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -105,8 +97,6 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; - this.inferenceServiceRegistry = inferenceServiceRegistry; - this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -199,30 +189,7 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInstance( - inferenceServiceRegistry, - modelRegistry, - clusterState, - requestsByShard.keySet(), - new ActionListener() { - @Override - public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { - processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); - } - - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error loading inference models", e); - } - } - ); - } - - void processRequestsByShards( - Map> requestsByShard, - ClusterState clusterState, - BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider - ) { + String nodeId = clusterService.localNode().getId(); Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -230,68 +197,29 @@ void processRequestsByShards( // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); - BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); - - Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); - bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { - @Override - public void onResponse(BulkShardRequest inferenceBulkShardRequest) { - executeBulkShardRequest( - inferenceBulkShardRequest, - ActionListener.releaseAfter(ActionListener.noop(), ref), - bulkItemFailedListener - ); - } - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error performing inference", e); - } - }, bulkItemFailedListener); + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); } } } - private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } - return bulkShardRequest; - } - - // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { - final String indexName = itemRequest.index(); - - DocWriteRequest docWriteRequest = itemRequest.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - } - - private void executeBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer bulkItemErrorListener - ) { - if (bulkShardRequest.items().length == 0) { - // No requests to execute due to previous errors, terminate early - listener.onResponse(bulkShardRequest); - return; - } - + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -302,17 +230,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - listener.onResponse(bulkShardRequest); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - BulkItemRequest[] items = bulkShardRequest.items(); - for (BulkItemRequest item : items) { - bulkItemErrorListener.accept(item, e); + for (BulkItemRequest request : bulkShardRequest.items()) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - listener.onFailure(e); + releaseOnFinish.close(); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java deleted file mode 100644 index 4b7a67e9ca0e3..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ /dev/null @@ -1,319 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.common.TriConsumer; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; -import java.util.stream.Collectors; - -/** - * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. - */ -public class BulkShardRequestInferenceProvider { - - // Root field name for storing inference results - public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; - - // Contains the original text for the field - - public static final String INFERENCE_RESULTS = "inference_results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - private final ClusterState clusterState; - private final Map inferenceProvidersMap; - - private record InferenceProvider(Model model, InferenceService service) { - private InferenceProvider { - Objects.requireNonNull(model); - Objects.requireNonNull(service); - } - } - - BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { - this.clusterState = clusterState; - this.inferenceProvidersMap = inferenceProvidersMap; - } - - public static void getInstance( - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - ClusterState clusterState, - Set shardIds, - ActionListener listener - ) { - Set inferenceIds = new HashSet<>(); - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); - inferenceIds.addAll(fieldsForModels.keySet()); - }); - final Map inferenceProviderMap = new ConcurrentHashMap<>(); - Runnable onModelLoadingComplete = () -> listener.onResponse( - new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) - ); - try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { - for (var inferenceId : inferenceIds) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - InferenceProvider inferenceProvider = new InferenceProvider( - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ), - service.get() - ); - inferenceProviderMap.put(inferenceId, inferenceProvider); - } - } - - @Override - public void onFailure(Exception e) { - // Failure on loading a model should not prevent the rest from being loaded and used. - // When the model is actually retrieved via the inference ID in the inference process, it will fail - // and the user will get the details on the inference failure. - } - }; - - modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); - } - } - } - - /** - * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from - * the original request will be modified with the inference results, to avoid copying the entire requests from - * the original bulk request. - * - * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. - * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, - * which may have fewer items than the original because of inference failures - * @param onBulkItemFailure invoked when a bulk item fails inference - */ - public void processBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer onBulkItemFailure - ) { - - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); - // No inference fields? Terminate early - if (fieldsForModels.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - - Set failedItems = Collections.synchronizedSet(new HashSet<>()); - Runnable onInferenceComplete = () -> { - if (failedItems.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - // Remove failed items from the original bulk shard request - BulkItemRequest[] originalItems = bulkShardRequest.items(); - BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; - for (int i = 0, j = 0; i < originalItems.length; i++) { - if (failedItems.contains(i) == false) { - newItems[j++] = originalItems[i]; - } - } - BulkShardRequest newBulkShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - newItems - ); - listener.onResponse(newBulkShardRequest); - }; - TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { - failedItems.add(i); - onBulkItemFailure.accept(bulkItemRequest, e); - }; - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest bulkItemRequest = items[i]; - // Bulk item might be null because of previous errors, skip in that case - if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest( - bulkItemRequest, - fieldsForModels, - i, - onBulkItemFailureWithIndex, - bulkItemReqRef.acquire() - ); - } - } - } - } - - @SuppressWarnings("unchecked") - private void performInferenceOnBulkItemRequest( - BulkItemRequest bulkItemRequest, - Map> fieldsForModels, - Integer itemIndex, - TriConsumer onBulkItemFailure, - Releasable releaseOnFinish - ) { - - DocWriteRequest docWriteRequest = bulkItemRequest.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; - } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - Map rootInferenceFieldMap; - try { - rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_INFERENCE_FIELD, - k -> new HashMap() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") - ); - return; - } - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { - continue; - } - - InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); - if (inferenceProvider == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("No inference provider found for model ID " + modelId) - ); - return; - } - ActionListener inferenceResultsListener = new ActionListener<>() { - @Override - public void onResponse(InferenceServiceResults results) { - if (results == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ) - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { - String inferenceFieldName = inferenceFieldNames.get(i++); - Map inferenceFieldResult = new LinkedHashMap<>(); - inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap()); - inferenceFieldResult.put( - INFERENCE_RESULTS, - List.of( - Map.of( - INFERENCE_CHUNKS_RESULTS, - inferenceResults.asMap("output").get("output"), - INFERENCE_CHUNKS_TEXT, - docMap.get(inferenceFieldName) - ) - ) - ); - rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); - } - } - - @Override - public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); - } - }; - inferenceProvider.service() - .infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - InputType.INGEST, - ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) - ); - } - } - } - - private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String) { - inferenceFieldNames.add(inferenceField); - } - } - return inferenceFieldNames; - } -} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 43076daf5bffe..a2445e95a572f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,8 +57,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -100,8 +98,6 @@ public class TransportBulkAction extends HandledTransportAction releasingListener) { - threadPool.executor(Names.WRITE).execute(new ActionRunnable<>(releasingListener) { + threadPool.executor(executorName).execute(new ActionRunnable<>(releasingListener) { @Override protected void doRun() { doInternalExecute(task, bulkRequest, executorName, releasingListener); @@ -425,13 +409,13 @@ protected void createMissingIndicesAndIndexData( final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -649,10 +633,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { new BulkOperation( task, @@ -666,8 +650,6 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, - modelRegistry, - inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index c8dc3e7b7ffd5..f65d0f462fde6 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -58,9 +58,7 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime, - null, - null + System::nanoTime ); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index ce6f1b21b734c..d5973807d9d78 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -13,41 +13,49 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry implements Closeable { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } -public interface InferenceServiceRegistry extends Closeable { - void init(Client client); - - Map getServices(); - - Optional getService(String serviceName); - - List getNamedWriteables(); - - class NoopInferenceServiceRegistry implements InferenceServiceRegistry { - public NoopInferenceServiceRegistry() {} + public void init(Client client) { + services.values().forEach(s -> s.init(client)); + } - @Override - public void init(Client client) {} + public Map getServices() { + return services; + } - @Override - public Map getServices() { - return Map.of(); - } + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } - @Override - public Optional getService(String serviceName) { - return Optional.empty(); - } + public List getNamedWriteables() { + return namedWriteables; + } - @Override - public List getNamedWriteables() { - return List.of(); + @Override + public void close() throws IOException { + for (var service : services.values()) { + service.close(); } - - @Override - public void close() throws IOException {} } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java deleted file mode 100644 index f0a990ded98ce..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class InferenceServiceRegistryImpl implements InferenceServiceRegistry { - - private final Map services; - private final List namedWriteables = new ArrayList<>(); - - public InferenceServiceRegistryImpl( - List inferenceServicePlugins, - InferenceServiceExtension.InferenceServiceFactoryContext factoryContext - ) { - // TODO check names are unique - services = inferenceServicePlugins.stream() - .flatMap(r -> r.getInferenceServiceFactories().stream()) - .map(factory -> factory.create(factoryContext)) - .collect(Collectors.toMap(InferenceService::name, Function.identity())); - } - - @Override - public void init(Client client) { - services.values().forEach(s -> s.init(client)); - } - - @Override - public Map getServices() { - return services; - } - - @Override - public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); - } - - @Override - public List getNamedWriteables() { - return namedWriteables; - } - - @Override - public void close() throws IOException { - for (var service : services.values()) { - service.close(); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java deleted file mode 100644 index fa90d5ba6f756..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; -import java.util.Map; - -public interface ModelRegistry { - - /** - * Get a model. - * Secret settings are not included - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModel(String inferenceEntityId, ActionListener listener); - - /** - * Get a model with its secret settings - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModelWithSecrets(String inferenceEntityId, ActionListener listener); - - /** - * Get all models of a particular task type. - * Secret settings are not included - * @param taskType The task type - * @param listener Models listener - */ - void getModelsByTaskType(TaskType taskType, ActionListener> listener); - - /** - * Get all models. - * Secret settings are not included - * @param listener Models listener - */ - void getAllModels(ActionListener> listener); - - void storeModel(Model model, ActionListener listener); - - void deleteModel(String modelId, ActionListener listener); - - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) {} - - class NoopModelRegistry implements ModelRegistry { - @Override - public void getModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void getAllModels(ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void storeModel(Model model, ActionListener listener) { - fail(listener); - } - - @Override - public void deleteModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - fail(listener); - } - - private static void fail(ActionListener listener) { - listener.onFailure(new IllegalArgumentException("No model registry configured")); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 19a6d200189f2..abc23b63aa1b6 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -126,8 +126,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -146,7 +144,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1104,18 +1101,6 @@ record PluginServiceInstances( ); } - // Register noop versions of inference services if Inference plugin is not available - Optional inferenceRegistryPlugin = getSinglePlugin(InferenceRegistryPlugin.class); - modules.bindToInstance( - InferenceServiceRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getInferenceServiceRegistry) - .orElse(new InferenceServiceRegistry.NoopInferenceServiceRegistry()) - ); - modules.bindToInstance( - ModelRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getModelRegistry).orElse(new ModelRegistry.NoopModelRegistry()) - ); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java deleted file mode 100644 index 696c3a067dad1..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; - -/** - * Plugins that provide inference services should implement this interface. - * There should be a single one in the classpath, as we currently support a single instance for ModelRegistry / InfereceServiceRegistry. - */ -public interface InferenceRegistryPlugin { - InferenceServiceRegistry getInferenceServiceRegistry(); - - ModelRegistry getModelRegistry(); -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java deleted file mode 100644 index 2ce7b161d3dd1..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ /dev/null @@ -1,657 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.IndexAbstraction; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.service.ClusterApplierService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatcher; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyMap; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -public class BulkOperationTests extends ESTestCase { - - private static final String INDEX_NAME = "test-index"; - private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; - private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; - private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; - private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; - private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; - private static final String SERVICE_1_ID = "elser_v2"; - private static final String SERVICE_2_ID = "e5"; - private static final String INFERENCE_FAILED_MSG = "Inference failed"; - private static TestThreadPool threadPool; - - public void testNoInference() { - - Map> fieldsForModels = Map.of(); - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference not invoked - verifyNoMoreInteractions(modelRegistry); - verifyNoMoreInteractions(inferenceServiceRegistry); - } - - private static Model mockModel(String inferenceServiceId) { - Model model = mock(Model.class); - - when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); - TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; - when(model.getTaskType()).thenReturn(taskType); - - ServiceSettings serviceSettings = mock(ServiceSettings.class); - when(model.getServiceSettings()).thenReturn(serviceSettings); - SimilarityMeasure similarity = switch (randomInt(2)) { - case 0 -> SimilarityMeasure.COSINE; - case 1 -> SimilarityMeasure.DOT_PRODUCT; - default -> null; - }; - when(serviceSettings.similarity()).thenReturn(similarity); - when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); - - return model; - } - - public void testFailedBulkShardRequest() { - - Map> fieldsForModels = Map.of(); - ModelRegistry modelRegistry = createModelRegistry(Map.of()); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - true, - request -> new BulkShardResponse( - request.shardId(), - new BulkItemResponse[] { - BulkItemResponse.failure( - 0, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure( - INDEX_NAME, - randomIdentifier(), - new IllegalArgumentException("Error on bulk shard request") - ) - ) } - ) - ); - verify(bulkOperationListener).onResponse(any()); - - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse[] items = bulkResponse.getItems(); - assertTrue(items[0].isFailed()); - } - - @SuppressWarnings("unchecked") - public void testInference() { - - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) - ); - - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - SECOND_INFERENCE_FIELD_SERVICE_1, - secondInferenceTextService1, - INFERENCE_FIELD_SERVICE_2, - inferenceTextService2, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference results - verifyInferenceServiceInvoked( - modelRegistry, - INFERENCE_SERVICE_1_ID, - inferenceService1, - model1, - List.of(firstInferenceTextService1, secondInferenceTextService1) - ); - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - checkInferenceResults( - originalSource, - writtenDocSource, - FIRST_INFERENCE_FIELD_SERVICE_1, - SECOND_INFERENCE_FIELD_SERVICE_1, - INFERENCE_FIELD_SERVICE_2 - ); - } - - public void testFailedInference() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); - - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); - - } - - public void testInferenceFailsForIncorrectRootObject() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - "incorrect_root_object" - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); - } - - public void testInferenceIdNotFound() { - - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceService(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - INFERENCE_FIELD_SERVICE_2, - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) - ); - } - - @SuppressWarnings("unchecked") - private static void checkInferenceResults( - Map docSource, - Map writtenDocSource, - String... inferenceFieldNames - ) { - - Map inferenceRootResultField = (Map) writtenDocSource.get( - BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD - ); - - for (String inferenceFieldName : inferenceFieldNames) { - Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); - assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(2)); - Map modelSettings = (Map) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME); - assertNotNull(modelSettings); - assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName())); - assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName())); - - List> inferenceResultElement = (List>) inferenceService1FieldResults.get( - INFERENCE_RESULTS - ); - assertFalse(inferenceResultElement.isEmpty()); - assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); - assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); - } - } - - private static void verifyInferenceServiceInvoked( - ModelRegistry modelRegistry, - String inferenceService1Id, - InferenceService inferenceService, - Model model, - Collection inferenceTexts - ) { - verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfigWithSecrets( - eq(inferenceService1Id), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ); - verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); - verifyNoMoreInteractions(inferenceService); - } - - private static ArgumentMatcher> containsInAnyOrder(Collection expected) { - return new ArgumentMatcher<>() { - @Override - public boolean matches(List argument) { - return argument.containsAll(expected) && argument.size() == expected.size(); - } - - @Override - public String toString() { - return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; - } - }; - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - Map> fieldsForModels, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - boolean expectTransportShardBulkActionToExecute, - ActionListener bulkOperationListener - ) { - return runBulkOperation( - docSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - expectTransportShardBulkActionToExecute, - successfulBulkShardResponse - ); - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - Map> fieldsForModels, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - ActionListener bulkOperationListener, - boolean expectTransportShardBulkActionToExecute, - Function bulkShardResponseSupplier - ) { - Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); - IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldsForModels(fieldsForModels) - .settings(settings) - .numberOfShards(1) - .numberOfReplicas(0) - .build(); - ClusterService clusterService = createClusterService(indexMetadata); - - IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); - when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); - - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); - - NodeClient client = mock(NodeClient.class); - - ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); - doAnswer(invocation -> { - BulkShardRequest request = invocation.getArgument(1); - ActionListener bulkShardResponseListener = invocation.getArgument(2); - bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); - return null; - }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); - - Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); - BulkOperation bulkOperation = new BulkOperation( - task, - threadPool, - ThreadPool.Names.WRITE, - clusterService, - bulkRequest, - client, - new AtomicArray<>(bulkRequest.requests.size()), - new HashMap<>(), - indexResolver, - () -> System.nanoTime(), - System.nanoTime(), - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - - bulkOperation.doRun(); - if (expectTransportShardBulkActionToExecute) { - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); - return bulkShardRequestCaptor.getValue(); - } - - return null; - } - - private static final Function successfulBulkShardResponse = (request) -> { - return new BulkShardResponse( - request.shardId(), - Arrays.stream(request.items()) - .filter(Objects::nonNull) - .map( - item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) - ) - ) - .toArray(BulkItemResponse[]::new) - ); - }; - - private static InferenceService createInferenceService(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); - List texts = invocation.getArgument(1); - List inferenceResults = new ArrayList<>(); - for (int i = 0; i < texts.size(); i++) { - inferenceResults.add(createInferenceResults()); - } - doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); - - listener.onResponse(inferenceServiceResults); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceService createInferenceServiceThatFails(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceResults createInferenceResults() { - InferenceResults inferenceResults = mock(InferenceResults.class); - when(inferenceResults.asMap(any())).then( - invocation -> Map.of( - (String) invocation.getArguments()[0], - Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) - ) - ); - return inferenceResults; - } - - private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { - InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); - return inferenceServiceRegistry; - } - - private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { - ModelRegistry modelRegistry = mock(ModelRegistry.class); - // Fails for unknown inference ids - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IllegalArgumentException("Model not found")); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { - ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - serviceId, - emptyMap(), - emptyMap() - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); - }); - - return modelRegistry; - } - - private static ClusterService createClusterService(IndexMetadata indexMetadata) { - Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); - - ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); - when(clusterService.state()).thenReturn(clusterState); - - ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); - when(clusterApplierService.state()).thenReturn(clusterState); - when(clusterApplierService.threadPool()).thenReturn(threadPool); - when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); - return clusterService; - } - - @BeforeClass - public static void createThreadPool() { - threadPool = new TestThreadPool(getTestClass().getName()); - } - - @AfterClass - public static void stopThreadPool() { - if (threadPool != null) { - threadPool.shutdownNow(); - threadPool = null; - } - } - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 988a92352649a..3057b00553a22 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -129,19 +129,17 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) { @Override void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 2d6492e4e73a4..6815d634292a4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -148,9 +148,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } @@ -159,10 +157,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index ad522e36f9bd9..1a16d9083df55 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -98,9 +98,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index a2e54a1c7c3b8..cb9bdd1f3a827 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } @@ -253,9 +253,7 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider, - null, - null + relativeTimeProvider ); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java similarity index 86% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ccda986a8d280..0f23e0b33d774 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -26,7 +26,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; @@ -55,13 +55,13 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class ModelRegistryImplIT extends ESSingleNodeTestCase { +public class ModelRegistryIT extends ESSingleNodeTestCase { - private ModelRegistryImpl ModelRegistryImpl; + private ModelRegistry modelRegistry; @Before public void createComponents() { - ModelRegistryImpl = new ModelRegistryImpl(client()); + modelRegistry = new ModelRegistry(client()); } @Override @@ -75,7 +75,7 @@ public void testStoreModel() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertThat(storeModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); @@ -87,7 +87,7 @@ public void testStoreModelWithUnknownFields() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertNull(storeModelHolder.get()); assertNotNull(exceptionHolder.get()); @@ -106,12 +106,12 @@ public void testGetModel() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -133,13 +133,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); putModelHolder.set(false); // an model with the same id exists - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(false)); assertThat(exceptionHolder.get(), not(nullValue())); assertThat( @@ -154,20 +154,20 @@ public void testDeleteModel() throws Exception { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference deleteResponseHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertTrue(deleteResponseHolder.get()); // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); @@ -187,13 +187,13 @@ public void testGetModelsByTaskType() throws InterruptedException { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) @@ -204,7 +204,7 @@ public void testGetModelsByTaskType() throws InterruptedException { assertThat(m.secrets().keySet(), empty()); }); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(2)); var denseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) @@ -228,13 +228,13 @@ public void testGetAllModels() throws InterruptedException { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -258,18 +258,18 @@ public void testGetModelWithSecrets() throws InterruptedException { AtomicReference exceptionHolder = new AtomicReference<>(); var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); // get model without secrets - blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 821a804596cff..a73dc0261faa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -26,11 +26,8 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceRegistryImpl; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -58,7 +55,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -80,13 +77,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin - implements - ActionPlugin, - ExtensiblePlugin, - SystemIndexPlugin, - InferenceRegistryPlugin, - MapperPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -111,8 +102,6 @@ public class InferencePlugin extends Plugin private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce modelRegistry = new SetOnce<>(); - private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -164,7 +153,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelReg = new ModelRegistryImpl(services.client()); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -175,13 +164,11 @@ public Collection createComponents(PluginServices services) { var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly - var inferenceRegistry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext); - inferenceRegistry.init(services.client()); - inferenceServiceRegistry.set(inferenceRegistry); - modelRegistry.set(modelReg); + var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + registry.init(services.client()); + inferenceServiceRegistry.set(registry); - // Don't return components as they will be registered using InferenceRegistryPlugin methods to retrieve them - return List.of(); + return List.of(modelRegistry, registry); } @Override @@ -280,16 +267,6 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } - @Override - public InferenceServiceRegistry getInferenceServiceRegistry() { - return inferenceServiceRegistry.get(); - } - - @Override - public ModelRegistry getModelRegistry() { - return modelRegistry.get(); - } - @Override public Map getMappers() { if (SemanticTextFeature.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index ad6042581f264..b55e2e6f8ebed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 0f7e48c4f8140..2de1aecea118c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -25,6 +24,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ece4fee1c935f..fb3974fc12e8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,11 +16,11 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6667e314a62b8..07d28f8e5b0a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -29,7 +29,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index ad1e0f8c8cb81..928bb0ea47179 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; -import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -41,10 +40,6 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; - /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -101,8 +96,13 @@ * */ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { - public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = ROOT_INFERENCE_FIELD; + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String INFERENCE_RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); @@ -173,7 +173,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + if (INFERENCE_RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java similarity index 89% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 3561c2351427c..aedaa1cd34c12 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -1,3 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -6,8 +13,11 @@ * Side Public License, v 1. */ -package org.elasticsearch.inference; +package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,7 +29,6 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ public class SemanticTextModelSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 40921cd38f181..0f3aa5b82b189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -24,7 +24,6 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -32,7 +31,6 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -57,21 +55,49 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistryImpl implements ModelRegistry { +public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} + /** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) { + + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } + } + private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; - private static final Logger logger = LogManager.getLogger(ModelRegistryImpl.class); + private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; - @Inject - public ModelRegistryImpl(Client client) { + public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } - @Override + /** + * Get a model with its secret settings + * @param inferenceEntityId Model to get + * @param listener Model listener + */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -80,7 +106,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -101,7 +132,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -116,7 +147,12 @@ public void getModel(String inferenceEntityId, ActionListener lis client.search(modelSearch, searchListener); } - @Override + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -125,7 +161,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -150,7 +190,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -217,7 +257,6 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEnt ); } - @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -314,7 +353,6 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } - @Override public void deleteModel(String inferenceEntityId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -339,16 +377,4 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } - - private static UnparsedModel unparsedModelFromMap(ModelRegistryImpl.ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 319f6ef73fa56..27d0949e933ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -53,9 +53,9 @@ import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java similarity index 92% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fd6a203450c12..2417148c84ac2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ModelRegistryImplTests extends ESTestCase { +public class ModelRegistryTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -65,9 +65,9 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -79,9 +79,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -96,9 +96,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -113,9 +113,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -147,9 +147,9 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -176,9 +176,9 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -201,7 +201,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +218,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +249,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -275,7 +275,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); From ebc26d2b29fd9dda9b896c061c67e6f223255c5d Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 14 Mar 2024 10:38:14 +0000 Subject: [PATCH 02/26] inference as an action filter --- .../action/bulk/BulkOperation.java | 4 + .../action/bulk/BulkShardRequest.java | 20 + .../snapshots/SnapshotResiliencyTests.java | 4 +- .../TestSparseInferenceServiceExtension.java | 8 +- .../xpack/inference/InferencePlugin.java | 13 + .../ShardBulkInferenceActionFilter.java | 346 ++++++++++++++++++ ...emanticTextInferenceResultFieldMapper.java | 1 - .../mapper/SemanticTextModelSettings.java | 8 - ...icTextInferenceResultFieldMapperTests.java | 1 - .../inference/10_semantic_text_inference.yml | 48 +-- .../20_semantic_text_field_mapper.yml | 20 +- 11 files changed, 423 insertions(+), 50 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 1d95f430d5c7e..b7a6387045e3d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -208,6 +208,10 @@ private void executeBulkRequestsByShard(Map> requ bulkRequest.getRefreshPolicy(), requests.toArray(new BulkItemRequest[0]) ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels()); + } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..f6dd7902f1672 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -22,6 +22,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -33,6 +34,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest> fieldsInferenceMetadata = null; + public BulkShardRequest(StreamInput in) throws IOException { super(in); items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +47,23 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Set the transient metadata indicating that this request requires running inference + * before proceeding. + */ + void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { + this.fieldsInferenceMetadata = fieldsInferenceMetadata; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public Map> consumeFieldInferenceMetadata() { + var ret = fieldsInferenceMetadata; + fieldsInferenceMetadata = null; + return ret; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index c13a6be6d3386..edde9f0164a6e 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -2355,9 +2355,7 @@ protected void assertSnapshotOrGenericThread() { actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 33bbc94901e9d..b6e48d3b1c29a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -123,15 +123,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a73dc0261faa8..7bfa06ecb9a20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -46,6 +47,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -77,6 +79,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.Collections.singletonList; + public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** @@ -102,6 +106,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -168,6 +173,9 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); + return List.of(modelRegistry, registry); } @@ -279,4 +287,9 @@ public Map getMappers() { public Map getMetadataMappers() { return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); } + + @Override + public Collection getActionFilters() { + return singletonList(shardBulkInferenceActionFilter.get()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..176d2917b0b2a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper} + * in the subsequent {@link TransportShardBulkAction} downstream. + */ +public class ShardBulkInferenceActionFilter implements ActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + if (fieldInferenceMetadata != null) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + Map> fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + private record FieldInferenceRequest(int id, String field, String input) {} + + private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) { + Exception createFailureOrNull() { + if (failures.isEmpty()) { + return null; + } + Exception main = failures.get(0); + for (int i = 1; i < failures.size(); i++) { + main.addSuppressed(failures.get(i)); + } + return main; + } + } + + private class AsyncBulkShardInferenceAction implements Runnable { + private final Map> fieldInferenceMetadata; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + Map> fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMetadata = fieldInferenceMetadata; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + applyInference(request, inferenceResponse); + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException( + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + inferenceResults.get(request.id).responses.add( + new FieldInferenceResponse(request.field, inferenceProvider.model, result) + ); + } + } + + @Override + public void onFailure(Exception exc) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + ActionListener.runAfter(completionListener, onFinish::close) + ); + } + + /** + * Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is mark as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results. + */ + private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) { + Exception failure = inferenceResponse.createFailureOrNull(); + if (failure != null) { + request.abort(bulkShardRequest.index(), failure); + return; + } + final IndexRequest indexRequest = getIndexRequestOrNull(request.request()); + final Map newDocMap = indexRequest.sourceAsMap(); + final Map inferenceMetadataMap = new LinkedHashMap<>(); + newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap); + for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) { + List> chunks = new ArrayList<>(); + if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO")); + return; + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap()); + fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks); + inferenceMetadataMap.put(fieldResponse.field, fieldMap); + } + indexRequest.source(newDocMap); + } + + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security). + continue; + } + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + if (indexRequest == null) { + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + List fieldRequests = null; + for (var pair : fieldInferenceMetadata.entrySet()) { + String inferenceId = pair.getKey(); + for (var field : pair.getValue()) { + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (value instanceof String valueStr) { + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (fieldRequests == null) { + fieldRequests = new ArrayList<>(); + fieldRequestsMap.put(inferenceId, fieldRequests); + } + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } + } + } + } + return fieldRequestsMap; + } + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 928bb0ea47179..cee6395185060 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -27,7 +27,6 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index aedaa1cd34c12..1b6bb22c0d6b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -5,14 +5,6 @@ * 2.0. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - package org.elasticsearch.xpack.inference.mapper; import org.elasticsearch.inference.Model; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 27d0949e933ee..5dc245298838f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index ead7f904ad57b..6008ebbcbedf8 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -83,11 +83,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +120,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- @@ -154,8 +154,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +174,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +214,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.results.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +233,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -271,11 +271,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -292,7 +292,7 @@ setup: type: text - do: - catch: bad_request + catch: missing index: index: incorrect-test-sparse-index id: doc_1 @@ -300,7 +300,7 @@ setup: inference_field: "inference test" non_inference_field: "non inference test" - - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } # Succeeds when semantic_text field is not used - do: diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index da61e6e403ed8..2c69f49218091 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -56,12 +56,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -83,14 +83,14 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding dimensions: 5 similarity: cosine - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" @@ -105,11 +105,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -123,11 +123,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -141,12 +141,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" From 86ddc9d8b8aff0285890732d2164c589b87ce3dc Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 18 Mar 2024 15:12:23 +0000 Subject: [PATCH 03/26] add more tests --- .../action/bulk/BulkShardRequest.java | 13 +- .../vectors/DenseVectorFieldMapper.java | 4 + .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 133 ++++--- ...r.java => InferenceResultFieldMapper.java} | 55 ++- .../mapper/SemanticTextFieldMapper.java | 2 +- .../ShardBulkInferenceActionFilterTests.java | 340 ++++++++++++++++++ ...a => InferenceResultFieldMapperTests.java} | 146 ++++---- 8 files changed, 544 insertions(+), 153 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapper.java => InferenceResultFieldMapper.java} (86%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapperTests.java => InferenceResultFieldMapperTests.java} (79%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index f6dd7902f1672..1b5494c6a68f5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -48,10 +48,10 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe } /** - * Set the transient metadata indicating that this request requires running inference - * before proceeding. + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. */ - void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { + public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { this.fieldsInferenceMetadata = fieldsInferenceMetadata; } @@ -64,6 +64,13 @@ public Map> consumeFieldInferenceMetadata() { return ret; } + /** + * Public for test + */ + public Map> getFieldsInferenceMetadata() { + return fieldsInferenceMetadata; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c6e4d4af926a2..53cc803fc5a2f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1086,6 +1086,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7bfa06ecb9a20..994207766f2a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,8 +55,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -285,7 +285,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 176d2917b0b2a..e679d3c970abf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -33,11 +34,9 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +49,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements ActionFilter { @@ -82,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -110,18 +109,7 @@ private record FieldInferenceRequest(int id, String field, String input) {} private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} - private record FieldInferenceResponseAccumulator(int id, List responses, List failures) { - Exception createFailureOrNull() { - if (failures.isEmpty()) { - return null; - } - Exception main = failures.get(0); - for (int i = 1; i < failures.size(); i++) { - main.addSuppressed(failures.get(i)); - } - return main; - } - } + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { private final Map> fieldInferenceMetadata; @@ -147,7 +135,11 @@ public void run() { try { for (var inferenceResponse : inferenceResults.asList()) { var request = bulkShardRequest.items()[inferenceResponse.id]; - applyInference(request, inferenceResponse); + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } } } finally { onCompletion.run(); @@ -189,8 +181,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var request = requests.get(i); inferenceResults.get(request.id).failures.add( new ResourceNotFoundException( - "Inference service [{}] not found for field [{}]", - unparsedModel.service(), + "Inference id [{}] not found for field [{}]", + inferenceId, request.field ) ); @@ -221,9 +213,8 @@ public void onResponse(List results) { for (int i = 0; i < results.size(); i++) { var request = requests.get(i); var result = results.get(i); - inferenceResults.get(request.id).responses.add( - new FieldInferenceResponse(request.field, inferenceProvider.model, result) - ); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); } } @@ -254,38 +245,34 @@ public void onFailure(Exception exc) { } /** - * Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. * If the response contains failures, the bulk item request is mark as failed for the downstream action. * Otherwise, the source of the request is augmented with the field inference results. */ - private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) { - Exception failure = inferenceResponse.createFailureOrNull(); - if (failure != null) { - request.abort(bulkShardRequest.index(), failure); + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } return; } - final IndexRequest indexRequest = getIndexRequestOrNull(request.request()); - final Map newDocMap = indexRequest.sourceAsMap(); - final Map inferenceMetadataMap = new LinkedHashMap<>(); - newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap); - for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) { - List> chunks = new ArrayList<>(); - if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO")); - return; + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + Map newDocMap = indexRequest.sourceAsMap(); + Map inferenceMap = new LinkedHashMap<>(); + // ignore the existing inference map if any + newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + for (FieldInferenceResponse fieldResponse : response.responses()) { + try { + InferenceResultFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap()); - fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks); - inferenceMetadataMap.put(fieldResponse.field, fieldMap); } indexRequest.source(newDocMap); } @@ -294,7 +281,7 @@ private Map> createFieldInferenceRequests(Bu Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security). + // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); @@ -302,30 +289,38 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - List fieldRequests = null; - for (var pair : fieldInferenceMetadata.entrySet()) { - String inferenceId = pair.getKey(); - for (var field : pair.getValue()) { + for (var entry : fieldInferenceMetadata.entrySet()) { + String inferenceId = entry.getKey(); + for (var field : entry.getValue()) { var value = XContentMapValues.extractValue(field, docMap); if (value == null) { continue; } - if (value instanceof String valueStr) { - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (fieldRequests == null) { - fieldRequests = new ArrayList<>(); - fieldRequestsMap.put(inferenceId, fieldRequests); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index cee6395185060..2ede5419ab74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -27,15 +29,24 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -51,7 +62,7 @@ * { * "_source": { * "my_semantic_text_field": "these are not the droids you're looking for", - * "_semantic_text_inference": { + * "_inference": { * "my_semantic_text_field": [ * { * "sparse_embedding": { @@ -94,17 +105,17 @@ * } * */ -public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { +public class InferenceResultFieldMapper extends MetadataFieldMapper { public static final String NAME = "_inference"; public static final String CONTENT_TYPE = "_inference"; - public static final String INFERENCE_RESULTS = "results"; + public static final String RESULTS = "results"; public static final String INFERENCE_CHUNKS_RESULTS = "inference"; public static final String INFERENCE_CHUNKS_TEXT = "text"; - public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); @@ -131,7 +142,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { } } - private SemanticTextInferenceResultFieldMapper() { + public InferenceResultFieldMapper() { super(SemanticTextInferenceFieldType.INSTANCE); } @@ -172,7 +183,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (INFERENCE_RESULTS.equals(currentName)) { + if (RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, @@ -328,4 +339,34 @@ protected String contentType() { public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..4caa3d68ba877 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link SemanticTextInferenceResultFieldMapper}. + * be indexed using {@link InferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..7f3ffbe596543 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map> inferenceFields = Map.of( + model.getInferenceEntityId(), + Set.of("field1"), + "inference_0", + Set.of("field2", "field3") + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map> inferenceFields = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String inferenceId = randomFrom(inferenceModelMap.keySet()); + String field = randomAlphaOfLengthBetween(5, 10); + var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); + res.add(field); + } + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[1]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[5]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + Map> inferenceFieldMap + ) { + Map docMap = new LinkedHashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); + for (var entry : inferenceFieldMap.entrySet()) { + String inferenceId = entry.getKey(); + var model = modelMap.get(inferenceId); + for (var field : entry.getValue()) { + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; + + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; + + default: + throw new AssertionError("Unknown task type " + taskType.name()); + } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + } + } + Map expectedDocMap = new LinkedHashMap<>(docMap); + expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults results) { + resultMap.put(text, results); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java similarity index 79% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java index 5dc245298838f..b5d75b528c6ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java @@ -31,48 +31,46 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; -public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { - private SemanticTextInferenceResults { - if (sparseEmbeddingResults.embeddings().size() != text.size()) { - throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); - } - } - } +public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + private record VisitedChildDocInfo(String path, int numChunks) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return SemanticTextInferenceResultFieldMapper.NAME; + return InferenceResultFieldMapper.NAME; } @Override @@ -108,8 +106,8 @@ public void testSuccessfulParse() throws IOException { b -> addSemanticTextInferenceResults( b, List.of( - generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), - generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) ) ) ) @@ -208,10 +206,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, - null + Map.of() ) ) ) @@ -226,10 +224,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, - null + Map.of() ) ) ) @@ -244,10 +242,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, - null + Map.of() ) ) ) @@ -262,7 +260,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); final List semanticTextInferenceResultsList = List.of( - generateSemanticTextinferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, List.of("a b")) ); DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); @@ -360,7 +358,7 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) ) ); assertThat( @@ -378,18 +376,32 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { - List embeddings = new ArrayList<>(chunks.size()); - for (String chunk : chunks) { - String[] tokens = chunk.split("\\s+"); - List weightedTokens = Arrays.stream(tokens) - .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) - .toList(); + public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[5]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } - embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); } + return new ChunkedSparseEmbeddingResults(chunks); + } - return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { + return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -401,10 +413,11 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResults, new SparseVectorSubfieldOptions(true, true, true), true, - null + Map.of() ); } + @SuppressWarnings("unchecked") private static void addSemanticTextInferenceResults( XContentBuilder sourceBuilder, List semanticTextInferenceResults, @@ -412,48 +425,39 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - - Map> inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - Map fieldMap = new HashMap<>(); - fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); - List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); - - Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() - .embeddings() - .iterator(); - Iterator textIterator = semanticTextInferenceResult.text().iterator(); - while (embeddingsIterator.hasNext() && textIterator.hasNext()) { - SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); - String text = textIterator.next(); - - Map subfieldMap = new HashMap<>(); - if (sparseVectorSubfieldOptions.include()) { - subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); - } - if (includeTextSubfield) { - subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); + InferenceResultFieldMapper.applyFieldInference( + inferenceResultsMap, + semanticTextInferenceResult.fieldName, + randomModel(), + semanticTextInferenceResult.results + ); + Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); + List> fieldResultList = (List>) optionsMap.get(RESULTS); + for (var entry : fieldResultList) { + if (includeTextSubfield == false) { + entry.remove(INFERENCE_CHUNKS_TEXT); } - if (extraSubfields != null) { - subfieldMap.putAll(extraSubfields); + if (sparseVectorSubfieldOptions.include == false) { + entry.remove(INFERENCE_CHUNKS_RESULTS); } - - parsedInferenceResults.add(subfieldMap); + entry.putAll(extraSubfields); } - - fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } - - sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); } - private static Map modelSettingsMap() { - return Map.of( - SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), - TaskType.SPARSE_EMBEDDING.toString(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - randomAlphaOfLength(8) + private static Model randomModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) ); } From c5de0da930dbc329f2000fae7475fe1d90b82488 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 14:22:05 +0100 Subject: [PATCH 04/26] Merge from feature branch --- .../action/bulk/BulkOperation.java | 4 +- .../action/bulk/BulkShardRequest.java | 18 ++-- .../ShardBulkInferenceActionFilter.java | 70 +++++++------- .../ShardBulkInferenceActionFilterTests.java | 96 ++++++++++--------- 4 files changed, 94 insertions(+), 94 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index b7a6387045e3d..452a9ec90443a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -209,8 +209,8 @@ private void executeBulkRequestsByShard(Map> requ requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels()); + if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 1b5494c6a68f5..39fa791a3e27d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -22,7 +23,6 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; -import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +34,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest> fieldsInferenceMetadata = null; + private transient FieldInferenceMetadata fieldsInferenceMetadataMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +51,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { - this.fieldsInferenceMetadata = fieldsInferenceMetadata; + public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { + this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public Map> consumeFieldInferenceMetadata() { - var ret = fieldsInferenceMetadata; - fieldsInferenceMetadata = null; + public FieldInferenceMetadata consumeFieldInferenceMetadata() { + FieldInferenceMetadata ret = fieldsInferenceMetadataMap; + fieldsInferenceMetadataMap = null; return ret; } /** * Public for test */ - public Map> getFieldsInferenceMetadata() { - return fieldsInferenceMetadata; + public FieldInferenceMetadata getFieldsInferenceMetadataMap() { + return fieldsInferenceMetadataMap; } public long totalSizeInBytes() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e679d3c970abf..984a20419b2c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -44,7 +45,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; /** @@ -81,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -96,7 +96,7 @@ public void app } private void processBulkShardRequest( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -112,13 +112,13 @@ private record FieldInferenceResponse(String field, Model model, ChunkedInferenc private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final Map> fieldInferenceMetadata; + private final FieldInferenceMetadata fieldInferenceMetadata; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -289,39 +289,35 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMetadata.entrySet()) { - String inferenceId = entry.getKey(); - for (var field : entry.getValue()) { - var value = XContentMapValues.extractValue(field, docMap); - if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent( - inferenceId, - k -> new ArrayList<>() - ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - inferenceResults.get(item.id()).failures.add( - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7f3ffbe596543..4a1825303b5a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -40,7 +41,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -79,7 +79,7 @@ public void testFilterNoop() throws Exception { CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); } finally { chainExecuted.countDown(); } @@ -91,7 +91,9 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + request.setFieldInferenceMetadata( + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -104,7 +106,7 @@ public void testInferenceNotFound() throws Exception { ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -118,11 +120,15 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - Map> inferenceFields = Map.of( - model.getInferenceEntityId(), - Set.of("field1"), - "inference_0", - Set.of("field2", "field3") + FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( + Map.of( + "field1", + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + "field2", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + "field3", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + ) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -144,19 +150,19 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map> inferenceFields = new HashMap<>(); + Map inferenceFieldsMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { - String inferenceId = randomFrom(inferenceModelMap.keySet()); String field = randomAlphaOfLengthBetween(5, 10); - var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); - res.add(field); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); } + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } @@ -167,7 +173,7 @@ public void testManyRandomDocs() throws Exception { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -186,7 +192,7 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(inferenceFields); + original.setFieldInferenceMetadata(fieldInferenceMetadata); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -257,42 +263,40 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - Map> inferenceFieldMap + FieldInferenceMetadata fieldInferenceMetadata ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : inferenceFieldMap.entrySet()) { - String inferenceId = entry.getKey(); - var model = modelMap.get(inferenceId); - for (var field : entry.getValue()) { - String text = randomAlphaOfLengthBetween(10, 100); - docMap.put(field, text); - if (model == null) { - // ignore results, the doc should fail with a resource not found exception - continue; - } - int numChunks = randomIntBetween(1, 5); - List chunks = new ArrayList<>(); - for (int i = 0; i < numChunks; i++) { - chunks.add(randomAlphaOfLengthBetween(5, 10)); - } - TaskType taskType = model.getTaskType(); - final ChunkedInferenceServiceResults results; - switch (taskType) { - case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); - break; + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + var model = modelMap.get(entry.getValue().inferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; - case SPARSE_EMBEDDING: - results = randomSparseEmbeddings(chunks); - break; + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; - default: - throw new AssertionError("Unknown task type " + taskType.name()); - } - model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + default: + throw new AssertionError("Unknown task type " + taskType.name()); } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); From cf62b1b79b5ccb6d369e400837a806c13d9b86b7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 22:02:59 +0100 Subject: [PATCH 05/26] Allow SemanticTextFieldMapper to be a multifield --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 83272a10f98d4..504752a6fbd23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -40,7 +40,7 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n)); private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); @@ -89,7 +89,7 @@ protected Parameter[] getParameters() { @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(context.buildFullName(name()), modelId.getValue(), meta.getValue()), copyTo); } } From f029015c60247de1eb876ae917ad646a7816fbf5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 22:03:42 +0100 Subject: [PATCH 06/26] Add multifields / copy_to tests to lookup and metadata --- .../index/mapper/MultiFieldTests.java | 3 + .../SemanticTextClusterMetadataTests.java | 118 +++++++++++++++++- 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java index d7df41131414e..6446033c07c5b 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java @@ -224,6 +224,9 @@ public void testSourcePathFields() throws IOException { final Set fieldsUsingSourcePath = new HashSet<>(); ((FieldMapper) mapper).sourcePathUsedBy().forEachRemaining(mapper1 -> fieldsUsingSourcePath.add(mapper1.name())); assertThat(fieldsUsingSourcePath, equalTo(Set.of("field.subfield1", "field.subfield2"))); + + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield1"), equalTo(Set.of("field"))); + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield2"), equalTo(Set.of("field"))); } public void testUnknownLegacyFieldsUnderKnownRootField() throws Exception { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index a7d3fcce26116..19a9acb01ee8a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -20,6 +20,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Set; public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @@ -39,7 +40,7 @@ public void testCreateIndexWithSemanticTextField() { ); } - public void testAddSemanticTextField() throws Exception { + public void testSingleSourceSemanticTextField() throws Exception { final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); @@ -53,10 +54,119 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals( - resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field"); + assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); + assertEquals(fieldInferenceOptions.sourceFields(), Set.of("field")); + } + + public void testMultiFieldsSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "top_field": { + "type": "text", + "fields": { + "semantic": { + "type": "semantic_text", + "model_id": "test_model" + } + } + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("top_field.semantic"); + assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); + assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field")); + } + + public void testCopyToSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "semantic": { + "type": "semantic_text", + "model_id": "test_model" + }, + "copy_origin_1": { + "type": "text", + "copy_to": "semantic" + }, + "copy_origin_2": { + "type": "text", + "copy_to": "semantic" + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("semantic"); + assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); + assertEquals(fieldInferenceOptions.sourceFields(), Set.of("semantic", "copy_origin_1", "copy_origin_2")); + } + + public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "top_field": { + "type": "text", + "fields": { + "semantic": { + "type": "semantic_text", + "model_id": "test_model" + } + } + }, + "copy_origin_1": { + "type": "text", + "copy_to": "top_field" + }, + "copy_origin_2": { + "type": "text", + "copy_to": "top_field" + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("top_field.semantic"); + assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); + assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field", "copy_origin_1", "copy_origin_2")); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From d3f9d86c2ab2c4832ae01c6492a859026040b1c7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 22:04:19 +0100 Subject: [PATCH 07/26] First iteration for adding inference support for copy_to / multifields --- .../ShardBulkInferenceActionFilter.java | 56 ++++++++++--------- .../ShardBulkInferenceActionFilterTests.java | 10 ++-- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 984a20419b2c8..35f830057a58f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -290,34 +290,38 @@ private Map> createFieldInferenceRequests(Bu } final Map docMap = indexRequest.sourceAsMap(); for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); String inferenceId = entry.getValue().inferenceId(); - var value = XContentMapValues.extractValue(field, docMap); - if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( - item.id(), - new FieldInferenceResponseAccumulator( + for (var field : entry.getValue().sourceFields()) { + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - inferenceResults.get(item.id()).failures.add( - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); + } } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4a1825303b5a7..4d816ae256e90 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -92,7 +92,7 @@ public void testFilterNoop() throws Exception { new BulkItemRequest[0] ); request.setFieldInferenceMetadata( - new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of("foo")))) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -123,11 +123,11 @@ public void testInferenceNotFound() throws Exception { FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( Map.of( "field1", - new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of("field1")), "field2", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of("field2")), "field3", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of("field3")) ) ); BulkItemRequest[] items = new BulkItemRequest[10]; @@ -154,7 +154,7 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of(field))); } FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); From 140caa3af27ca48611ded6d4979f26f7b8bb9244 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 22:59:15 +0100 Subject: [PATCH 08/26] Add tests for copy_to --- .../org/elasticsearch/index/mapper/CopyToMapperTests.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java index 5eacfe6f2e3ab..33341e6b36987 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; @@ -106,6 +107,12 @@ public void testCopyToFieldsParsing() throws Exception { fieldMapper = mapperService.documentMapper().mappers().getMapper("new_field"); assertThat(fieldMapper.typeName(), equalTo("long")); + + MappingLookup mappingLookup = mapperService.mappingLookup(); + assertThat(mappingLookup.sourcePaths("another_field"), equalTo(Set.of("copy_test", "int_to_str_test", "another_field"))); + assertThat(mappingLookup.sourcePaths("new_field"), equalTo(Set.of("new_field", "int_to_str_test"))); + assertThat(mappingLookup.sourcePaths("copy_test"), equalTo(Set.of("copy_test", "cyclic_test"))); + assertThat(mappingLookup.sourcePaths("cyclic_test"), equalTo(Set.of("cyclic_test", "copy_test"))); } public void testCopyToFieldsInnerObjectParsing() throws Exception { From 7d1c92aae1f0aa60e7f8da99feeb015a2fb59108 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 22:59:49 +0100 Subject: [PATCH 09/26] Spotless --- .../mapper/SemanticTextFieldMapper.java | 6 +++++- .../SemanticTextClusterMetadataTests.java | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 504752a6fbd23..e1e46aa228980 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -89,7 +89,11 @@ protected Parameter[] getParameters() { @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(context.buildFullName(name()), modelId.getValue(), meta.getValue()), copyTo); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(context.buildFullName(name()), modelId.getValue(), meta.getValue()), + copyTo + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 19a9acb01ee8a..d36fcb4cf0a4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -54,7 +54,11 @@ public void testSingleSourceSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = resultingState.metadata() + .index("test") + .getFieldInferenceMetadata() + .getFieldInferenceOptions() + .get("field"); assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); assertEquals(fieldInferenceOptions.sourceFields(), Set.of("field")); } @@ -87,7 +91,9 @@ public void testMultiFieldsSemanticTextField() throws Exception { singleTask(request) ); IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("top_field.semantic"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() + .getFieldInferenceOptions() + .get("top_field.semantic"); assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field")); } @@ -123,7 +129,9 @@ public void testCopyToSemanticTextField() throws Exception { singleTask(request) ); IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("semantic"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() + .getFieldInferenceOptions() + .get("semantic"); assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); assertEquals(fieldInferenceOptions.sourceFields(), Set.of("semantic", "copy_origin_1", "copy_origin_2")); } @@ -164,7 +172,9 @@ public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { singleTask(request) ); IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata().getFieldInferenceOptions().get("top_field.semantic"); + FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() + .getFieldInferenceOptions() + .get("top_field.semantic"); assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field", "copy_origin_1", "copy_origin_2")); } From e023a19d5156bae7a6f546d0034a40901391e4d6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 20 Mar 2024 09:15:45 +0100 Subject: [PATCH 10/26] Minor changes from previous PR --- .../index/mapper/FieldTypeLookupTests.java | 12 +++++----- .../SemanticTextClusterMetadataTests.java | 22 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 932eac3e60d27..2bfd1d9db385f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -37,9 +37,9 @@ public void testEmpty() { assertNotNull(names); assertThat(names, hasSize(0)); - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); + Map inferenceForFields = lookup.getInferenceIdsForFields(); + assertNotNull(inferenceForFields); + assertTrue(inferenceForFields.isEmpty()); } public void testAddNewField() { @@ -48,9 +48,9 @@ public void testAddNewField() { assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); + Map inferenceForFields = lookup.getInferenceIdsForFields(); + assertNotNull(inferenceForFields); + assertTrue(inferenceForFields.isEmpty()); } public void testAddFieldAlias() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index d36fcb4cf0a4e..13123433e32cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -22,6 +22,8 @@ import java.util.List; import java.util.Set; +import static org.hamcrest.CoreMatchers.equalTo; + public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @Override @@ -34,9 +36,9 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") ); - assertEquals( + assertThat( indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" + equalTo("test_model") ); } @@ -59,8 +61,8 @@ public void testSingleSourceSemanticTextField() throws Exception { .getFieldInferenceMetadata() .getFieldInferenceOptions() .get("field"); - assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); - assertEquals(fieldInferenceOptions.sourceFields(), Set.of("field")); + assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); + assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("field"))); } public void testMultiFieldsSemanticTextField() throws Exception { @@ -94,8 +96,8 @@ public void testMultiFieldsSemanticTextField() throws Exception { FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() .getFieldInferenceOptions() .get("top_field.semantic"); - assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); - assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field")); + assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); + assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("top_field"))); } public void testCopyToSemanticTextField() throws Exception { @@ -132,8 +134,8 @@ public void testCopyToSemanticTextField() throws Exception { FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() .getFieldInferenceOptions() .get("semantic"); - assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); - assertEquals(fieldInferenceOptions.sourceFields(), Set.of("semantic", "copy_origin_1", "copy_origin_2")); + assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); + assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("semantic", "copy_origin_1", "copy_origin_2"))); } public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { @@ -175,8 +177,8 @@ public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() .getFieldInferenceOptions() .get("top_field.semantic"); - assertEquals(fieldInferenceOptions.inferenceId(), "test_model"); - assertEquals(fieldInferenceOptions.sourceFields(), Set.of("top_field", "copy_origin_1", "copy_origin_2")); + assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); + assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("top_field", "copy_origin_1", "copy_origin_2"))); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From 05aa06f88bc2741ee886a9e67fd744e976170812 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Mar 2024 13:54:50 +0100 Subject: [PATCH 11/26] Add getMapper(field) to Mapper, so both field with multifields and object mappers can provide underlying field mappers --- .../elasticsearch/index/mapper/FieldAliasMapper.java | 5 +++++ .../org/elasticsearch/index/mapper/FieldMapper.java | 11 +++++++++++ .../java/org/elasticsearch/index/mapper/Mapper.java | 8 ++++++++ 3 files changed, 24 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java index 8aa29e6317d51..2cbdb79c4ce45 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java @@ -69,6 +69,11 @@ public Iterator iterator() { return Collections.emptyIterator(); } + @Override + public Mapper getMapper(String field) { + return null; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { return builder.startObject(simpleName()).field("type", CONTENT_TYPE).field(Names.PATH, path).endObject(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..53985393aa42f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -306,6 +306,17 @@ public Iterator iterator() { return multiFieldsIterator(); } + @Override + public Mapper getMapper(String field) { + while (iterator().hasNext()) { + Mapper mapper = iterator().next(); + if (mapper.simpleName().equals(field)) { + return mapper; + } + } + return null; + } + protected Iterator multiFieldsIterator() { return Iterators.forArray(multiFields.mappers); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java index 7c047125a80d3..64d36c1928499 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java @@ -145,4 +145,12 @@ public static FieldType freezeAndDeduplicateFieldType(FieldType fieldType) { * Defines how this mapper counts towards {@link MapperService#INDEX_MAPPING_TOTAL_FIELDS_LIMIT_SETTING}. */ public abstract int getTotalFieldsCount(); + + /** + * Returns a submapper for this mapper, if it exists. + * + * @param field field name from which to obtain the mapper + * @return submapper + */ + public abstract Mapper getMapper(String field); } From c80677f69d829a0da2b0608299645fd7f46299bb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 17:11:29 +0100 Subject: [PATCH 12/26] multifields support --- .../ShardBulkInferenceActionFilter.java | 9 ++-- .../mapper/InferenceResultFieldMapper.java | 43 ++++++++++++++----- .../inference/10_semantic_text_inference.yml | 34 +++++++++++++++ 3 files changed, 72 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 35f830057a58f..792c6d9945021 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -291,8 +291,9 @@ private Map> createFieldInferenceRequests(Bu final Map docMap = indexRequest.sourceAsMap(); for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { String inferenceId = entry.getValue().inferenceId(); - for (var field : entry.getValue().sourceFields()) { - var value = XContentMapValues.extractValue(field, docMap); + String fieldName = entry.getKey(); + for (var sourceField : entry.getValue().sourceFields()) { + var value = XContentMapValues.extractValue(sourceField, docMap); if (value == null) { continue; } @@ -311,13 +312,13 @@ private Map> createFieldInferenceRequests(Bu inferenceId, k -> new ArrayList<>() ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + fieldRequests.add(new FieldInferenceRequest(item.id(), fieldName, valueStr)); } else { inferenceResults.get(item.id()).failures.add( new ElasticsearchStatusException( "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, - field, + fieldName, value.getClass().getSimpleName() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index 2ede5419ab74e..4988734400330 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -148,10 +148,15 @@ public InferenceResultFieldMapper() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); + boolean withinLeafObject = context.path().isWithinLeafObject(); + try { + context.path().setWithinLeafObject(true); + XContentParser parser = context.parser(); + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); + parseAllFields(context); + } finally { + context.path().setWithinLeafObject(withinLeafObject); + } } private static void parseAllFields(DocumentParserContext context) throws IOException { @@ -164,12 +169,18 @@ private static void parseAllFields(DocumentParserContext context) throws IOExcep } } - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { + private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) + throws IOException { XContentParser parser = context.parser(); String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + Mapper mapper = findMapper(context.root(), fieldName); + if (mapper == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a field type", fieldName) + ); + } else if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -190,7 +201,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde fieldName, modelSettings ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); + parseFieldInferenceChunks(context, modelSettings, nestedObjectMapper); } else { logger.debug("Skipping unrecognized field name [" + currentName + "]"); advancePastCurrentFieldName(parser); @@ -198,10 +209,20 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde } } + private static Mapper findMapper(Mapper mapper, String fullPath) { + String[] pathElements = fullPath.split("\\."); + for (int i = 0; i < pathElements.length; i++) { + Mapper next = mapper.getMapper(pathElements[i]); + if (next == null) { + return null; + } + mapper = next; + } + return mapper; + } + private static void parseFieldInferenceChunks( DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, SemanticTextModelSettings modelSettings, NestedObjectMapper nestedObjectMapper ) throws IOException { @@ -243,6 +264,8 @@ private static void parseFieldInferenceChunkElement( if (childMapper instanceof FieldMapper fieldMapper) { parser.nextToken(); fieldMapper.parse(childContext); + // Reset leaf object after parsing the field + context.path().setWithinLeafObject(true); } else { // This should never happen, but fail parsing if it does so that it's not a silent failure throw new DocumentParsingException( diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 6008ebbcbedf8..b22df64e83622 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -309,3 +309,37 @@ setup: id: doc_1 body: non_inference_field: "non inference test" + + +--- +"semantic_text multifields calculate inference for parent field": + - do: + indices.create: + index: test-multifield-index + body: + mappings: + properties: + top_level_field: + type: text + fields: + semantic_multifield: + type: semantic_text + model_id: dense-inference-id + - do: + index: + index: test-multifield-index + id: doc_1 + body: + top_level_field: "multifield inference test" + + - do: + get: + index: test-multifield-index + id: doc_1 + + - match: { _source.top_level_field: "multifield inference test" } + - match: { _source._inference.top_level_field\.semantic_multifield.results.0.text: "multifield inference test" } + - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference + - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference + + From 5896c6022e1ba4658ac7c1a5ca4953e176f26533 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 17:37:21 +0100 Subject: [PATCH 13/26] Add copy_to support --- .../mapper/InferenceResultFieldMapper.java | 8 +++-- .../inference/10_semantic_text_inference.yml | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index 4988734400330..dcb5c86b5eff7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -363,6 +363,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + @SuppressWarnings("unchecked") public static void applyFieldInference( Map inferenceMap, String field, @@ -387,9 +388,10 @@ public static void applyFieldInference( results.getWriteableName() ); } - Map fieldMap = new LinkedHashMap<>(); + + Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); + List> fieldChunks = (List>) fieldMap.computeIfAbsent(InferenceResultFieldMapper.RESULTS, k -> new ArrayList<>()); + fieldChunks.addAll(chunks); } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index b22df64e83622..7f3107110c621 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -343,3 +343,38 @@ setup: - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference +--- +"semantic_text copy_to calculate inference for source fields": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "copy_to inference test" + inference_field: "inference test" + + - do: + get: + index: test-copy-to-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - length: {_source._inference.inference_field.results: 2} + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.inference_field.results.1.text: "copy_to inference test" } + + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.inference_field.results.1.inference From df0cc9006290b58e686639e72308c97dc9bf426b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 17:37:57 +0100 Subject: [PATCH 14/26] Spotless --- .../inference/mapper/InferenceResultFieldMapper.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index dcb5c86b5eff7..3e7d7ecdb69ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -169,8 +169,7 @@ private static void parseAllFields(DocumentParserContext context) throws IOExcep } } - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) - throws IOException { + private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { XContentParser parser = context.parser(); String fieldName = parser.currentName(); @@ -180,7 +179,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a field type", fieldName) ); - } else if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + } else if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -391,7 +390,10 @@ public static void applyFieldInference( Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - List> fieldChunks = (List>) fieldMap.computeIfAbsent(InferenceResultFieldMapper.RESULTS, k -> new ArrayList<>()); + List> fieldChunks = (List>) fieldMap.computeIfAbsent( + InferenceResultFieldMapper.RESULTS, + k -> new ArrayList<>() + ); fieldChunks.addAll(chunks); } } From 6d4bbf3a062dd8545663280c3b40115dbecffc14 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 17:42:29 +0100 Subject: [PATCH 15/26] Fix tests --- .../inference/mapper/InferenceResultFieldMapper.java | 7 +------ .../mapper/SemanticTextFieldMapperTests.java | 12 ------------ .../test/inference/10_semantic_text_inference.yml | 5 ++--- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index 3e7d7ecdb69ea..ffb342fea966f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -174,12 +174,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde XContentParser parser = context.parser(); String fieldName = parser.currentName(); Mapper mapper = findMapper(context.root(), fieldName); - if (mapper == null) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a field type", fieldName) - ); - } else if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + if ((mapper == null) || (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false)) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..c2aefecebbace 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -49,18 +49,6 @@ public void testModelIdNotPresent() throws IOException { assertThat(e.getMessage(), containsString("field [model_id] must be specified")); } - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - public void testUpdatesToModelIdNotSupported() throws IOException { MapperService mapperService = createMapperService( fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 7f3107110c621..13b4d922dfc43 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -373,8 +373,7 @@ setup: - match: { _source.inference_field: "inference test" } - length: {_source._inference.inference_field.results: 2} - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.inference_field.results.1.text: "copy_to inference test" } - - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.inference_field.results.0.text - exists: _source._inference.inference_field.results.1.inference + - exists: _source._inference.inference_field.results.1.text From 068615af6fc39313379fb69a5ae60ca5ddec5d89 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 18:03:46 +0100 Subject: [PATCH 16/26] Remove the need to get mappings by using MappingLookup.getFieldType() --- .../mapper/InferenceResultFieldMapper.java | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index ffb342fea966f..cb9b87f9935a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -173,8 +173,8 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde XContentParser parser = context.parser(); String fieldName = parser.currentName(); - Mapper mapper = findMapper(context.root(), fieldName); - if ((mapper == null) || (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false)) { + MappedFieldType fieldType = context.mappingLookup().getFieldType(fieldName); + if ((fieldType == null) || (SemanticTextFieldMapper.CONTENT_TYPE.equals(fieldType.typeName()) == false)) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -203,18 +203,6 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde } } - private static Mapper findMapper(Mapper mapper, String fullPath) { - String[] pathElements = fullPath.split("\\."); - for (int i = 0; i < pathElements.length; i++) { - Mapper next = mapper.getMapper(pathElements[i]); - if (next == null) { - return null; - } - mapper = next; - } - return mapper; - } - private static void parseFieldInferenceChunks( DocumentParserContext context, SemanticTextModelSettings modelSettings, From 5c3d9c450b702062f7e72c76045b22ba8e097816 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Mar 2024 18:04:06 +0100 Subject: [PATCH 17/26] Revert "Add getMapper(field) to Mapper, so both field with multifields and object mappers can provide underlying field mappers" This reverts commit 05aa06f88bc2741ee886a9e67fd744e976170812. --- .../elasticsearch/index/mapper/FieldAliasMapper.java | 5 ----- .../org/elasticsearch/index/mapper/FieldMapper.java | 11 ----------- .../java/org/elasticsearch/index/mapper/Mapper.java | 8 -------- 3 files changed, 24 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java index 2cbdb79c4ce45..8aa29e6317d51 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldAliasMapper.java @@ -69,11 +69,6 @@ public Iterator iterator() { return Collections.emptyIterator(); } - @Override - public Mapper getMapper(String field) { - return null; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { return builder.startObject(simpleName()).field("type", CONTENT_TYPE).field(Names.PATH, path).endObject(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 53985393aa42f..71fd9edd49903 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -306,17 +306,6 @@ public Iterator iterator() { return multiFieldsIterator(); } - @Override - public Mapper getMapper(String field) { - while (iterator().hasNext()) { - Mapper mapper = iterator().next(); - if (mapper.simpleName().equals(field)) { - return mapper; - } - } - return null; - } - protected Iterator multiFieldsIterator() { return Iterators.forArray(multiFields.mappers); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java index 64d36c1928499..7c047125a80d3 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapper.java @@ -145,12 +145,4 @@ public static FieldType freezeAndDeduplicateFieldType(FieldType fieldType) { * Defines how this mapper counts towards {@link MapperService#INDEX_MAPPING_TOTAL_FIELDS_LIMIT_SETTING}. */ public abstract int getTotalFieldsCount(); - - /** - * Returns a submapper for this mapper, if it exists. - * - * @param field field name from which to obtain the mapper - * @return submapper - */ - public abstract Mapper getMapper(String field); } From ace17dd67473cde7111619a9d6444dd156857222 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 26 Mar 2024 13:31:25 +0100 Subject: [PATCH 18/26] PR feedback --- .../xpack/inference/mapper/InferenceResultFieldMapper.java | 1 + .../test/inference/10_semantic_text_inference.yml | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index cb9b87f9935a2..a0065e2f1cf77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -150,6 +150,7 @@ public InferenceResultFieldMapper() { protected void parseCreateField(DocumentParserContext context) throws IOException { boolean withinLeafObject = context.path().isWithinLeafObject(); try { + // Disable dot expansion so there is no need to traverse subobjects for retrieving the field type context.path().setWithinLeafObject(true); XContentParser parser = context.parser(); failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 13b4d922dfc43..70ea12a45371f 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -337,9 +337,9 @@ setup: index: test-multifield-index id: doc_1 - - match: { _source.top_level_field: "multifield inference test" } - - match: { _source._inference.top_level_field\.semantic_multifield.results.0.text: "multifield inference test" } - - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference + - match: { _source.top_level_field: "multifield inference test" } + - length: { _source._inference.top_level_field\.semantic_multifield.results: 1 } + - match: { _source._inference.top_level_field\.semantic_multifield.results.0.text: "multifield inference test" } - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference From 47a22d7aee1db05372106f01d8fb344474948e9e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 26 Mar 2024 16:40:19 +0100 Subject: [PATCH 19/26] Fix merge with feature branch --- .../ShardBulkInferenceActionFilter.java | 67 +-- .../mapper/InferenceMetadataFieldMapper.java | 12 +- .../mapper/InferenceResultFieldMapper.java | 383 ------------- .../mapper/SemanticTextFieldMapper.java | 5 +- .../SemanticTextClusterMetadataTests.java | 6 +- .../InferenceResultFieldMapperTests.java | 527 ------------------ .../mapper/SemanticTextFieldMapperTests.java | 12 - 7 files changed, 47 insertions(+), 965 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index a4d3f1f64cace..b12ed5dfad2e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -301,42 +301,45 @@ private Map> createFieldInferenceRequests(Bu String fieldName = entry.getKey(); for (var sourceField : entry.getValue().sourceFields()) { - var value = XContentMapValues.extractValue(sourceField, docMap); - if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( - item.id(), - new FieldInferenceResponseAccumulator( + var value = XContentMapValues.extractValue(sourceField, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - fieldRequests.add(new FieldInferenceRequest(item.id(), fieldName, valueStr)); - hasInput = true; - } else { - inferenceResults.get(item.id()).failures.add( - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - fieldName, - value.getClass().getSimpleName() - ) - ); + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); + fieldRequests.add(new FieldInferenceRequest(item.id(), fieldName, valueStr)); + hasInput = true; + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + fieldName, + value.getClass().getSimpleName() + ) + ); + } } - } - if (hasInput == false) { - // remove the existing _inference field (if present) since none of the content require inference. - if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { - indexRequest.source(docMap); + if (hasInput == false) { + // remove the existing _inference field (if present) since none of the content require inference. + if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { + indexRequest.source(docMap); + } } } - } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 9eeb7a5407bc4..1b370ead21a2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -345,6 +345,8 @@ private void parseResultsObject( } parser.nextToken(); fieldMapper.parse(context); + // Reset leaf object after parsing the field + context.path().setWithinLeafObject(true); } if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { Set missingSubfields = REQUIRED_SUBFIELDS.stream() @@ -380,6 +382,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + @SuppressWarnings("unchecked") public static void applyFieldInference( Map inferenceMap, String field, @@ -404,11 +407,12 @@ public static void applyFieldInference( results.getWriteableName() ); } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); + + Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(CHUNKS, chunks); - inferenceMap.put(field, fieldMap); + List> fieldChunks = (List>) fieldMap.computeIfAbsent(CHUNKS, k -> new ArrayList<>()); + fieldChunks.addAll(chunks); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); } record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index a0065e2f1cf77..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - boolean withinLeafObject = context.path().isWithinLeafObject(); - try { - // Disable dot expansion so there is no need to traverse subobjects for retrieving the field type - context.path().setWithinLeafObject(true); - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parseAllFields(context); - } finally { - context.path().setWithinLeafObject(withinLeafObject); - } - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - MappedFieldType fieldType = context.mappingLookup().getFieldType(fieldName); - if ((fieldType == null) || (SemanticTextFieldMapper.CONTENT_TYPE.equals(fieldType.typeName()) == false)) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - // Reset leaf object after parsing the field - context.path().setWithinLeafObject(true); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - @SuppressWarnings("unchecked") - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - - Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - List> fieldChunks = (List>) fieldMap.computeIfAbsent( - InferenceResultFieldMapper.RESULTS, - k -> new ArrayList<>() - ); - fieldChunks.addAll(chunks); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 2445d5c8751a5..63ecb2a7e440e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -60,10 +60,7 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c.indexVersionCreated()), - notInMultiFields(CONTENT_TYPE) - ); + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated())); private final IndexVersion indexVersionCreated; private final SemanticTextModelSettings modelSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 39d94be7e509a..29c1e8ded0601 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -79,7 +79,7 @@ public void testMultiFieldsSemanticTextField() throws Exception { "fields": { "semantic": { "type": "semantic_text", - "model_id": "test_model" + "inference_id": "test_model" } } } @@ -111,7 +111,7 @@ public void testCopyToSemanticTextField() throws Exception { "properties": { "semantic": { "type": "semantic_text", - "model_id": "test_model" + "inference_id": "test_model" }, "copy_origin_1": { "type": "text", @@ -152,7 +152,7 @@ public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { "fields": { "semantic": { "type": "semantic_text", - "model_id": "test_model" + "inference_id": "test_model" } } }, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java deleted file mode 100644 index b5d75b528c6ab..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ /dev/null @@ -1,527 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.LuceneDocument; -import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MetadataMapperTestCase; -import org.elasticsearch.index.mapper.NestedLookup; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.LeafNestedDocuments; -import org.elasticsearch.search.NestedDocuments; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.model.TestModel; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; - -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; -import static org.hamcrest.Matchers.containsString; - -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - - private record VisitedChildDocInfo(String path, int numChunks) {} - - private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} - - @Override - protected String fieldName() { - return InferenceResultFieldMapper.NAME; - } - - @Override - protected boolean isConfigurable() { - return false; - } - - @Override - protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available - } - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException { - - } - - @Override - protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY)); - } - - public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) - ) - ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) - ); - - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 - ); - assertEquals(0, topDocs.totalHits.value); - } - }); - } - - public void testMissingSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); - - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - true, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), - new SparseVectorSubfieldOptions(true, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]") - ); - } - } - - public void testExtraSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) - ); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); - - Consumer checkParsedDocument = d -> { - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); - - List luceneDocs = d.docs(); - assertEquals(2, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), d.rootDoc(), visitedChildDocs); - assertEquals(d.rootDoc(), luceneDocs.get(1)); - assertNull(luceneDocs.get(1).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - }; - - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", "extra_value") - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", Map.of("k1", "v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", List.of("v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - Map extraSubfields = new HashMap<>(); - extraSubfields.put("extra_key", null); - - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - extraSubfields - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - } - - public void testMissingSemanticTextMapping() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> {})); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) - ) - ); - assertThat( - ex.getMessage(), - containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ) - ); - } - - private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { - mappingBuilder.startObject(fieldName); - mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("model_id", modelId); - mappingBuilder.endObject(); - } - - public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - double[] values = new double[5]; - for (int j = 0; j < values.length; j++) { - values[j] = randomDouble(); - } - chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); - } - return new ChunkedTextEmbeddingResults(chunks); - } - - public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - var tokens = new ArrayList(); - for (var token : input.split("\\s+")) { - tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); - } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); - } - return new ChunkedSparseEmbeddingResults(chunks); - } - - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); - } - - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults - ) throws IOException { - addSemanticTextInferenceResults( - sourceBuilder, - semanticTextInferenceResults, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of() - ); - } - - @SuppressWarnings("unchecked") - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults, - SparseVectorSubfieldOptions sparseVectorSubfieldOptions, - boolean includeTextSubfield, - Map extraSubfields - ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); - for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( - inferenceResultsMap, - semanticTextInferenceResult.fieldName, - randomModel(), - semanticTextInferenceResult.results - ); - Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(RESULTS); - for (var entry : fieldResultList) { - if (includeTextSubfield == false) { - entry.remove(INFERENCE_CHUNKS_TEXT); - } - if (sparseVectorSubfieldOptions.include == false) { - entry.remove(INFERENCE_CHUNKS_RESULTS); - } - entry.putAll(extraSubfields); - } - } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); - } - - private static Model randomModel() { - String serviceName = randomAlphaOfLengthBetween(5, 10); - String inferenceId = randomAlphaOfLengthBetween(5, 10); - return new TestModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - serviceName, - new TestModel.TestServiceSettings("my-model"), - new TestModel.TestTaskSettings(randomIntBetween(1, 100)), - new TestModel.TestSecretSettings(randomAlphaOfLength(10)) - ); - } - - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { - NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); - assertNotNull(mapper); - - BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); - BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); - for (String token : tokens) { - queryBuilder.add( - new BooleanClause(new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)), BooleanClause.Occur.MUST) - ); - } - queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); - } - - private static void assertValidChildDoc( - LuceneDocument childDoc, - LuceneDocument expectedParent, - Set visitedChildDocs - ) { - assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); - } - - private static void assertChildLeafNestedDocument( - LeafNestedDocuments leaf, - int advanceToDoc, - int expectedRootDoc, - Set visitedNestedIdentities - ) throws IOException { - - assertNotNull(leaf.advance(advanceToDoc)); - assertEquals(advanceToDoc, leaf.doc()); - assertEquals(expectedRootDoc, leaf.rootDoc()); - assertNotNull(leaf.nestedIdentity()); - visitedNestedIdentities.add(leaf.nestedIdentity()); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 1b5311ac9effb..6d63cbec08a07 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -106,18 +106,6 @@ public void testInferenceIdNotPresent() throws IOException { assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); } - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - public void testUpdatesToInferenceIdNotSupported() throws IOException { String fieldName = randomAlphaOfLengthBetween(5, 15); MapperService mapperService = createMapperService( From 5311424027041d6c3778908ac7c4a275eac11870 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Mar 2024 16:26:51 +0100 Subject: [PATCH 20/26] Fix merge with feature branch --- .../test/inference/10_semantic_text_inference.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index c5c86b9e874ef..852cc0e9e8e7e 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -325,7 +325,7 @@ setup: fields: semantic_multifield: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id - do: index: index: test-multifield-index @@ -354,7 +354,7 @@ setup: properties: inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id source_field: type: text copy_to: inference_field From 707b3f118a6f7c7991f4914f2cfd9a772638094f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Mar 2024 16:30:44 +0100 Subject: [PATCH 21/26] Remove multifield testing --- .../inference/10_semantic_text_inference.yml | 37 +------------------ 1 file changed, 2 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 852cc0e9e8e7e..9330366995fb1 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -311,39 +311,6 @@ setup: body: non_inference_field: "non inference test" - ---- -"semantic_text multifields calculate inference for parent field": - - do: - indices.create: - index: test-multifield-index - body: - mappings: - properties: - top_level_field: - type: text - fields: - semantic_multifield: - type: semantic_text - inference_id: dense-inference-id - - do: - index: - index: test-multifield-index - id: doc_1 - body: - top_level_field: "multifield inference test" - - - do: - get: - index: test-multifield-index - id: doc_1 - - - match: { _source.top_level_field: "multifield inference test" } - - length: { _source._inference.top_level_field\.semantic_multifield.chunks: 1 } - - match: { _source._inference.top_level_field\.semantic_multifield.chunks.0.text: "multifield inference test" } - - exists: _source._inference.top_level_field\.semantic_multifield.results.0.inference - - --- "semantic_text copy_to calculate inference for source fields": - do: @@ -372,8 +339,8 @@ setup: index: test-copy-to-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - length: {_source._inference.inference_field.results: 2} + - match: { _source.inference_field: "inference test" } + - length: {_source._inference.inference_field.chunks: 2} - exists: _source._inference.inference_field.chunks.0.inference - exists: _source._inference.inference_field.chunks.0.text - exists: _source._inference.inference_field.chunks.1.inference From 5e5b32afff3dce30ff5b1bdfced0de5e779cd80a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Mar 2024 16:47:37 +0100 Subject: [PATCH 22/26] Now, actually drop multifields support --- .../mapper/SemanticTextFieldMapper.java | 5 +- .../SemanticTextClusterMetadataTests.java | 78 ------------------- .../mapper/SemanticTextFieldMapperTests.java | 12 +++ 3 files changed, 16 insertions(+), 79 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 63ecb2a7e440e..2445d5c8751a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -60,7 +60,10 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated())); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); private final IndexVersion indexVersionCreated; private final SemanticTextModelSettings modelSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 29c1e8ded0601..3496c32ab2cbc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -65,41 +65,6 @@ public void testSingleSourceSemanticTextField() throws Exception { assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("field"))); } - public void testMultiFieldsSemanticTextField() throws Exception { - final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); - final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); - final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); - final ClusterService clusterService = getInstanceFromNode(ClusterService.class); - - final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { - "properties": { - "top_field": { - "type": "text", - "fields": { - "semantic": { - "type": "semantic_text", - "inference_id": "test_model" - } - } - } - } - } - """); - request.indices(new Index[] { indexService.index() }); - final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( - clusterService.state(), - putMappingExecutor, - singleTask(request) - ); - IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() - .getFieldInferenceOptions() - .get("top_field.semantic"); - assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); - assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("top_field"))); - } - public void testCopyToSemanticTextField() throws Exception { final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); @@ -138,49 +103,6 @@ public void testCopyToSemanticTextField() throws Exception { assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("semantic", "copy_origin_1", "copy_origin_2"))); } - public void testCopyToAndMultiFieldsSemanticTextField() throws Exception { - final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); - final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); - final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); - final ClusterService clusterService = getInstanceFromNode(ClusterService.class); - - final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { - "properties": { - "top_field": { - "type": "text", - "fields": { - "semantic": { - "type": "semantic_text", - "inference_id": "test_model" - } - } - }, - "copy_origin_1": { - "type": "text", - "copy_to": "top_field" - }, - "copy_origin_2": { - "type": "text", - "copy_to": "top_field" - } - } - } - """); - request.indices(new Index[] { indexService.index() }); - final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( - clusterService.state(), - putMappingExecutor, - singleTask(request) - ); - IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() - .getFieldInferenceOptions() - .get("top_field.semantic"); - assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); - assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("top_field", "copy_origin_1", "copy_origin_2"))); - } - private static List singleTask(PutMappingClusterStateUpdateRequest request) { return Collections.singletonList(new MetadataMappingService.PutMappingClusterStateUpdateTask(request, ActionListener.running(() -> { throw new AssertionError("task should not complete publication"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 6d63cbec08a07..1b5311ac9effb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -106,6 +106,18 @@ public void testInferenceIdNotPresent() throws IOException { assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); } + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + public void testUpdatesToInferenceIdNotSupported() throws IOException { String fieldName = randomAlphaOfLengthBetween(5, 15); MapperService mapperService = createMapperService( From 29dd33ed1b0939b479616162579cbfd7f8187ce8 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Mar 2024 17:17:29 +0100 Subject: [PATCH 23/26] Fix merge with main --- .../java/org/elasticsearch/index/mapper/MapperTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index fa0f0e1b95f54..34ccc4599811b 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -1030,7 +1030,7 @@ public final void testMinimalIsInvalidInRoutingPath() throws IOException { } } - private String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { if (mapper instanceof FieldMapper fieldMapper && fieldMapper.fieldType().isDimension() == false) { return "All fields that match routing_path must be configured with [time_series_dimension: true] " + "or flattened fields with a list of dimensions in [time_series_dimensions] and " From 82ffb5beca0d5f3b21aab3a7c70f2b79579830a9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 1 Apr 2024 10:16:18 +0200 Subject: [PATCH 24/26] Fix merge with feature branch --- .../filter/ShardBulkInferenceActionFilter.java | 7 +++++-- .../SemanticTextClusterMetadataTests.java | 15 ++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index f1d521f6e9c5a..0fb75bddedc31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -416,7 +416,7 @@ private Map> createFieldInferenceRequests(Bu for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); - for (var sourceField : entry.getValue().sourceFields()) { + for (var sourceField : entry.getSourceFields()) { Object inferenceResult = inferenceMap.remove(field); var value = XContentMapValues.extractValue(sourceField, docMap); if (value == null) { @@ -435,7 +435,10 @@ private Map> createFieldInferenceRequests(Bu } ensureResponseAccumulatorSlot(item.id()); if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); } else { addInferenceResponseFailure( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 8a721e7fa6285..1c4a2f561ad4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -16,11 +16,12 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.hamcrest.Matchers; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Set; import static org.hamcrest.CoreMatchers.equalTo; @@ -56,7 +57,6 @@ public void testSingleSourceSemanticTextField() throws Exception { assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); } - public void testCopyToSemanticTextField() throws Exception { final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); @@ -88,11 +88,12 @@ public void testCopyToSemanticTextField() throws Exception { singleTask(request) ); IndexMetadata indexMetadata = resultingState.metadata().index("test"); - FieldInferenceMetadata.FieldInferenceOptions fieldInferenceOptions = indexMetadata.getFieldInferenceMetadata() - .getFieldInferenceOptions() - .get("semantic"); - assertThat(fieldInferenceOptions.inferenceId(), equalTo("test_model")); - assertThat(fieldInferenceOptions.sourceFields(), equalTo(Set.of("semantic", "copy_origin_1", "copy_origin_2"))); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get("semantic"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo("test_model")); + assertThat( + Arrays.asList(inferenceFieldMetadata.getSourceFields()), + Matchers.containsInAnyOrder("semantic", "copy_origin_1", "copy_origin_2") + ); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From 6b262006970a3a4e53d135b0c837c41bbfd77e99 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 1 Apr 2024 20:34:53 +0200 Subject: [PATCH 25/26] Check that bulk updates use all source fields for updating a semantic_text field --- .../ShardBulkInferenceActionFilter.java | 15 +++--- .../inference/10_semantic_text_inference.yml | 46 ++++++++++++++++++- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 0fb75bddedc31..ebe480a68c3fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -388,10 +388,12 @@ private Map> createFieldInferenceRequests(Bu // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } + boolean isUpdateRequest = false; final IndexRequest indexRequest; if (item.request() instanceof IndexRequest ir) { indexRequest = ir; } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; if (updateRequest.script() != null) { addInferenceResponseFailure( item.id(), @@ -409,25 +411,20 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - final Map inferenceMap = XContentMapValues.nodeMapValue( - docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), - InferenceMetadataFieldMapper.NAME - ); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); for (var sourceField : entry.getSourceFields()) { - Object inferenceResult = inferenceMap.remove(field); var value = XContentMapValues.extractValue(sourceField, docMap); if (value == null) { - if (inferenceResult != null) { + if (isUpdateRequest) { addInferenceResponseFailure( item.id(), new ElasticsearchStatusException( - "The field [{}] is referenced in the [{}] metadata field but has no value", + "Field [{}] must be specified on an update request to calculate inference for field [{}]", RestStatus.BAD_REQUEST, - field, - InferenceMetadataFieldMapper.NAME + sourceField, + field ) ); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 2bde3af2777fe..08ba32d08bd62 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -392,6 +392,9 @@ setup: source_field: type: text copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field - do: index: @@ -400,6 +403,7 @@ setup: body: source_field: "copy_to inference test" inference_field: "inference test" + another_source_field: "another copy_to inference test" - do: get: @@ -407,8 +411,48 @@ setup: id: doc_1 - match: { _source.inference_field: "inference test" } - - length: {_source._inference.inference_field.chunks: 2} + - length: { _source._inference.inference_field.chunks: 3 } - exists: _source._inference.inference_field.chunks.0.inference - exists: _source._inference.inference_field.chunks.0.text - exists: _source._inference.inference_field.chunks.1.inference - exists: _source._inference.inference_field.chunks.1.text + - exists: _source._inference.inference_field.chunks.2.inference + - exists: _source._inference.inference_field.chunks.2.text + + +--- +"semantic_text copy_to needs values for every source field for updates": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + # Not every source field needed on creation + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "a single source field provided" + inference_field: "inference test" + + # Every source field needed on bulk updates + - do: + bulk: + body: + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' + - '{"doc": {"source_field": "a single source field is kept as provided via bulk", "inference_field": "updated inference test" }}' + + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Field [another_source_field] must be specified on an update request to calculate inference for field [inference_field]" } From bf5b83711c77502b3dcde71107306a619120da59 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 1 Apr 2024 20:55:19 +0200 Subject: [PATCH 26/26] Add check for inference results when no value is provided --- .../filter/ShardBulkInferenceActionFilter.java | 15 +++++++++++++++ .../inference/10_semantic_text_inference.yml | 16 ++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index ebe480a68c3fd..2e6f66c64fa95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -411,9 +411,14 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); + final Map inferenceMap = XContentMapValues.nodeMapValue( + docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), + InferenceMetadataFieldMapper.NAME + ); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); + Object inferenceResult = inferenceMap.remove(field); for (var sourceField : entry.getSourceFields()) { var value = XContentMapValues.extractValue(sourceField, docMap); if (value == null) { @@ -427,6 +432,16 @@ private Map> createFieldInferenceRequests(Bu field ) ); + } else if (inferenceResult != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "The field [{}] is referenced in the [{}] metadata field but has no value", + RestStatus.BAD_REQUEST, + field, + InferenceMetadataFieldMapper.NAME + ) + ); } continue; } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 08ba32d08bd62..0a07a88d230ef 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -378,6 +378,22 @@ setup: - match: { items.0.update.status: 400 } - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } +--- +"Fails when providing inference results and there is no value for field": + - do: + catch: /The field \[inference_field\] is referenced in the \[_inference\] metadata field but has no value/ + index: + index: test-sparse-index + id: doc_1 + body: + _inference: + inference_field: + chunks: + - text: "inference test" + inference: + "hello": 0.123 + + --- "semantic_text copy_to calculate inference for source fields": - do: