From 7aaa3b6d458796a288ba05ff2ee6a28fc48f2bc3 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 7 Mar 2024 14:29:33 +0000 Subject: [PATCH 01/13] Revert "Extract interface from ModelRegistry so it can be used from server (#105012)" This reverts commit f4d3ab9df2a89879ac56262e618c0058fcb70d0e. --- .../action/bulk/BulkOperation.java | 114 +-- .../BulkShardRequestInferenceProvider.java | 319 --------- .../action/bulk/TransportBulkAction.java | 38 +- .../bulk/TransportSimulateBulkAction.java | 4 +- .../inference/InferenceServiceRegistry.java | 62 +- .../InferenceServiceRegistryImpl.java | 64 -- .../inference/ModelRegistry.java | 99 --- .../elasticsearch/node/NodeConstruction.java | 15 - .../plugins/InferenceRegistryPlugin.java | 22 - .../action/bulk/BulkOperationTests.java | 657 ------------------ ...ActionIndicesThatCannotBeCreatedTests.java | 8 +- .../bulk/TransportBulkActionIngestTests.java | 8 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 16 +- ...gistryImplIT.java => ModelRegistryIT.java} | 52 +- .../xpack/inference/InferencePlugin.java | 37 +- .../TransportDeleteInferenceModelAction.java | 2 +- .../TransportGetInferenceModelAction.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- ...emanticTextInferenceResultFieldMapper.java | 16 +- .../mapper}/SemanticTextModelSettings.java | 13 +- ...elRegistryImpl.java => ModelRegistry.java} | 82 ++- ...icTextInferenceResultFieldMapperTests.java | 6 +- ...ImplTests.java => ModelRegistryTests.java} | 34 +- 25 files changed, 212 insertions(+), 1466 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/ModelRegistry.java delete mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java delete mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryImplIT.java => ModelRegistryIT.java} (86%) rename {server/src/main/java/org/elasticsearch/inference => x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper}/SemanticTextModelSettings.java (89%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImpl.java => ModelRegistry.java} (86%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImplTests.java => ModelRegistryTests.java} (92%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 2b84ec8746cd2..1d95f430d5c7e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -36,8 +35,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -47,7 +44,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -73,8 +69,6 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; - private final InferenceServiceRegistry inferenceServiceRegistry; - private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -88,8 +82,6 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -105,8 +97,6 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; - this.inferenceServiceRegistry = inferenceServiceRegistry; - this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -199,30 +189,7 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInstance( - inferenceServiceRegistry, - modelRegistry, - clusterState, - requestsByShard.keySet(), - new ActionListener() { - @Override - public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { - processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); - } - - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error loading inference models", e); - } - } - ); - } - - void processRequestsByShards( - Map> requestsByShard, - ClusterState clusterState, - BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider - ) { + String nodeId = clusterService.localNode().getId(); Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -230,68 +197,29 @@ void processRequestsByShards( // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); - BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); - - Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); - bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { - @Override - public void onResponse(BulkShardRequest inferenceBulkShardRequest) { - executeBulkShardRequest( - inferenceBulkShardRequest, - ActionListener.releaseAfter(ActionListener.noop(), ref), - bulkItemFailedListener - ); - } - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error performing inference", e); - } - }, bulkItemFailedListener); + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); } } } - private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } - return bulkShardRequest; - } - - // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { - final String indexName = itemRequest.index(); - - DocWriteRequest docWriteRequest = itemRequest.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - } - - private void executeBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer bulkItemErrorListener - ) { - if (bulkShardRequest.items().length == 0) { - // No requests to execute due to previous errors, terminate early - listener.onResponse(bulkShardRequest); - return; - } - + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -302,17 +230,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - listener.onResponse(bulkShardRequest); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - BulkItemRequest[] items = bulkShardRequest.items(); - for (BulkItemRequest item : items) { - bulkItemErrorListener.accept(item, e); + for (BulkItemRequest request : bulkShardRequest.items()) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - listener.onFailure(e); + releaseOnFinish.close(); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java deleted file mode 100644 index 4b7a67e9ca0e3..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ /dev/null @@ -1,319 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.common.TriConsumer; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; -import java.util.stream.Collectors; - -/** - * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. - */ -public class BulkShardRequestInferenceProvider { - - // Root field name for storing inference results - public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; - - // Contains the original text for the field - - public static final String INFERENCE_RESULTS = "inference_results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - private final ClusterState clusterState; - private final Map inferenceProvidersMap; - - private record InferenceProvider(Model model, InferenceService service) { - private InferenceProvider { - Objects.requireNonNull(model); - Objects.requireNonNull(service); - } - } - - BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { - this.clusterState = clusterState; - this.inferenceProvidersMap = inferenceProvidersMap; - } - - public static void getInstance( - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - ClusterState clusterState, - Set shardIds, - ActionListener listener - ) { - Set inferenceIds = new HashSet<>(); - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); - inferenceIds.addAll(fieldsForModels.keySet()); - }); - final Map inferenceProviderMap = new ConcurrentHashMap<>(); - Runnable onModelLoadingComplete = () -> listener.onResponse( - new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) - ); - try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { - for (var inferenceId : inferenceIds) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - InferenceProvider inferenceProvider = new InferenceProvider( - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ), - service.get() - ); - inferenceProviderMap.put(inferenceId, inferenceProvider); - } - } - - @Override - public void onFailure(Exception e) { - // Failure on loading a model should not prevent the rest from being loaded and used. - // When the model is actually retrieved via the inference ID in the inference process, it will fail - // and the user will get the details on the inference failure. - } - }; - - modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); - } - } - } - - /** - * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from - * the original request will be modified with the inference results, to avoid copying the entire requests from - * the original bulk request. - * - * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. - * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, - * which may have fewer items than the original because of inference failures - * @param onBulkItemFailure invoked when a bulk item fails inference - */ - public void processBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer onBulkItemFailure - ) { - - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); - // No inference fields? Terminate early - if (fieldsForModels.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - - Set failedItems = Collections.synchronizedSet(new HashSet<>()); - Runnable onInferenceComplete = () -> { - if (failedItems.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - // Remove failed items from the original bulk shard request - BulkItemRequest[] originalItems = bulkShardRequest.items(); - BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; - for (int i = 0, j = 0; i < originalItems.length; i++) { - if (failedItems.contains(i) == false) { - newItems[j++] = originalItems[i]; - } - } - BulkShardRequest newBulkShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - newItems - ); - listener.onResponse(newBulkShardRequest); - }; - TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { - failedItems.add(i); - onBulkItemFailure.accept(bulkItemRequest, e); - }; - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest bulkItemRequest = items[i]; - // Bulk item might be null because of previous errors, skip in that case - if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest( - bulkItemRequest, - fieldsForModels, - i, - onBulkItemFailureWithIndex, - bulkItemReqRef.acquire() - ); - } - } - } - } - - @SuppressWarnings("unchecked") - private void performInferenceOnBulkItemRequest( - BulkItemRequest bulkItemRequest, - Map> fieldsForModels, - Integer itemIndex, - TriConsumer onBulkItemFailure, - Releasable releaseOnFinish - ) { - - DocWriteRequest docWriteRequest = bulkItemRequest.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; - } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - Map rootInferenceFieldMap; - try { - rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_INFERENCE_FIELD, - k -> new HashMap() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") - ); - return; - } - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { - continue; - } - - InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); - if (inferenceProvider == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("No inference provider found for model ID " + modelId) - ); - return; - } - ActionListener inferenceResultsListener = new ActionListener<>() { - @Override - public void onResponse(InferenceServiceResults results) { - if (results == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ) - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { - String inferenceFieldName = inferenceFieldNames.get(i++); - Map inferenceFieldResult = new LinkedHashMap<>(); - inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap()); - inferenceFieldResult.put( - INFERENCE_RESULTS, - List.of( - Map.of( - INFERENCE_CHUNKS_RESULTS, - inferenceResults.asMap("output").get("output"), - INFERENCE_CHUNKS_TEXT, - docMap.get(inferenceFieldName) - ) - ) - ); - rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); - } - } - - @Override - public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); - } - }; - inferenceProvider.service() - .infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - InputType.INGEST, - ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) - ); - } - } - } - - private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String) { - inferenceFieldNames.add(inferenceField); - } - } - return inferenceFieldNames; - } -} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 43076daf5bffe..a2445e95a572f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,8 +57,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -100,8 +98,6 @@ public class TransportBulkAction extends HandledTransportAction releasingListener) { - threadPool.executor(Names.WRITE).execute(new ActionRunnable<>(releasingListener) { + threadPool.executor(executorName).execute(new ActionRunnable<>(releasingListener) { @Override protected void doRun() { doInternalExecute(task, bulkRequest, executorName, releasingListener); @@ -425,13 +409,13 @@ protected void createMissingIndicesAndIndexData( final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -649,10 +633,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { new BulkOperation( task, @@ -666,8 +650,6 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, - modelRegistry, - inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index c8dc3e7b7ffd5..f65d0f462fde6 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -58,9 +58,7 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime, - null, - null + System::nanoTime ); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index ce6f1b21b734c..d5973807d9d78 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -13,41 +13,49 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry implements Closeable { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } -public interface InferenceServiceRegistry extends Closeable { - void init(Client client); - - Map getServices(); - - Optional getService(String serviceName); - - List getNamedWriteables(); - - class NoopInferenceServiceRegistry implements InferenceServiceRegistry { - public NoopInferenceServiceRegistry() {} + public void init(Client client) { + services.values().forEach(s -> s.init(client)); + } - @Override - public void init(Client client) {} + public Map getServices() { + return services; + } - @Override - public Map getServices() { - return Map.of(); - } + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } - @Override - public Optional getService(String serviceName) { - return Optional.empty(); - } + public List getNamedWriteables() { + return namedWriteables; + } - @Override - public List getNamedWriteables() { - return List.of(); + @Override + public void close() throws IOException { + for (var service : services.values()) { + service.close(); } - - @Override - public void close() throws IOException {} } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java deleted file mode 100644 index f0a990ded98ce..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class InferenceServiceRegistryImpl implements InferenceServiceRegistry { - - private final Map services; - private final List namedWriteables = new ArrayList<>(); - - public InferenceServiceRegistryImpl( - List inferenceServicePlugins, - InferenceServiceExtension.InferenceServiceFactoryContext factoryContext - ) { - // TODO check names are unique - services = inferenceServicePlugins.stream() - .flatMap(r -> r.getInferenceServiceFactories().stream()) - .map(factory -> factory.create(factoryContext)) - .collect(Collectors.toMap(InferenceService::name, Function.identity())); - } - - @Override - public void init(Client client) { - services.values().forEach(s -> s.init(client)); - } - - @Override - public Map getServices() { - return services; - } - - @Override - public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); - } - - @Override - public List getNamedWriteables() { - return namedWriteables; - } - - @Override - public void close() throws IOException { - for (var service : services.values()) { - service.close(); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java deleted file mode 100644 index fa90d5ba6f756..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; -import java.util.Map; - -public interface ModelRegistry { - - /** - * Get a model. - * Secret settings are not included - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModel(String inferenceEntityId, ActionListener listener); - - /** - * Get a model with its secret settings - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModelWithSecrets(String inferenceEntityId, ActionListener listener); - - /** - * Get all models of a particular task type. - * Secret settings are not included - * @param taskType The task type - * @param listener Models listener - */ - void getModelsByTaskType(TaskType taskType, ActionListener> listener); - - /** - * Get all models. - * Secret settings are not included - * @param listener Models listener - */ - void getAllModels(ActionListener> listener); - - void storeModel(Model model, ActionListener listener); - - void deleteModel(String modelId, ActionListener listener); - - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) {} - - class NoopModelRegistry implements ModelRegistry { - @Override - public void getModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void getAllModels(ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void storeModel(Model model, ActionListener listener) { - fail(listener); - } - - @Override - public void deleteModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - fail(listener); - } - - private static void fail(ActionListener listener) { - listener.onFailure(new IllegalArgumentException("No model registry configured")); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 19a6d200189f2..abc23b63aa1b6 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -126,8 +126,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -146,7 +144,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1104,18 +1101,6 @@ record PluginServiceInstances( ); } - // Register noop versions of inference services if Inference plugin is not available - Optional inferenceRegistryPlugin = getSinglePlugin(InferenceRegistryPlugin.class); - modules.bindToInstance( - InferenceServiceRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getInferenceServiceRegistry) - .orElse(new InferenceServiceRegistry.NoopInferenceServiceRegistry()) - ); - modules.bindToInstance( - ModelRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getModelRegistry).orElse(new ModelRegistry.NoopModelRegistry()) - ); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java deleted file mode 100644 index 696c3a067dad1..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; - -/** - * Plugins that provide inference services should implement this interface. - * There should be a single one in the classpath, as we currently support a single instance for ModelRegistry / InfereceServiceRegistry. - */ -public interface InferenceRegistryPlugin { - InferenceServiceRegistry getInferenceServiceRegistry(); - - ModelRegistry getModelRegistry(); -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java deleted file mode 100644 index 2ce7b161d3dd1..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ /dev/null @@ -1,657 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.IndexAbstraction; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.service.ClusterApplierService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatcher; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyMap; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -public class BulkOperationTests extends ESTestCase { - - private static final String INDEX_NAME = "test-index"; - private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; - private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; - private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; - private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; - private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; - private static final String SERVICE_1_ID = "elser_v2"; - private static final String SERVICE_2_ID = "e5"; - private static final String INFERENCE_FAILED_MSG = "Inference failed"; - private static TestThreadPool threadPool; - - public void testNoInference() { - - Map> fieldsForModels = Map.of(); - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference not invoked - verifyNoMoreInteractions(modelRegistry); - verifyNoMoreInteractions(inferenceServiceRegistry); - } - - private static Model mockModel(String inferenceServiceId) { - Model model = mock(Model.class); - - when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); - TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; - when(model.getTaskType()).thenReturn(taskType); - - ServiceSettings serviceSettings = mock(ServiceSettings.class); - when(model.getServiceSettings()).thenReturn(serviceSettings); - SimilarityMeasure similarity = switch (randomInt(2)) { - case 0 -> SimilarityMeasure.COSINE; - case 1 -> SimilarityMeasure.DOT_PRODUCT; - default -> null; - }; - when(serviceSettings.similarity()).thenReturn(similarity); - when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); - - return model; - } - - public void testFailedBulkShardRequest() { - - Map> fieldsForModels = Map.of(); - ModelRegistry modelRegistry = createModelRegistry(Map.of()); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - true, - request -> new BulkShardResponse( - request.shardId(), - new BulkItemResponse[] { - BulkItemResponse.failure( - 0, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure( - INDEX_NAME, - randomIdentifier(), - new IllegalArgumentException("Error on bulk shard request") - ) - ) } - ) - ); - verify(bulkOperationListener).onResponse(any()); - - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse[] items = bulkResponse.getItems(); - assertTrue(items[0].isFailed()); - } - - @SuppressWarnings("unchecked") - public void testInference() { - - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) - ); - - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - SECOND_INFERENCE_FIELD_SERVICE_1, - secondInferenceTextService1, - INFERENCE_FIELD_SERVICE_2, - inferenceTextService2, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference results - verifyInferenceServiceInvoked( - modelRegistry, - INFERENCE_SERVICE_1_ID, - inferenceService1, - model1, - List.of(firstInferenceTextService1, secondInferenceTextService1) - ); - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - checkInferenceResults( - originalSource, - writtenDocSource, - FIRST_INFERENCE_FIELD_SERVICE_1, - SECOND_INFERENCE_FIELD_SERVICE_1, - INFERENCE_FIELD_SERVICE_2 - ); - } - - public void testFailedInference() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); - - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); - - } - - public void testInferenceFailsForIncorrectRootObject() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - "incorrect_root_object" - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); - } - - public void testInferenceIdNotFound() { - - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceService(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - INFERENCE_FIELD_SERVICE_2, - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) - ); - } - - @SuppressWarnings("unchecked") - private static void checkInferenceResults( - Map docSource, - Map writtenDocSource, - String... inferenceFieldNames - ) { - - Map inferenceRootResultField = (Map) writtenDocSource.get( - BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD - ); - - for (String inferenceFieldName : inferenceFieldNames) { - Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); - assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(2)); - Map modelSettings = (Map) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME); - assertNotNull(modelSettings); - assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName())); - assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName())); - - List> inferenceResultElement = (List>) inferenceService1FieldResults.get( - INFERENCE_RESULTS - ); - assertFalse(inferenceResultElement.isEmpty()); - assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); - assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); - } - } - - private static void verifyInferenceServiceInvoked( - ModelRegistry modelRegistry, - String inferenceService1Id, - InferenceService inferenceService, - Model model, - Collection inferenceTexts - ) { - verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfigWithSecrets( - eq(inferenceService1Id), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ); - verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); - verifyNoMoreInteractions(inferenceService); - } - - private static ArgumentMatcher> containsInAnyOrder(Collection expected) { - return new ArgumentMatcher<>() { - @Override - public boolean matches(List argument) { - return argument.containsAll(expected) && argument.size() == expected.size(); - } - - @Override - public String toString() { - return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; - } - }; - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - Map> fieldsForModels, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - boolean expectTransportShardBulkActionToExecute, - ActionListener bulkOperationListener - ) { - return runBulkOperation( - docSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - expectTransportShardBulkActionToExecute, - successfulBulkShardResponse - ); - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - Map> fieldsForModels, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - ActionListener bulkOperationListener, - boolean expectTransportShardBulkActionToExecute, - Function bulkShardResponseSupplier - ) { - Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); - IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldsForModels(fieldsForModels) - .settings(settings) - .numberOfShards(1) - .numberOfReplicas(0) - .build(); - ClusterService clusterService = createClusterService(indexMetadata); - - IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); - when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); - - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); - - NodeClient client = mock(NodeClient.class); - - ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); - doAnswer(invocation -> { - BulkShardRequest request = invocation.getArgument(1); - ActionListener bulkShardResponseListener = invocation.getArgument(2); - bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); - return null; - }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); - - Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); - BulkOperation bulkOperation = new BulkOperation( - task, - threadPool, - ThreadPool.Names.WRITE, - clusterService, - bulkRequest, - client, - new AtomicArray<>(bulkRequest.requests.size()), - new HashMap<>(), - indexResolver, - () -> System.nanoTime(), - System.nanoTime(), - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - - bulkOperation.doRun(); - if (expectTransportShardBulkActionToExecute) { - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); - return bulkShardRequestCaptor.getValue(); - } - - return null; - } - - private static final Function successfulBulkShardResponse = (request) -> { - return new BulkShardResponse( - request.shardId(), - Arrays.stream(request.items()) - .filter(Objects::nonNull) - .map( - item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) - ) - ) - .toArray(BulkItemResponse[]::new) - ); - }; - - private static InferenceService createInferenceService(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); - List texts = invocation.getArgument(1); - List inferenceResults = new ArrayList<>(); - for (int i = 0; i < texts.size(); i++) { - inferenceResults.add(createInferenceResults()); - } - doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); - - listener.onResponse(inferenceServiceResults); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceService createInferenceServiceThatFails(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceResults createInferenceResults() { - InferenceResults inferenceResults = mock(InferenceResults.class); - when(inferenceResults.asMap(any())).then( - invocation -> Map.of( - (String) invocation.getArguments()[0], - Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) - ) - ); - return inferenceResults; - } - - private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { - InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); - return inferenceServiceRegistry; - } - - private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { - ModelRegistry modelRegistry = mock(ModelRegistry.class); - // Fails for unknown inference ids - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IllegalArgumentException("Model not found")); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { - ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - serviceId, - emptyMap(), - emptyMap() - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); - }); - - return modelRegistry; - } - - private static ClusterService createClusterService(IndexMetadata indexMetadata) { - Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); - - ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); - when(clusterService.state()).thenReturn(clusterState); - - ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); - when(clusterApplierService.state()).thenReturn(clusterState); - when(clusterApplierService.threadPool()).thenReturn(threadPool); - when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); - return clusterService; - } - - @BeforeClass - public static void createThreadPool() { - threadPool = new TestThreadPool(getTestClass().getName()); - } - - @AfterClass - public static void stopThreadPool() { - if (threadPool != null) { - threadPool.shutdownNow(); - threadPool = null; - } - } - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 988a92352649a..3057b00553a22 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -129,19 +129,17 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) { @Override void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 2d6492e4e73a4..6815d634292a4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -148,9 +148,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } @@ -159,10 +157,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index ad522e36f9bd9..1a16d9083df55 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -98,9 +98,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index a2e54a1c7c3b8..cb9bdd1f3a827 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } @@ -253,9 +253,7 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider, - null, - null + relativeTimeProvider ); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java similarity index 86% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ccda986a8d280..0f23e0b33d774 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -26,7 +26,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; @@ -55,13 +55,13 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class ModelRegistryImplIT extends ESSingleNodeTestCase { +public class ModelRegistryIT extends ESSingleNodeTestCase { - private ModelRegistryImpl ModelRegistryImpl; + private ModelRegistry modelRegistry; @Before public void createComponents() { - ModelRegistryImpl = new ModelRegistryImpl(client()); + modelRegistry = new ModelRegistry(client()); } @Override @@ -75,7 +75,7 @@ public void testStoreModel() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertThat(storeModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); @@ -87,7 +87,7 @@ public void testStoreModelWithUnknownFields() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertNull(storeModelHolder.get()); assertNotNull(exceptionHolder.get()); @@ -106,12 +106,12 @@ public void testGetModel() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -133,13 +133,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); putModelHolder.set(false); // an model with the same id exists - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(false)); assertThat(exceptionHolder.get(), not(nullValue())); assertThat( @@ -154,20 +154,20 @@ public void testDeleteModel() throws Exception { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference deleteResponseHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertTrue(deleteResponseHolder.get()); // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); @@ -187,13 +187,13 @@ public void testGetModelsByTaskType() throws InterruptedException { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) @@ -204,7 +204,7 @@ public void testGetModelsByTaskType() throws InterruptedException { assertThat(m.secrets().keySet(), empty()); }); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(2)); var denseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) @@ -228,13 +228,13 @@ public void testGetAllModels() throws InterruptedException { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -258,18 +258,18 @@ public void testGetModelWithSecrets() throws InterruptedException { AtomicReference exceptionHolder = new AtomicReference<>(); var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); // get model without secrets - blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 821a804596cff..a73dc0261faa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -26,11 +26,8 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceRegistryImpl; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -58,7 +55,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -80,13 +77,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin - implements - ActionPlugin, - ExtensiblePlugin, - SystemIndexPlugin, - InferenceRegistryPlugin, - MapperPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -111,8 +102,6 @@ public class InferencePlugin extends Plugin private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce modelRegistry = new SetOnce<>(); - private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -164,7 +153,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelReg = new ModelRegistryImpl(services.client()); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -175,13 +164,11 @@ public Collection createComponents(PluginServices services) { var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly - var inferenceRegistry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext); - inferenceRegistry.init(services.client()); - inferenceServiceRegistry.set(inferenceRegistry); - modelRegistry.set(modelReg); + var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + registry.init(services.client()); + inferenceServiceRegistry.set(registry); - // Don't return components as they will be registered using InferenceRegistryPlugin methods to retrieve them - return List.of(); + return List.of(modelRegistry, registry); } @Override @@ -280,16 +267,6 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } - @Override - public InferenceServiceRegistry getInferenceServiceRegistry() { - return inferenceServiceRegistry.get(); - } - - @Override - public ModelRegistry getModelRegistry() { - return modelRegistry.get(); - } - @Override public Map getMappers() { if (SemanticTextFeature.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index ad6042581f264..b55e2e6f8ebed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 0f7e48c4f8140..2de1aecea118c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -25,6 +24,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ece4fee1c935f..fb3974fc12e8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,11 +16,11 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6667e314a62b8..07d28f8e5b0a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -29,7 +29,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index ad1e0f8c8cb81..928bb0ea47179 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; -import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -41,10 +40,6 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; - /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -101,8 +96,13 @@ * */ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { - public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = ROOT_INFERENCE_FIELD; + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String INFERENCE_RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); @@ -173,7 +173,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + if (INFERENCE_RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java similarity index 89% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 3561c2351427c..aedaa1cd34c12 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -1,3 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -6,8 +13,11 @@ * Side Public License, v 1. */ -package org.elasticsearch.inference; +package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,7 +29,6 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ public class SemanticTextModelSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 40921cd38f181..0f3aa5b82b189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -24,7 +24,6 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -32,7 +31,6 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -57,21 +55,49 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistryImpl implements ModelRegistry { +public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} + /** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) { + + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } + } + private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; - private static final Logger logger = LogManager.getLogger(ModelRegistryImpl.class); + private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; - @Inject - public ModelRegistryImpl(Client client) { + public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } - @Override + /** + * Get a model with its secret settings + * @param inferenceEntityId Model to get + * @param listener Model listener + */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -80,7 +106,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -101,7 +132,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -116,7 +147,12 @@ public void getModel(String inferenceEntityId, ActionListener lis client.search(modelSearch, searchListener); } - @Override + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -125,7 +161,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -150,7 +190,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -217,7 +257,6 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEnt ); } - @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -314,7 +353,6 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } - @Override public void deleteModel(String inferenceEntityId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -339,16 +377,4 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } - - private static UnparsedModel unparsedModelFromMap(ModelRegistryImpl.ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 319f6ef73fa56..27d0949e933ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -53,9 +53,9 @@ import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java similarity index 92% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fd6a203450c12..2417148c84ac2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ModelRegistryImplTests extends ESTestCase { +public class ModelRegistryTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -65,9 +65,9 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -79,9 +79,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -96,9 +96,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -113,9 +113,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -147,9 +147,9 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -176,9 +176,9 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -201,7 +201,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +218,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +249,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -275,7 +275,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); From ebc26d2b29fd9dda9b896c061c67e6f223255c5d Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 14 Mar 2024 10:38:14 +0000 Subject: [PATCH 02/13] inference as an action filter --- .../action/bulk/BulkOperation.java | 4 + .../action/bulk/BulkShardRequest.java | 20 + .../snapshots/SnapshotResiliencyTests.java | 4 +- .../TestSparseInferenceServiceExtension.java | 8 +- .../xpack/inference/InferencePlugin.java | 13 + .../ShardBulkInferenceActionFilter.java | 346 ++++++++++++++++++ ...emanticTextInferenceResultFieldMapper.java | 1 - .../mapper/SemanticTextModelSettings.java | 8 - ...icTextInferenceResultFieldMapperTests.java | 1 - .../inference/10_semantic_text_inference.yml | 48 +-- .../20_semantic_text_field_mapper.yml | 20 +- 11 files changed, 423 insertions(+), 50 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 1d95f430d5c7e..b7a6387045e3d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -208,6 +208,10 @@ private void executeBulkRequestsByShard(Map> requ bulkRequest.getRefreshPolicy(), requests.toArray(new BulkItemRequest[0]) ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels()); + } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..f6dd7902f1672 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -22,6 +22,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -33,6 +34,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest> fieldsInferenceMetadata = null; + public BulkShardRequest(StreamInput in) throws IOException { super(in); items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +47,23 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Set the transient metadata indicating that this request requires running inference + * before proceeding. + */ + void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { + this.fieldsInferenceMetadata = fieldsInferenceMetadata; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public Map> consumeFieldInferenceMetadata() { + var ret = fieldsInferenceMetadata; + fieldsInferenceMetadata = null; + return ret; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index c13a6be6d3386..edde9f0164a6e 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -2355,9 +2355,7 @@ protected void assertSnapshotOrGenericThread() { actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 33bbc94901e9d..b6e48d3b1c29a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -123,15 +123,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a73dc0261faa8..7bfa06ecb9a20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -46,6 +47,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -77,6 +79,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.Collections.singletonList; + public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** @@ -102,6 +106,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -168,6 +173,9 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); + return List.of(modelRegistry, registry); } @@ -279,4 +287,9 @@ public Map getMappers() { public Map getMetadataMappers() { return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); } + + @Override + public Collection getActionFilters() { + return singletonList(shardBulkInferenceActionFilter.get()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..176d2917b0b2a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper} + * in the subsequent {@link TransportShardBulkAction} downstream. + */ +public class ShardBulkInferenceActionFilter implements ActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + if (fieldInferenceMetadata != null) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + Map> fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + private record FieldInferenceRequest(int id, String field, String input) {} + + private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) { + Exception createFailureOrNull() { + if (failures.isEmpty()) { + return null; + } + Exception main = failures.get(0); + for (int i = 1; i < failures.size(); i++) { + main.addSuppressed(failures.get(i)); + } + return main; + } + } + + private class AsyncBulkShardInferenceAction implements Runnable { + private final Map> fieldInferenceMetadata; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + Map> fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMetadata = fieldInferenceMetadata; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + applyInference(request, inferenceResponse); + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException( + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + inferenceResults.get(request.id).responses.add( + new FieldInferenceResponse(request.field, inferenceProvider.model, result) + ); + } + } + + @Override + public void onFailure(Exception exc) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + ActionListener.runAfter(completionListener, onFinish::close) + ); + } + + /** + * Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is mark as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results. + */ + private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) { + Exception failure = inferenceResponse.createFailureOrNull(); + if (failure != null) { + request.abort(bulkShardRequest.index(), failure); + return; + } + final IndexRequest indexRequest = getIndexRequestOrNull(request.request()); + final Map newDocMap = indexRequest.sourceAsMap(); + final Map inferenceMetadataMap = new LinkedHashMap<>(); + newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap); + for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) { + List> chunks = new ArrayList<>(); + if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO")); + return; + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap()); + fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks); + inferenceMetadataMap.put(fieldResponse.field, fieldMap); + } + indexRequest.source(newDocMap); + } + + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security). + continue; + } + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + if (indexRequest == null) { + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + List fieldRequests = null; + for (var pair : fieldInferenceMetadata.entrySet()) { + String inferenceId = pair.getKey(); + for (var field : pair.getValue()) { + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (value instanceof String valueStr) { + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (fieldRequests == null) { + fieldRequests = new ArrayList<>(); + fieldRequestsMap.put(inferenceId, fieldRequests); + } + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } + } + } + } + return fieldRequestsMap; + } + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 928bb0ea47179..cee6395185060 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -27,7 +27,6 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index aedaa1cd34c12..1b6bb22c0d6b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -5,14 +5,6 @@ * 2.0. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - package org.elasticsearch.xpack.inference.mapper; import org.elasticsearch.inference.Model; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 27d0949e933ee..5dc245298838f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index ead7f904ad57b..6008ebbcbedf8 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -83,11 +83,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +120,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- @@ -154,8 +154,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +174,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +214,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.results.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +233,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -271,11 +271,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -292,7 +292,7 @@ setup: type: text - do: - catch: bad_request + catch: missing index: index: incorrect-test-sparse-index id: doc_1 @@ -300,7 +300,7 @@ setup: inference_field: "inference test" non_inference_field: "non inference test" - - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } # Succeeds when semantic_text field is not used - do: diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index da61e6e403ed8..2c69f49218091 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -56,12 +56,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -83,14 +83,14 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding dimensions: 5 similarity: cosine - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" @@ -105,11 +105,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -123,11 +123,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -141,12 +141,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" From 86ddc9d8b8aff0285890732d2164c589b87ce3dc Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 18 Mar 2024 15:12:23 +0000 Subject: [PATCH 03/13] add more tests --- .../action/bulk/BulkShardRequest.java | 13 +- .../vectors/DenseVectorFieldMapper.java | 4 + .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 133 ++++--- ...r.java => InferenceResultFieldMapper.java} | 55 ++- .../mapper/SemanticTextFieldMapper.java | 2 +- .../ShardBulkInferenceActionFilterTests.java | 340 ++++++++++++++++++ ...a => InferenceResultFieldMapperTests.java} | 146 ++++---- 8 files changed, 544 insertions(+), 153 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapper.java => InferenceResultFieldMapper.java} (86%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapperTests.java => InferenceResultFieldMapperTests.java} (79%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index f6dd7902f1672..1b5494c6a68f5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -48,10 +48,10 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe } /** - * Set the transient metadata indicating that this request requires running inference - * before proceeding. + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. */ - void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { + public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { this.fieldsInferenceMetadata = fieldsInferenceMetadata; } @@ -64,6 +64,13 @@ public Map> consumeFieldInferenceMetadata() { return ret; } + /** + * Public for test + */ + public Map> getFieldsInferenceMetadata() { + return fieldsInferenceMetadata; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c6e4d4af926a2..53cc803fc5a2f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1086,6 +1086,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7bfa06ecb9a20..994207766f2a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,8 +55,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -285,7 +285,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 176d2917b0b2a..e679d3c970abf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -33,11 +34,9 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +49,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements ActionFilter { @@ -82,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -110,18 +109,7 @@ private record FieldInferenceRequest(int id, String field, String input) {} private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} - private record FieldInferenceResponseAccumulator(int id, List responses, List failures) { - Exception createFailureOrNull() { - if (failures.isEmpty()) { - return null; - } - Exception main = failures.get(0); - for (int i = 1; i < failures.size(); i++) { - main.addSuppressed(failures.get(i)); - } - return main; - } - } + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { private final Map> fieldInferenceMetadata; @@ -147,7 +135,11 @@ public void run() { try { for (var inferenceResponse : inferenceResults.asList()) { var request = bulkShardRequest.items()[inferenceResponse.id]; - applyInference(request, inferenceResponse); + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } } } finally { onCompletion.run(); @@ -189,8 +181,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var request = requests.get(i); inferenceResults.get(request.id).failures.add( new ResourceNotFoundException( - "Inference service [{}] not found for field [{}]", - unparsedModel.service(), + "Inference id [{}] not found for field [{}]", + inferenceId, request.field ) ); @@ -221,9 +213,8 @@ public void onResponse(List results) { for (int i = 0; i < results.size(); i++) { var request = requests.get(i); var result = results.get(i); - inferenceResults.get(request.id).responses.add( - new FieldInferenceResponse(request.field, inferenceProvider.model, result) - ); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); } } @@ -254,38 +245,34 @@ public void onFailure(Exception exc) { } /** - * Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. * If the response contains failures, the bulk item request is mark as failed for the downstream action. * Otherwise, the source of the request is augmented with the field inference results. */ - private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) { - Exception failure = inferenceResponse.createFailureOrNull(); - if (failure != null) { - request.abort(bulkShardRequest.index(), failure); + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } return; } - final IndexRequest indexRequest = getIndexRequestOrNull(request.request()); - final Map newDocMap = indexRequest.sourceAsMap(); - final Map inferenceMetadataMap = new LinkedHashMap<>(); - newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap); - for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) { - List> chunks = new ArrayList<>(); - if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO")); - return; + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + Map newDocMap = indexRequest.sourceAsMap(); + Map inferenceMap = new LinkedHashMap<>(); + // ignore the existing inference map if any + newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + for (FieldInferenceResponse fieldResponse : response.responses()) { + try { + InferenceResultFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap()); - fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks); - inferenceMetadataMap.put(fieldResponse.field, fieldMap); } indexRequest.source(newDocMap); } @@ -294,7 +281,7 @@ private Map> createFieldInferenceRequests(Bu Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security). + // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); @@ -302,30 +289,38 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - List fieldRequests = null; - for (var pair : fieldInferenceMetadata.entrySet()) { - String inferenceId = pair.getKey(); - for (var field : pair.getValue()) { + for (var entry : fieldInferenceMetadata.entrySet()) { + String inferenceId = entry.getKey(); + for (var field : entry.getValue()) { var value = XContentMapValues.extractValue(field, docMap); if (value == null) { continue; } - if (value instanceof String valueStr) { - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (fieldRequests == null) { - fieldRequests = new ArrayList<>(); - fieldRequestsMap.put(inferenceId, fieldRequests); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index cee6395185060..2ede5419ab74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -27,15 +29,24 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -51,7 +62,7 @@ * { * "_source": { * "my_semantic_text_field": "these are not the droids you're looking for", - * "_semantic_text_inference": { + * "_inference": { * "my_semantic_text_field": [ * { * "sparse_embedding": { @@ -94,17 +105,17 @@ * } * */ -public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { +public class InferenceResultFieldMapper extends MetadataFieldMapper { public static final String NAME = "_inference"; public static final String CONTENT_TYPE = "_inference"; - public static final String INFERENCE_RESULTS = "results"; + public static final String RESULTS = "results"; public static final String INFERENCE_CHUNKS_RESULTS = "inference"; public static final String INFERENCE_CHUNKS_TEXT = "text"; - public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); @@ -131,7 +142,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { } } - private SemanticTextInferenceResultFieldMapper() { + public InferenceResultFieldMapper() { super(SemanticTextInferenceFieldType.INSTANCE); } @@ -172,7 +183,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (INFERENCE_RESULTS.equals(currentName)) { + if (RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, @@ -328,4 +339,34 @@ protected String contentType() { public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..4caa3d68ba877 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link SemanticTextInferenceResultFieldMapper}. + * be indexed using {@link InferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..7f3ffbe596543 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map> inferenceFields = Map.of( + model.getInferenceEntityId(), + Set.of("field1"), + "inference_0", + Set.of("field2", "field3") + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map> inferenceFields = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String inferenceId = randomFrom(inferenceModelMap.keySet()); + String field = randomAlphaOfLengthBetween(5, 10); + var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); + res.add(field); + } + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[1]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[5]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + Map> inferenceFieldMap + ) { + Map docMap = new LinkedHashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); + for (var entry : inferenceFieldMap.entrySet()) { + String inferenceId = entry.getKey(); + var model = modelMap.get(inferenceId); + for (var field : entry.getValue()) { + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; + + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; + + default: + throw new AssertionError("Unknown task type " + taskType.name()); + } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + } + } + Map expectedDocMap = new LinkedHashMap<>(docMap); + expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults results) { + resultMap.put(text, results); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java similarity index 79% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java index 5dc245298838f..b5d75b528c6ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java @@ -31,48 +31,46 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; -public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { - private SemanticTextInferenceResults { - if (sparseEmbeddingResults.embeddings().size() != text.size()) { - throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); - } - } - } +public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + private record VisitedChildDocInfo(String path, int numChunks) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return SemanticTextInferenceResultFieldMapper.NAME; + return InferenceResultFieldMapper.NAME; } @Override @@ -108,8 +106,8 @@ public void testSuccessfulParse() throws IOException { b -> addSemanticTextInferenceResults( b, List.of( - generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), - generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) ) ) ) @@ -208,10 +206,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, - null + Map.of() ) ) ) @@ -226,10 +224,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, - null + Map.of() ) ) ) @@ -244,10 +242,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, - null + Map.of() ) ) ) @@ -262,7 +260,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); final List semanticTextInferenceResultsList = List.of( - generateSemanticTextinferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, List.of("a b")) ); DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); @@ -360,7 +358,7 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) ) ); assertThat( @@ -378,18 +376,32 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { - List embeddings = new ArrayList<>(chunks.size()); - for (String chunk : chunks) { - String[] tokens = chunk.split("\\s+"); - List weightedTokens = Arrays.stream(tokens) - .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) - .toList(); + public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[5]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } - embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); } + return new ChunkedSparseEmbeddingResults(chunks); + } - return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { + return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -401,10 +413,11 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResults, new SparseVectorSubfieldOptions(true, true, true), true, - null + Map.of() ); } + @SuppressWarnings("unchecked") private static void addSemanticTextInferenceResults( XContentBuilder sourceBuilder, List semanticTextInferenceResults, @@ -412,48 +425,39 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - - Map> inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - Map fieldMap = new HashMap<>(); - fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); - List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); - - Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() - .embeddings() - .iterator(); - Iterator textIterator = semanticTextInferenceResult.text().iterator(); - while (embeddingsIterator.hasNext() && textIterator.hasNext()) { - SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); - String text = textIterator.next(); - - Map subfieldMap = new HashMap<>(); - if (sparseVectorSubfieldOptions.include()) { - subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); - } - if (includeTextSubfield) { - subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); + InferenceResultFieldMapper.applyFieldInference( + inferenceResultsMap, + semanticTextInferenceResult.fieldName, + randomModel(), + semanticTextInferenceResult.results + ); + Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); + List> fieldResultList = (List>) optionsMap.get(RESULTS); + for (var entry : fieldResultList) { + if (includeTextSubfield == false) { + entry.remove(INFERENCE_CHUNKS_TEXT); } - if (extraSubfields != null) { - subfieldMap.putAll(extraSubfields); + if (sparseVectorSubfieldOptions.include == false) { + entry.remove(INFERENCE_CHUNKS_RESULTS); } - - parsedInferenceResults.add(subfieldMap); + entry.putAll(extraSubfields); } - - fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } - - sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); } - private static Map modelSettingsMap() { - return Map.of( - SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), - TaskType.SPARSE_EMBEDDING.toString(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - randomAlphaOfLength(8) + private static Model randomModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) ); } From c5de0da930dbc329f2000fae7475fe1d90b82488 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 19 Mar 2024 14:22:05 +0100 Subject: [PATCH 04/13] Merge from feature branch --- .../action/bulk/BulkOperation.java | 4 +- .../action/bulk/BulkShardRequest.java | 18 ++-- .../ShardBulkInferenceActionFilter.java | 70 +++++++------- .../ShardBulkInferenceActionFilterTests.java | 96 ++++++++++--------- 4 files changed, 94 insertions(+), 94 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index b7a6387045e3d..452a9ec90443a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -209,8 +209,8 @@ private void executeBulkRequestsByShard(Map> requ requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels()); + if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 1b5494c6a68f5..39fa791a3e27d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -22,7 +23,6 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; -import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +34,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest> fieldsInferenceMetadata = null; + private transient FieldInferenceMetadata fieldsInferenceMetadataMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +51,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { - this.fieldsInferenceMetadata = fieldsInferenceMetadata; + public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { + this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public Map> consumeFieldInferenceMetadata() { - var ret = fieldsInferenceMetadata; - fieldsInferenceMetadata = null; + public FieldInferenceMetadata consumeFieldInferenceMetadata() { + FieldInferenceMetadata ret = fieldsInferenceMetadataMap; + fieldsInferenceMetadataMap = null; return ret; } /** * Public for test */ - public Map> getFieldsInferenceMetadata() { - return fieldsInferenceMetadata; + public FieldInferenceMetadata getFieldsInferenceMetadataMap() { + return fieldsInferenceMetadataMap; } public long totalSizeInBytes() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e679d3c970abf..984a20419b2c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -44,7 +45,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; /** @@ -81,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -96,7 +96,7 @@ public void app } private void processBulkShardRequest( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -112,13 +112,13 @@ private record FieldInferenceResponse(String field, Model model, ChunkedInferenc private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final Map> fieldInferenceMetadata; + private final FieldInferenceMetadata fieldInferenceMetadata; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -289,39 +289,35 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMetadata.entrySet()) { - String inferenceId = entry.getKey(); - for (var field : entry.getValue()) { - var value = XContentMapValues.extractValue(field, docMap); - if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent( - inferenceId, - k -> new ArrayList<>() - ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - inferenceResults.get(item.id()).failures.add( - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7f3ffbe596543..4a1825303b5a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -40,7 +41,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -79,7 +79,7 @@ public void testFilterNoop() throws Exception { CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); } finally { chainExecuted.countDown(); } @@ -91,7 +91,9 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + request.setFieldInferenceMetadata( + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -104,7 +106,7 @@ public void testInferenceNotFound() throws Exception { ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -118,11 +120,15 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - Map> inferenceFields = Map.of( - model.getInferenceEntityId(), - Set.of("field1"), - "inference_0", - Set.of("field2", "field3") + FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( + Map.of( + "field1", + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + "field2", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + "field3", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + ) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -144,19 +150,19 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map> inferenceFields = new HashMap<>(); + Map inferenceFieldsMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { - String inferenceId = randomFrom(inferenceModelMap.keySet()); String field = randomAlphaOfLengthBetween(5, 10); - var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); - res.add(field); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); } + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } @@ -167,7 +173,7 @@ public void testManyRandomDocs() throws Exception { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -186,7 +192,7 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(inferenceFields); + original.setFieldInferenceMetadata(fieldInferenceMetadata); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -257,42 +263,40 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - Map> inferenceFieldMap + FieldInferenceMetadata fieldInferenceMetadata ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : inferenceFieldMap.entrySet()) { - String inferenceId = entry.getKey(); - var model = modelMap.get(inferenceId); - for (var field : entry.getValue()) { - String text = randomAlphaOfLengthBetween(10, 100); - docMap.put(field, text); - if (model == null) { - // ignore results, the doc should fail with a resource not found exception - continue; - } - int numChunks = randomIntBetween(1, 5); - List chunks = new ArrayList<>(); - for (int i = 0; i < numChunks; i++) { - chunks.add(randomAlphaOfLengthBetween(5, 10)); - } - TaskType taskType = model.getTaskType(); - final ChunkedInferenceServiceResults results; - switch (taskType) { - case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); - break; + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + var model = modelMap.get(entry.getValue().inferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; - case SPARSE_EMBEDDING: - results = randomSparseEmbeddings(chunks); - break; + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; - default: - throw new AssertionError("Unknown task type " + taskType.name()); - } - model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + default: + throw new AssertionError("Unknown task type " + taskType.name()); } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); From b2b863579d0c301f3d781f503d7b35996b449e69 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 19 Mar 2024 18:44:29 +0000 Subject: [PATCH 05/13] Refactor the semantic_text field so that it can registers all the sub-fields in the mapping --- .../index/mapper/FieldMapper.java | 8 +- .../vectors/SparseVectorFieldMapper.java | 7 +- .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 8 +- .../mapper/InferenceMetadataFieldMapper.java | 385 ++++++++++++++++++ .../mapper/InferenceResultFieldMapper.java | 372 ----------------- .../mapper/SemanticTextFieldMapper.java | 197 ++++++++- .../mapper/SemanticTextModelSettings.java | 45 +- .../ShardBulkInferenceActionFilterTests.java | 10 +- ...=> InferenceMetadataFieldMapperTests.java} | 309 +++++++------- .../mapper/SemanticTextFieldMapperTests.java | 6 + .../20_semantic_text_field_mapper.yml | 4 +- 12 files changed, 815 insertions(+), 540 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{InferenceResultFieldMapperTests.java => InferenceMetadataFieldMapperTests.java} (66%) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..f9354025cab49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1176,7 +1176,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1188,7 +1188,11 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public boolean hasConflicts() { + return conflicts.isEmpty() == false; + } + + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 994207766f2a6..24c1950be1915 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,7 +55,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -285,7 +285,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e679d3c970abf..47fae274095e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -36,7 +36,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -49,7 +49,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements ActionFilter { @@ -261,10 +261,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons Map newDocMap = indexRequest.sourceAsMap(); Map inferenceMap = new LinkedHashMap<>(); // ignore the existing inference map if any - newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { try { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceMap, fieldResponse.field(), fieldResponse.model(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java new file mode 100644 index 0000000000000..831509288696f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -0,0 +1,385 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_inference": {
+ *             "my_semantic_text_field": {
+ *                  "model_settings": {
+ *                      "inference_id": "my_inference_id",
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "results" [
+ *                      {
+ *                          "inference": {
+ *                              "lucas": 0.05212344,
+ *                              "ty": 0.041213956,
+ *                              "dragon": 0.50991,
+ *                              "type": 0.23241979,
+ *                              "dr": 1.9312073,
+ *                              "##o": 0.2797593
+ *                          },
+ *                          "text": "these are not the droids you're looking for"
+ *                      }
+ *                  ]
+ *              }
+ *          }
+ *      }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_text_field": {
+ *                 "type": "nested",
+ *                 "properties": {
+ *                     "sparse_embedding": {
+ *                         "type": "sparse_vector"
+ *                     },
+ *                     "text": {
+ *                         "type": "text",
+ *                         "index": false
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class InferenceMetadataFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); + + private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); + + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + public InferenceMetadataFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); + boolean origWithLeafObject = context.path().isWithinLeafObject(); + try { + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); + parseSingleField(context, mapperBuilderContext); + } + } finally { + context.path().setWithinLeafObject(origWithLeafObject); + } + } + + private SemanticTextFieldMapper updateSemanticTextFieldMapper( + DocumentParserContext docContext, + MapperBuilderContext mapperBuilderContext, + SemanticTextFieldMapper original, + SemanticTextModelSettings modelSettings, + XContentLocation xContentLocation + ) { + if (modelSettings.inferenceId().equals(original.fieldType().getInferenceModel()) == false) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + + original.fieldType().name() + + "] is already set to [" + + original.fieldType().getInferenceModel() + + "], got [" + + modelSettings.inferenceId() + + "]" + ); + } + if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + original.fieldType().name() + "] must contain dimensions" + ); + } + + if (original.getModelSettings() == null) { + SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( + original.name(), + docContext.indexSettings().getIndexVersionCreated(), + docContext.indexAnalyzers() + ).setModelId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + docContext.addDynamicMapper(newMapper); + return newMapper; + } else { + var conflicts = new Conflicts(original.name()); + SemanticTextModelSettings.checkCompatibility(original.getModelSettings(), modelSettings, conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Failed to update field [" + original.name() + "]", exc); + } + } + return original; + } + + private record FieldMapperAndParent(ObjectMapper parent, Mapper mapper) {} + + private FieldMapperAndParent findFieldMapper(ObjectMapper mapper, String fullName) { + String[] pathElements = fullName.split("\\."); + for (int i = 0; i < pathElements.length - 1; i++) { + Mapper next = mapper.getMapper(pathElements[i]); + if (next == null || next instanceof ObjectMapper == false) { + return null; + } + mapper = (ObjectMapper) next; + } + return new FieldMapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + } + + @SuppressWarnings("unchecked") + private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); + var res = findFieldMapper(context.root(), fieldName); + if (res == null || res.mapper == null || res.mapper instanceof SemanticTextFieldMapper == false) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + parser.nextToken(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + XContentLocation xContentLocation = parser.getTokenLocation(); + + Map map = parser.mapOrdered(); + Map modelSettingsMap = (Map) map.remove(SemanticTextModelSettings.NAME); + var modelSettings = SemanticTextModelSettings.parse( + XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap) + ); + var fieldMapper = updateSemanticTextFieldMapper( + context, + mapperBuilderContext, + (SemanticTextFieldMapper) res.mapper, + modelSettings, + xContentLocation + ); + XContentParser subParser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + DocumentParserContext mapContext = context.switchParser(subParser); + parseFieldInferenceObject(xContentLocation, subParser, mapContext, fieldMapper.getNestedField()); + } + + private void parseFieldInferenceObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + switch (parser.currentName()) { + case RESULTS -> parseResultsList(xContentLocation, parser, context, nestedMapper); + default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); + } + } + } + + private void parseResultsList( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext subContext = context.createNestedContext(nestedMapper); + parseResultsObject(xContentLocation, parser, subContext, nestedMapper); + } + } + + private void parseResultsObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + Set visited = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); + visited.add(parser.currentName()); + FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); + if (fieldMapper == null) { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } + parser.nextToken(); + fieldMapper.parse(context); + } + if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() + .filter(s -> visited.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); + } + } + + private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); + } + } + + private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceMetadataFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index 2ede5419ab74e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 4caa3d68ba877..deeea81a46d92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,30 +9,54 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; + /** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * A {@link FieldMapper} for semantic text fields. + * These fields have a model id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceResultFieldMapper}. + * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -40,15 +64,47 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated(), c.getIndexAnalyzers()), + notInMultiFields(CONTENT_TYPE) + ); + + private final IndexVersion indexVersionCreated; + private final SemanticTextModelSettings modelSettings; + private final IndexAnalyzers indexAnalyzers; + private final NestedObjectMapper subMappers; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + CopyTo copyTo, + IndexVersion indexVersionCreated, + IndexAnalyzers indexAnalyzers, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers + ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.indexVersionCreated = indexVersionCreated; + this.indexAnalyzers = indexAnalyzers; + this.modelSettings = modelSettings; + this.subMappers = subMappers; + } + + @Override + public String name() { + return super.name(); + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(subMappers); + return subIterators.iterator(); } @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); + return new Builder(simpleName(), indexVersionCreated, indexAnalyzers).init(this); } @Override @@ -67,7 +123,17 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getNestedField() { + return subMappers; + } + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; + private final IndexAnalyzers indexAnalyzers; private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) .addValidator(v -> { @@ -76,25 +142,84 @@ public static class Builder extends FieldMapper.Builder { } }); + @SuppressWarnings("unchecked") + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (name, context, node) -> { + if (node == null) { + return null; + } + try { + Map map = (Map) node; + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new IllegalArgumentException(exc); + } + }, + m -> ((SemanticTextFieldMapper) m).modelSettings, + XContentBuilder::field, + Strings::toString + ).acceptsNull().setMergeValidator(SemanticTextModelSettings::checkCompatibility); + private final Parameter> meta = Parameter.metaParam(); - public Builder(String name) { + public Builder(String name, IndexVersion indexVersionCreated, IndexAnalyzers indexAnalyzers) { super(name); + this.indexVersionCreated = indexVersionCreated; + this.indexAnalyzers = indexAnalyzers; + } + + public Builder setModelId(String id) { + this.modelId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextModelSettings value) { + this.modelSettings.setValue(value); + return this; } @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; + return new Parameter[] { modelId, meta, modelSettings }; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + final String fullName = context.buildFullName(name()); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(RESULTS, indexVersionCreated); + nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( + INFERENCE_CHUNKS_TEXT, + indexVersionCreated, + indexAnalyzers + ).index(false).store(false); + if (modelSettings.get() != null) { + nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + } + nestedBuilder.add(textMapperBuilder); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, modelId.getValue(), meta.getValue()), + copyTo, + indexVersionCreated, + indexAnalyzers, + modelSettings.getValue(), + nestedBuilder.build(childContext) + ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private final String modelId; public SemanticTextFieldType(String name, String modelId, Map meta) { @@ -127,4 +252,54 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } + + private static Mapper.Builder createInferenceMapperBuilder( + String fieldName, + SemanticTextModelSettings modelSettings, + IndexVersion indexVersionCreated + ) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + Integer dimensions = modelSettings.dimensions(); + denseVectorMapperBuilder.dimensions(dimensions); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException( + "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() + ); + }; + } + + @Override + protected void checkIncomingMergeType(FieldMapper mergeWith) { + if (mergeWith instanceof SemanticTextFieldMapper other) { + if (other.modelSettings != null && other.modelSettings.inferenceId().equals(other.fieldType().getInferenceModel()) == false) { + throw new IllegalArgumentException( + "mapper [" + + name() + + "] refers to different model ids [" + + other.modelSettings.inferenceId() + + "] and [" + + other.fieldType().getInferenceModel() + + "]" + ); + } + } + super.checkIncomingMergeType(mergeWith); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 1b6bb22c0d6b5..8b49e420f16a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,11 +7,14 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -22,7 +25,7 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ -public class SemanticTextModelSettings { +public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); @@ -98,4 +101,44 @@ public Integer dimensions() { public SimilarityMeasure similarity() { return similarity; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + public static boolean checkCompatibility( + SemanticTextModelSettings original, + SemanticTextModelSettings another, + FieldMapper.Conflicts conflicts + ) { + if (original == null) { + return true; + } + if (original != null && another == null) { + conflicts.addConflict("model_settings", "missing"); + } + if (original.inferenceId.equals(another.inferenceId) == false) { + conflicts.addConflict(INFERENCE_ID_FIELD.getPreferredName(), "values differ"); + } + if (original.taskType != another.taskType()) { + conflicts.addConflict(TASK_TYPE_FIELD.getPreferredName(), "values differ"); + } + if (original.dimensions != another.dimensions) { + conflicts.addConflict(DIMENSIONS_FIELD.getPreferredName(), "values differ"); + } + if (original.similarity != another.similarity) { + conflicts.addConflict(SIMILARITY_FIELD.getPreferredName(), "values differ"); + } + return conflicts.hasConflicts() == false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7f3ffbe596543..a7af1443dc0ca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -31,7 +31,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +51,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -291,11 +291,11 @@ private static BulkItemRequest[] randomBulkItemRequest( throw new AssertionError("Unknown task type " + taskType.name()); } model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } } Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java similarity index 66% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b5d75b528c6ab..b212ce6a269ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -51,26 +53,28 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} +public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int numChunks) {} + private record VisitedChildDocInfo(String path) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return InferenceResultFieldMapper.NAME; + return InferenceMetadataFieldMapper.NAME; } @Override @@ -94,109 +98,127 @@ protected Collection getPlugins() { } public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(); + Model model2 = randomModel(); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) + ) ) ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) ); - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".results.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".results.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".results.inference", 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + RESULTS, 0, null) ); - assertEquals(0, topDocs.totalHits.value); - } - }); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + RESULTS, + List.of("a") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + RESULTS, + List.of("a", "b") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + RESULTS, + List.of("d") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + RESULTS, + List.of("z") + ), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } } public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); { DocumentParsingException ex = expectThrows( @@ -206,7 +228,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, Map.of() @@ -224,7 +246,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, Map.of() @@ -242,7 +264,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, Map.of() @@ -259,15 +281,18 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(); final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + RESULTS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -358,13 +383,18 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) + source( + b -> addSemanticTextInferenceResults( + b, + List.of(randomSemanticTextInferenceResults(fieldName, randomModel(), List.of("a b"))) + ) + ) ) ); assertThat( ex.getMessage(), containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ) ); } @@ -400,8 +430,12 @@ public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List return new ChunkedSparseEmbeddingResults(chunks); } - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults( + String semanticTextFieldName, + Model model, + List chunks + ) { + return new SemanticTextInferenceResults(semanticTextFieldName, model, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -425,12 +459,12 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceResultsMap, semanticTextInferenceResult.fieldName, - randomModel(), + semanticTextInferenceResult.model, semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); @@ -445,7 +479,18 @@ private static void addSemanticTextInferenceResults( entry.putAll(extraSubfields); } } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + } + + private String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); } private static Model randomModel() { @@ -461,29 +506,6 @@ private static Model randomModel() { ); } - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); assertNotNull(mapper); @@ -503,12 +525,10 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook private static void assertValidChildDoc( LuceneDocument childDoc, LuceneDocument expectedParent, - Set visitedChildDocs + Collection visitedChildDocs ) { assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); + visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); } private static void assertChildLeafNestedDocument( @@ -524,4 +544,15 @@ private static void assertChildLeafNestedDocument( assertNotNull(leaf.nestedIdentity()); visitedNestedIdentities.add(leaf.nestedIdentity()); } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..274ef346e27e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; @@ -82,6 +83,11 @@ protected void minimalMapping(XContentBuilder b) throws IOException { b.field("type", "semantic_text").field("model_id", "test_model"); } + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; + } + @Override protected Object getSampleValueForDocument() { return "value"; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 2c69f49218091..6744b04014446 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -86,7 +86,7 @@ setup: _inference: dense_field: model_settings: - inference_id: sparse-inference-id + inference_id: dense-inference-id task_type: text_embedding dimensions: 5 similarity: cosine @@ -144,7 +144,7 @@ setup: _inference: dense_field: model_settings: - inference_id: sparse-inference-id + inference_id: dense-inference-id task_type: text_embedding results: - text: "inference test" From 1c18fbc8e0e5d0143f2f79f499141084a23ae4f5 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 09:05:03 +0000 Subject: [PATCH 06/13] Refatcor the semantic_text to register its sub fields in the mapping instead of re-creating them each time. --- .../xcontent/support/XContentMapValues.java | 2 +- .../elasticsearch/index/mapper/Mapping.java | 2 +- .../mapper/InferenceMetadataFieldMapper.java | 128 ++++++---- .../mapper/SemanticTextFieldMapper.java | 109 +++++---- .../mapper/SemanticTextModelSettings.java | 73 ++++-- .../SemanticTextClusterMetadataTests.java | 4 +- .../InferenceMetadataFieldMapperTests.java | 6 +- .../mapper/SemanticTextFieldMapperTests.java | 228 ++++++++++++++---- .../inference/10_semantic_text_inference.yml | 14 +- .../20_semantic_text_field_mapper.yml | 4 +- 10 files changed, 399 insertions(+), 171 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java index 805931550ad62..f527b4cd8d684 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java @@ -555,7 +555,7 @@ public static Map nodeMapValue(Object node, String desc) { if (node instanceof Map) { return (Map) node; } else { - throw new ElasticsearchParseException(desc + " should be a hash but was of type: " + node.getClass()); + throw new ElasticsearchParseException(desc + " should be a map but was of type: " + node.getClass()); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java index 903e4e5da5b29..da184d6f7a45e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java @@ -76,7 +76,7 @@ public CompressedXContent toCompressedXContent() { /** * Returns the root object for the current mapping */ - RootObjectMapper getRoot() { + public RootObjectMapper getRoot() { return root; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 6b102e5218134..d03cbdeceaa56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; @@ -172,9 +173,10 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio } } - private SemanticTextFieldMapper updateSemanticTextFieldMapper( + private NestedObjectMapper updateSemanticTextFieldMapper( DocumentParserContext docContext, MapperBuilderContext mapperBuilderContext, + ObjectMapper parent, SemanticTextFieldMapper original, SemanticTextModelSettings modelSettings, XContentLocation xContentLocation @@ -182,61 +184,49 @@ private SemanticTextFieldMapper updateSemanticTextFieldMapper( if (modelSettings.inferenceId().equals(original.fieldType().getInferenceId()) == false) { throw new DocumentParsingException( xContentLocation, - "Model settings for field [" - + original.fieldType().name() - + "] is already set to [" - + original.fieldType().getInferenceId() - + "], got [" - + modelSettings.inferenceId() - + "]" + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), + modelSettings.inferenceId(), + original.name(), + SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), + modelSettings.inferenceId() + ) ); } if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) { throw new DocumentParsingException( xContentLocation, - "Model settings for field [" + original.fieldType().name() + "] must contain dimensions" + "Model settings for field [" + original.name() + "] must contain dimensions" ); } - if (original.getModelSettings() == null) { + if (parent != docContext.root()) { + mapperBuilderContext = mapperBuilderContext.createChildContext(parent.name(), ObjectMapper.Dynamic.FALSE); + } SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( - original.name(), + original.simpleName(), docContext.indexSettings().getIndexVersionCreated(), docContext.indexAnalyzers() - ).setModelId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + ).setInferenceId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); docContext.addDynamicMapper(newMapper); - return newMapper; + return newMapper.getSubMappers(); } else { - var conflicts = new Conflicts(original.name()); - SemanticTextModelSettings.checkCompatibility(original.getModelSettings(), modelSettings, conflicts); + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(original.name()); + SemanticTextFieldMapper.canMergeModelSettings(original.getModelSettings(), modelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { - throw new DocumentParsingException(xContentLocation, "Failed to update field [" + original.name() + "]", exc); - } - } - return original; - } - - private record FieldMapperAndParent(ObjectMapper parent, Mapper mapper) {} - - private FieldMapperAndParent findFieldMapper(ObjectMapper mapper, String fullName) { - String[] pathElements = fullName.split("\\."); - for (int i = 0; i < pathElements.length - 1; i++) { - Mapper next = mapper.getMapper(pathElements[i]); - if (next == null || next instanceof ObjectMapper == false) { - return null; + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); } - mapper = (ObjectMapper) next; } - return new FieldMapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + return original.getSubMappers(); } - @SuppressWarnings("unchecked") private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { XContentParser parser = context.parser(); String fieldName = parser.currentName(); - var res = findFieldMapper(context.root(), fieldName); + var res = findMapper(context.mappingLookup().getMapping().getRoot(), fieldName); if (res == null || res.mapper == null || res.mapper instanceof SemanticTextFieldMapper == false) { throw new DocumentParsingException( parser.getTokenLocation(), @@ -245,20 +235,51 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex } parser.nextToken(); failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - XContentLocation xContentLocation = parser.getTokenLocation(); + // record the location of the inference field in the original source + XContentLocation xContentLocation = parser.getTokenLocation(); + // parse eagerly to extract the model settings first Map map = parser.mapOrdered(); - Map modelSettingsMap = (Map) map.remove(SemanticTextModelSettings.NAME); - var modelSettings = SemanticTextModelSettings.parse( - XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap) - ); - var fieldMapper = updateSemanticTextFieldMapper( + Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); + if (modelSettingsObj == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format( + "Missing required [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + } + Map modelSettingsMap = XContentMapValues.nodeMapValue(modelSettingsObj, "model_settings"); + final SemanticTextModelSettings modelSettings; + try { + modelSettings = SemanticTextModelSettings.parse( + XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap) + ); + } catch (Exception exc) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "Error parsing [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ), + exc + ); + } + var nestedObjectMapper = updateSemanticTextFieldMapper( context, mapperBuilderContext, + res.parent, (SemanticTextFieldMapper) res.mapper, modelSettings, xContentLocation ); + + // we know the model settings, so we can (re) parse the results array now XContentParser subParser = new MapXContentParser( NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, @@ -266,7 +287,7 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex XContentType.JSON ); DocumentParserContext mapContext = context.switchParser(subParser); - parseFieldInferenceObject(xContentLocation, subParser, mapContext, fieldMapper.getNestedField()); + parseFieldInferenceObject(xContentLocation, subParser, mapContext, nestedObjectMapper); } private void parseFieldInferenceObject( @@ -312,9 +333,16 @@ private void parseResultsObject( visited.add(parser.currentName()); FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); if (fieldMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); - advancePastCurrentFieldName(xContentLocation, parser); - continue; + if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { + throw new DocumentParsingException( + xContentLocation, + "Missing sub-fields definition for [" + parser.currentName() + "]" + ); + } else { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } } parser.nextToken(); fieldMapper.parse(context); @@ -382,4 +410,18 @@ public static void applyFieldInference( fieldMap.put(InferenceMetadataFieldMapper.RESULTS, chunks); inferenceMap.put(field, fieldMap); } + + record MapperAndParent(ObjectMapper parent, Mapper mapper) {} + + static MapperAndParent findMapper(ObjectMapper mapper, String fullPath) { + String[] pathElements = fullPath.split("\\."); + for (int i = 0; i < pathElements.length - 1; i++) { + Mapper next = mapper.getMapper(pathElements[i]); + if (next == null || next instanceof ObjectMapper == false) { + return null; + } + mapper = (ObjectMapper) next; + } + return new MapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 8ebaf57cc7543..cacb2fc176f18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -32,18 +32,14 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; @@ -51,7 +47,7 @@ /** * A {@link FieldMapper} for semantic text fields. - * These fields have a model id reference, that is used for performing inference at ingestion and query time. + * These fields have a reference id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will * be indexed using {@link InferenceMetadataFieldMapper}. */ @@ -127,7 +123,7 @@ public SemanticTextModelSettings getModelSettings() { return modelSettings; } - public NestedObjectMapper getNestedField() { + public NestedObjectMapper getSubMappers() { return subMappers; } @@ -135,40 +131,27 @@ public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; private final IndexAnalyzers indexAnalyzers; - private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) - .addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + m -> toType(m).fieldType().inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); @SuppressWarnings("unchecked") private final Parameter modelSettings = new Parameter<>( "model_settings", true, () -> null, - (name, context, node) -> { - if (node == null) { - return null; - } - try { - Map map = (Map) node; - XContentParser parser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - return SemanticTextModelSettings.parse(parser); - } catch (Exception exc) { - throw new IllegalArgumentException(exc); - } - }, - m -> ((SemanticTextFieldMapper) m).modelSettings, + (n, c, o) -> SemanticTextModelSettings.fromMap(o), + mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, XContentBuilder::field, - Strings::toString - ).acceptsNull().setMergeValidator(SemanticTextModelSettings::checkCompatibility); - + (m) -> m == null ? "null" : Strings.toString(m) + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); public Builder(String name, IndexVersion indexVersionCreated, IndexAnalyzers indexAnalyzers) { @@ -177,8 +160,8 @@ public Builder(String name, IndexVersion indexVersionCreated, IndexAnalyzers ind this.indexAnalyzers = indexAnalyzers; } - public Builder setModelId(String id) { - this.modelId.setValue(id); + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); return this; } @@ -189,7 +172,7 @@ public Builder setModelSettings(SemanticTextModelSettings value) { @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta, modelSettings }; + return new Parameter[] { inferenceId, modelSettings, meta }; } @Override @@ -207,24 +190,35 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { } nestedBuilder.add(textMapperBuilder); var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + var subMappers = nestedBuilder.build(childContext); return new SemanticTextFieldMapper( name(), - new SemanticTextFieldType(fullName, modelId.getValue(), meta.getValue()), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), copyTo, indexVersionCreated, indexAnalyzers, modelSettings.getValue(), - nestedBuilder.build(childContext) + subMappers ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private final String modelId; - - public SemanticTextFieldType(String name, String modelId, Map meta) { + private final String inferenceId; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; + + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers, + Map meta + ) { super(name, false, false, false, TextSearchInfo.NONE, meta); - this.modelId = modelId; + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.subMappers = subMappers; } @Override @@ -234,7 +228,15 @@ public String typeName() { @Override public String getInferenceId() { - return modelId; + return inferenceId; + } + + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; } @Override @@ -302,4 +304,23 @@ protected void checkIncomingMergeType(FieldMapper mergeWith) { } super.checkIncomingMergeType(mergeWith); } + + static boolean canMergeModelSettings( + SemanticTextModelSettings previous, + SemanticTextModelSettings current, + FieldMapper.Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + if (current == null) { + conflicts.addConflict("model_settings", ""); + return false; + } + conflicts.addConflict("model_settings", ""); + return false; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 8b49e420f16a6..108dce33c7ffa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,15 +7,20 @@ package org.elasticsearch.xpack.inference.mapper; -import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; import java.util.HashMap; @@ -73,6 +78,34 @@ public static SemanticTextModelSettings parse(XContentParser parser) throws IOEx PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } + public static SemanticTextModelSettings fromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, NAME); + if (map.containsKey(INFERENCE_ID_FIELD.getPreferredName()) == false) { + throw new IllegalArgumentException( + "Failed to parse [" + NAME + "], required [" + INFERENCE_ID_FIELD.getPreferredName() + "] is missing" + ); + } + if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { + throw new IllegalArgumentException( + "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" + ); + } + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + public Map asMap() { Map attrsMap = new HashMap<>(); attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); @@ -116,29 +149,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.endObject(); } - public static boolean checkCompatibility( - SemanticTextModelSettings original, - SemanticTextModelSettings another, - FieldMapper.Conflicts conflicts - ) { - if (original == null) { - return true; - } - if (original != null && another == null) { - conflicts.addConflict("model_settings", "missing"); - } - if (original.inferenceId.equals(another.inferenceId) == false) { - conflicts.addConflict(INFERENCE_ID_FIELD.getPreferredName(), "values differ"); - } - if (original.taskType != another.taskType()) { - conflicts.addConflict(TASK_TYPE_FIELD.getPreferredName(), "values differ"); - } - if (original.dimensions != another.dimensions) { - conflicts.addConflict(DIMENSIONS_FIELD.getPreferredName(), "values differ"); - } - if (original.similarity != another.similarity) { - conflicts.addConflict(SIMILARITY_FIELD.getPreferredName(), "values differ"); - } - return conflicts.hasConflicts() == false; + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticTextModelSettings that = (SemanticTextModelSettings) o; + return taskType == that.taskType + && inferenceId.equals(that.inferenceId) + && Objects.equals(dimensions, that.dimensions) + && similarity == that.similarity; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, inferenceId, dimensions, similarity); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index a7d3fcce26116..bf3cc6334433a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -31,7 +31,7 @@ protected Collection> getPlugins() { public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", - client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); assertEquals( indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), @@ -46,7 +46,7 @@ public void testAddSemanticTextField() throws Exception { final ClusterService clusterService = getInstanceFromNode(ClusterService.class); final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { "properties": { "field": { "type": "semantic_text", "model_id": "test_model" }}}"""); + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); request.indices(new Index[] { indexService.index() }); final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( clusterService.state(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b212ce6a269ef..aee2db47e18a3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -110,6 +110,8 @@ public void testSuccessfulParse() throws IOException { }); MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); DocumentMapper documentMapper = mapperService.documentMapper(); ParsedDocument doc = documentMapper.parse( source( @@ -402,7 +404,7 @@ public void testMissingSemanticTextMapping() throws IOException { private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("model_id", modelId); + mappingBuilder.field("inference_id", modelId); mappingBuilder.endObject(); } @@ -482,7 +484,7 @@ private static void addSemanticTextInferenceResults( sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); } - private String randomFieldName(int numLevel) { + static String randomFieldName(int numLevel) { StringBuilder builder = new StringBuilder(); for (int i = 0; i < numLevel; i++) { if (i > 0) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 274ef346e27e4..adb1d93f2bffb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -16,7 +16,11 @@ import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -27,52 +31,12 @@ import java.util.List; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.findMapper; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { - - public void testDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - public void testModelIdNotPresent() throws IOException { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) - ); - assertThat(e.getMessage(), containsString("field [model_id] must be specified")); - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToModelIdNotSupported() throws IOException { - MapperService mapperService = createMapperService( - fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) - ); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); - } - @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); @@ -80,7 +44,7 @@ protected Collection getPlugins() { @Override protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("model_id", "test_model"); + b.field("type", "semantic_text").field("inference_id", "test_model"); } @Override @@ -121,4 +85,180 @@ protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] to [null]" + ) + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] " + + "to [{\"task_type\":\"text_embedding\",\"inference_id\":\"test_model\",\"dimensions\":10}]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + var res = findMapper(mapperService.mappingLookup().getMapping().getRoot(), fieldName); + Mapper mapper = res.mapper(); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); + assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); + + NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(fieldName + "." + InferenceMetadataFieldMapper.RESULTS); + assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); + Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(TextFieldMapper.class)); + TextFieldMapper textFieldMapper = (TextFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.getModelSettings()); + Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(inferenceMapper); + switch (semanticFieldMapper.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.getModelSettings()); + } + } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 6008ebbcbedf8..eb330c60fd0d2 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -41,10 +41,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -56,10 +56,10 @@ setup: properties: inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id another_inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -247,10 +247,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -287,7 +287,7 @@ setup: properties: inference_field: type: semantic_text - model_id: non-existing-inference-id + inference_id: non-existing-inference-id non_inference_field: type: text diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 6744b04014446..46e3f3cfa5be3 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -41,10 +41,10 @@ setup: properties: sparse_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id dense_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text From 7b578d188ad0e6ad40d3e324d88fa1574f03ebc8 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 15:36:20 +0000 Subject: [PATCH 07/13] add task_type validation --- .../mapper/SemanticTextModelSettings.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 108dce33c7ffa..f4a170acb0649 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -27,6 +27,9 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; + /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ @@ -58,6 +61,7 @@ public SemanticTextModelSettings(Model model) { model.getServiceSettings().dimensions(), model.getServiceSettings().similarity() ); + validate(); } public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { @@ -149,6 +153,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.endObject(); } + public void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException("Wrong [" + TASK_TYPE_FIELD.getPreferredName() + "], expected " + + TEXT_EMBEDDING + "or " + SPARSE_EMBEDDING + ", got " + taskType.name()); + } + } + @Override public boolean equals(Object o) { if (this == o) return true; From 2be50d7338083a51f11c57e166d2412d56daeadb Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 22:46:16 +0000 Subject: [PATCH 08/13] address review comments --- .../TestDenseInferenceServiceExtension.java | 2 +- .../ShardBulkInferenceActionFilter.java | 8 ++ .../mapper/InferenceMetadataFieldMapper.java | 40 +++++--- .../mapper/SemanticTextFieldMapper.java | 21 +---- .../mapper/SemanticTextModelSettings.java | 78 ++++++++-------- .../mapper/SemanticTextFieldMapperTests.java | 12 +-- .../xpack/inference/model/TestModel.java | 11 +++ .../inference/10_semantic_text_inference.yml | 1 + .../20_semantic_text_field_mapper.yml | 93 +------------------ 9 files changed, 93 insertions(+), 173 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 54fe6e01946b4..586850eb948d3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -166,7 +166,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index f49b2a3856e82..00dc195313a61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -295,6 +295,7 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); + boolean hasInput = false; for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { String field = entry.getKey(); String inferenceId = entry.getValue().inferenceId(); @@ -315,6 +316,7 @@ private Map> createFieldInferenceRequests(Bu if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + hasInput = true; } else { inferenceResults.get(item.id()).failures.add( new ElasticsearchStatusException( @@ -326,6 +328,12 @@ private Map> createFieldInferenceRequests(Bu ); } } + if (hasInput == false) { + // remove the existing _inference field (if present) since none of the content require inference. + if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { + indexRequest.source(docMap); + } + } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index d03cbdeceaa56..e962fc0ec9270 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; @@ -37,7 +36,6 @@ import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentLocation; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; @@ -67,8 +65,8 @@ * "my_semantic_text_field": "these are not the droids you're looking for", * "_inference": { * "my_semantic_text_field": { + * "inference_id": "my_inference_id", * "model_settings": { - * "inference_id": "my_inference_id", * "task_type": "SPARSE_EMBEDDING" * }, * "results" [ @@ -118,6 +116,7 @@ public class InferenceMetadataFieldMapper extends MetadataFieldMapper { public static final String NAME = "_inference"; public static final String CONTENT_TYPE = "_inference"; + private static final String INFERENCE_ID = "inference_id"; public static final String RESULTS = "results"; public static final String INFERENCE_CHUNKS_RESULTS = "inference"; public static final String INFERENCE_CHUNKS_TEXT = "text"; @@ -178,19 +177,20 @@ private NestedObjectMapper updateSemanticTextFieldMapper( MapperBuilderContext mapperBuilderContext, ObjectMapper parent, SemanticTextFieldMapper original, + String inferenceId, SemanticTextModelSettings modelSettings, XContentLocation xContentLocation ) { - if (modelSettings.inferenceId().equals(original.fieldType().getInferenceId()) == false) { + if (inferenceId.equals(original.fieldType().getInferenceId()) == false) { throw new DocumentParsingException( xContentLocation, Strings.format( "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - modelSettings.inferenceId(), + INFERENCE_ID, + inferenceId, original.name(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - modelSettings.inferenceId() + INFERENCE_ID, + original.fieldType().getInferenceId() ) ); } @@ -208,7 +208,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( original.simpleName(), docContext.indexSettings().getIndexVersionCreated(), docContext.indexAnalyzers() - ).setInferenceId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + ).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); docContext.addDynamicMapper(newMapper); return newMapper.getSubMappers(); } else { @@ -238,8 +238,18 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex // record the location of the inference field in the original source XContentLocation xContentLocation = parser.getTokenLocation(); - // parse eagerly to extract the model settings first + // parse eagerly to extract the inference id and the model settings first Map map = parser.mapOrdered(); + logger.info("map=" + map.toString()); + + // inference_id + Object inferenceIdObj = map.remove(INFERENCE_ID); + final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); + if (inferenceId == null) { + throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); + } + + // model_settings Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); if (modelSettingsObj == null) { throw new DocumentParsingException( @@ -252,12 +262,9 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex ) ); } - Map modelSettingsMap = XContentMapValues.nodeMapValue(modelSettingsObj, "model_settings"); final SemanticTextModelSettings modelSettings; try { - modelSettings = SemanticTextModelSettings.parse( - XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap) - ); + modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); } catch (Exception exc) { throw new DocumentParsingException( xContentLocation, @@ -270,11 +277,13 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex exc ); } + var nestedObjectMapper = updateSemanticTextFieldMapper( context, mapperBuilderContext, res.parent, (SemanticTextFieldMapper) res.mapper, + inferenceId, modelSettings, xContentLocation ); @@ -406,8 +415,9 @@ public static void applyFieldInference( ); } Map fieldMap = new LinkedHashMap<>(); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceMetadataFieldMapper.RESULTS, chunks); + fieldMap.put(RESULTS, chunks); inferenceMap.put(field, fieldMap); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index cacb2fc176f18..564cca4821cb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -277,8 +277,7 @@ private static Mapper.Builder createInferenceMapperBuilder( ); } } - Integer dimensions = modelSettings.dimensions(); - denseVectorMapperBuilder.dimensions(dimensions); + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); yield denseVectorMapperBuilder; } default -> throw new IllegalArgumentException( @@ -287,24 +286,6 @@ private static Mapper.Builder createInferenceMapperBuilder( }; } - @Override - protected void checkIncomingMergeType(FieldMapper mergeWith) { - if (mergeWith instanceof SemanticTextFieldMapper other) { - if (other.modelSettings != null && other.modelSettings.inferenceId().equals(other.fieldType().getInferenceId()) == false) { - throw new IllegalArgumentException( - "mapper [" - + name() - + "] refers to different model ids [" - + other.modelSettings.inferenceId() - + "] and [" - + other.fieldType().getInferenceId() - + "]" - ); - } - } - super.checkIncomingMergeType(mergeWith); - } - static boolean canMergeModelSettings( SemanticTextModelSettings previous, SemanticTextModelSettings current, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index f4a170acb0649..54e86c108fa43 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -37,30 +37,21 @@ public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); private final TaskType taskType; - private final String inferenceId; private final Integer dimensions; private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public SemanticTextModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; - this.inferenceId = inferenceId; this.dimensions = dimensions; this.similarity = similarity; - } - - public SemanticTextModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); validate(); } @@ -68,16 +59,18 @@ public static SemanticTextModelSettings parse(XContentParser parser) throws IOEx return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); - }); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new SemanticTextModelSettings(taskType, dimensions, similarity); + } + ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } @@ -88,11 +81,6 @@ public static SemanticTextModelSettings fromMap(Object node) { } try { Map map = XContentMapValues.nodeMapValue(node, NAME); - if (map.containsKey(INFERENCE_ID_FIELD.getPreferredName()) == false) { - throw new IllegalArgumentException( - "Failed to parse [" + NAME + "], required [" + INFERENCE_ID_FIELD.getPreferredName() + "] is missing" - ); - } if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { throw new IllegalArgumentException( "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" @@ -113,7 +101,6 @@ public static SemanticTextModelSettings fromMap(Object node) { public Map asMap() { Map attrsMap = new HashMap<>(); attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); if (dimensions != null) { attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); } @@ -127,10 +114,6 @@ public TaskType taskType() { return taskType; } - public String inferenceId() { - return inferenceId; - } - public Integer dimensions() { return dimensions; } @@ -143,7 +126,6 @@ public SimilarityMeasure similarity() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); if (dimensions != null) { builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); } @@ -156,12 +138,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public void validate() { switch (taskType) { case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; case SPARSE_EMBEDDING: break; default: - throw new IllegalArgumentException("Wrong [" + TASK_TYPE_FIELD.getPreferredName() + "], expected " + - TEXT_EMBEDDING + "or " + SPARSE_EMBEDDING + ", got " + taskType.name()); + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + "or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); } } @@ -170,14 +171,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SemanticTextModelSettings that = (SemanticTextModelSettings) o; - return taskType == that.taskType - && inferenceId.equals(that.inferenceId) - && Objects.equals(dimensions, that.dimensions) - && similarity == that.similarity; + return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; } @Override public int hashCode() { - return Objects.hash(taskType, inferenceId, dimensions, similarity); + return Objects.hash(taskType, dimensions, similarity); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index adb1d93f2bffb..15c50e530c1c0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -166,7 +166,6 @@ public void testUpdateModelSettings() throws IOException { .field("type", "semantic_text") .field("inference_id", "test_model") .startObject("model_settings") - .field("inference_id", "test_model") .field("task_type", "sparse_embedding") .endObject() .endObject() @@ -186,10 +185,7 @@ public void testUpdateModelSettings() throws IOException { ); assertThat( exc.getMessage(), - containsString( - "Cannot update parameter [model_settings] " - + "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] to [null]" - ) + containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") ); } { @@ -202,9 +198,9 @@ public void testUpdateModelSettings() throws IOException { .field("type", "semantic_text") .field("inference_id", "test_model") .startObject("model_settings") - .field("inference_id", "test_model") .field("task_type", "text_embedding") .field("dimensions", 10) + .field("similarity", "cosine") .endObject() .endObject() ) @@ -214,8 +210,8 @@ public void testUpdateModelSettings() throws IOException { exc.getMessage(), containsString( "Cannot update parameter [model_settings] " - + "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] " - + "to [{\"task_type\":\"text_embedding\",\"inference_id\":\"test_model\",\"dimensions\":10}]" + + "from [{\"task_type\":\"sparse_embedding\"}] " + + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..b64485a3d3fb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -121,6 +122,16 @@ public void writeTo(StreamOutput out) throws IOException { public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.COSINE; + } + + @Override + public Integer dimensions() { + return 100; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index eb330c60fd0d2..4488af74c0616 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -27,6 +27,7 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, + "similarity": "cosine", "api_key": "abc64" }, "task_settings": { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 46e3f3cfa5be3..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -27,7 +27,8 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, - "api_key": "abc64" + "api_key": "abc64", + "similarity": "cosine" }, "task_settings": { } @@ -55,25 +56,7 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 - - text: "another inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 + sparse_field: "you know, for testing" --- "Dense vector results format": @@ -82,72 +65,4 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: dense-inference-id - task_type: text_embedding - dimensions: 5 - similarity: cosine - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] - ---- -"Model settings inference id not included": - - do: - catch: /Required \[inference_id\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings task type not included": - - do: - catch: /Required \[task_type\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings dense vector dimensions not included": - - do: - catch: /Model settings for field \[dense_field\] must contain dimensions/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: dense-inference-id - task_type: text_embedding - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + dense_field: "you know, for testing" From eb4731fa33ce7f0c39e35240ed45be1af1fbd3b3 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 22:51:37 +0000 Subject: [PATCH 09/13] remove unused --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 564cca4821cb1..5fb536022bc76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -86,11 +86,6 @@ private SemanticTextFieldMapper( this.subMappers = subMappers; } - @Override - public String name() { - return super.name(); - } - @Override public Iterator iterator() { List subIterators = new ArrayList<>(); From b3fb5d3017454e668c6e2249c425caa3bd5a0ec4 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 21 Mar 2024 10:09:36 +0000 Subject: [PATCH 10/13] address review comments --- .../mapper/InferenceMetadataFieldMapper.java | 8 +- .../mapper/SemanticTextFieldMapper.java | 23 ++--- .../ShardBulkInferenceActionFilterTests.java | 2 +- .../InferenceMetadataFieldMapperTests.java | 89 ++++++++++++++++--- .../mapper/SemanticTextFieldMapperTests.java | 6 +- 5 files changed, 93 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index e962fc0ec9270..c11e8f8b82bf0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -116,7 +116,7 @@ public class InferenceMetadataFieldMapper extends MetadataFieldMapper { public static final String NAME = "_inference"; public static final String CONTENT_TYPE = "_inference"; - private static final String INFERENCE_ID = "inference_id"; + public static final String INFERENCE_ID = "inference_id"; public static final String RESULTS = "results"; public static final String INFERENCE_CHUNKS_RESULTS = "inference"; public static final String INFERENCE_CHUNKS_TEXT = "text"; @@ -206,8 +206,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( } SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( original.simpleName(), - docContext.indexSettings().getIndexVersionCreated(), - docContext.indexAnalyzers() + docContext.indexSettings().getIndexVersionCreated() ).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); docContext.addDynamicMapper(newMapper); return newMapper.getSubMappers(); @@ -227,7 +226,7 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex XContentParser parser = context.parser(); String fieldName = parser.currentName(); var res = findMapper(context.mappingLookup().getMapping().getRoot(), fieldName); - if (res == null || res.mapper == null || res.mapper instanceof SemanticTextFieldMapper == false) { + if (res == null || res.mapper instanceof SemanticTextFieldMapper == false) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -240,7 +239,6 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex XContentLocation xContentLocation = parser.getTokenLocation(); // parse eagerly to extract the inference id and the model settings first Map map = parser.mapOrdered(); - logger.info("map=" + map.toString()); // inference_id Object inferenceIdObj = map.remove(INFERENCE_ID); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 5fb536022bc76..7925bfac39476 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -10,12 +10,12 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; @@ -23,7 +23,6 @@ import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -61,13 +60,12 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { } public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c.indexVersionCreated(), c.getIndexAnalyzers()), + (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) ); private final IndexVersion indexVersionCreated; private final SemanticTextModelSettings modelSettings; - private final IndexAnalyzers indexAnalyzers; private final NestedObjectMapper subMappers; private SemanticTextFieldMapper( @@ -75,13 +73,11 @@ private SemanticTextFieldMapper( MappedFieldType mappedFieldType, CopyTo copyTo, IndexVersion indexVersionCreated, - IndexAnalyzers indexAnalyzers, SemanticTextModelSettings modelSettings, NestedObjectMapper subMappers ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); this.indexVersionCreated = indexVersionCreated; - this.indexAnalyzers = indexAnalyzers; this.modelSettings = modelSettings; this.subMappers = subMappers; } @@ -95,7 +91,7 @@ public Iterator iterator() { @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName(), indexVersionCreated, indexAnalyzers).init(this); + return new Builder(simpleName(), indexVersionCreated).init(this); } @Override @@ -124,7 +120,6 @@ public NestedObjectMapper getSubMappers() { public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; - private final IndexAnalyzers indexAnalyzers; private final Parameter inferenceId = Parameter.stringParam( "inference_id", @@ -149,10 +144,9 @@ public static class Builder extends FieldMapper.Builder { ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - public Builder(String name, IndexVersion indexVersionCreated, IndexAnalyzers indexAnalyzers) { + public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; - this.indexAnalyzers = indexAnalyzers; } public Builder setInferenceId(String id) { @@ -175,11 +169,9 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(RESULTS, indexVersionCreated); nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - indexAnalyzers - ).index(false).store(false); + KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) + .indexed(false) + .docValues(false); if (modelSettings.get() != null) { nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); } @@ -191,7 +183,6 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), copyTo, indexVersionCreated, - indexAnalyzers, modelSettings.getValue(), subMappers ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 766a27f24df75..8b18cf74236a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -285,7 +285,7 @@ private static BulkItemRequest[] randomBulkItemRequest( final ChunkedInferenceServiceResults results; switch (taskType) { case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); + results = randomTextEmbeddings(model, chunks); break; case SPARSE_EMBEDDING: diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index aee2db47e18a3..1d517204ab598 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -102,8 +102,8 @@ public void testSuccessfulParse() throws IOException { final String fieldName1 = randomFieldName(depth); final String fieldName2 = randomFieldName(depth + 1); - Model model1 = randomModel(); - Model model2 = randomModel(); + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); XContentBuilder mapping = mapping(b -> { addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); @@ -216,7 +216,7 @@ public void testSuccessfulParse() throws IOException { public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); DocumentMapper documentMapper = createDocumentMapper( mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) @@ -283,7 +283,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); final List semanticTextInferenceResultsList = List.of( randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); @@ -388,7 +388,13 @@ public void testMissingSemanticTextMapping() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, randomModel(), List.of("a b"))) + List.of( + randomSemanticTextInferenceResults( + fieldName, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + List.of("a b") + ) + ) ) ) ) @@ -401,6 +407,64 @@ public void testMissingSemanticTextMapping() throws IOException { ); } + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .startObject(SemanticTextModelSettings.NAME) + .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .startObject(SemanticTextModelSettings.NAME) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); + } + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); @@ -408,10 +472,10 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { List chunks = new ArrayList<>(); for (String input : inputs) { - double[] values = new double[5]; + double[] values = new double[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = randomDouble(); } @@ -437,7 +501,12 @@ private static SemanticTextInferenceResults randomSemanticTextInferenceResults( Model model, List chunks ) { - return new SemanticTextInferenceResults(semanticTextFieldName, model, randomSparseEmbeddings(chunks), chunks); + ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); + default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); + }; + return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); } private static void addSemanticTextInferenceResults( @@ -495,12 +564,12 @@ static String randomFieldName(int numLevel) { return builder.toString(); } - private static Model randomModel() { + private static Model randomModel(TaskType taskType) { String serviceName = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomAlphaOfLengthBetween(5, 10); return new TestModel( inferenceId, - TaskType.SPARSE_EMBEDDING, + taskType, serviceName, new TestModel.TestServiceSettings("my-model"), new TestModel.TestTaskSettings(randomIntBetween(1, 100)), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 15c50e530c1c0..551b5f73fe27e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperParsingException; @@ -18,7 +19,6 @@ import org.elasticsearch.index.mapper.MapperTestCase; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.plugins.Plugin; @@ -240,8 +240,8 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); assertNotNull(textMapper); - assertThat(textMapper, instanceOf(TextFieldMapper.class)); - TextFieldMapper textFieldMapper = (TextFieldMapper) textMapper; + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; assertFalse(textFieldMapper.fieldType().isIndexed()); assertFalse(textFieldMapper.fieldType().hasDocValues()); if (expectedModelSettings) { From 8ddc37ff496cd2d5d11c2349a99279d3da7e347d Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 21 Mar 2024 13:21:38 +0000 Subject: [PATCH 11/13] Fix the mapper builder context when updating the semantic text field definition --- .../mapper/InferenceMetadataFieldMapper.java | 94 +++++++++++-------- .../mapper/SemanticTextFieldMapperTests.java | 10 +- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index c11e8f8b82bf0..20315cb43e2a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; @@ -158,14 +159,13 @@ public InferenceMetadataFieldMapper() { protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); boolean origWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing context.path().setWithinLeafObject(true); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); - parseSingleField(context, mapperBuilderContext); + parseSingleField(context); } } finally { context.path().setWithinLeafObject(origWithLeafObject); @@ -174,59 +174,56 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio private NestedObjectMapper updateSemanticTextFieldMapper( DocumentParserContext docContext, - MapperBuilderContext mapperBuilderContext, - ObjectMapper parent, - SemanticTextFieldMapper original, - String inferenceId, - SemanticTextModelSettings modelSettings, + SemanticTextMapperContext semanticFieldContext, + String newInferenceId, + SemanticTextModelSettings newModelSettings, XContentLocation xContentLocation ) { - if (inferenceId.equals(original.fieldType().getInferenceId()) == false) { + final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); + final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + if (newInferenceId.equals(inferenceId) == false) { throw new DocumentParsingException( xContentLocation, Strings.format( "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", INFERENCE_ID, inferenceId, - original.name(), + fullFieldName, INFERENCE_ID, - original.fieldType().getInferenceId() + newInferenceId ) ); } - if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) { + if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { throw new DocumentParsingException( xContentLocation, - "Model settings for field [" + original.name() + "] must contain dimensions" + "Model settings for field [" + fullFieldName + "] must contain dimensions" ); } - if (original.getModelSettings() == null) { - if (parent != docContext.root()) { - mapperBuilderContext = mapperBuilderContext.createChildContext(parent.name(), ObjectMapper.Dynamic.FALSE); - } + if (semanticFieldContext.mapper.getModelSettings() == null) { SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( - original.simpleName(), + semanticFieldContext.mapper.simpleName(), docContext.indexSettings().getIndexVersionCreated() - ).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); docContext.addDynamicMapper(newMapper); return newMapper.getSubMappers(); } else { - SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(original.name()); - SemanticTextFieldMapper.canMergeModelSettings(original.getModelSettings(), modelSettings, conflicts); + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); } } - return original.getSubMappers(); + return semanticFieldContext.mapper.getSubMappers(); } - private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { + private void parseSingleField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); String fieldName = parser.currentName(); - var res = findMapper(context.mappingLookup().getMapping().getRoot(), fieldName); - if (res == null || res.mapper instanceof SemanticTextFieldMapper == false) { + SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); + if (builderContext == null) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -276,15 +273,7 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex ); } - var nestedObjectMapper = updateSemanticTextFieldMapper( - context, - mapperBuilderContext, - res.parent, - (SemanticTextFieldMapper) res.mapper, - inferenceId, - modelSettings, - xContentLocation - ); + var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); // we know the model settings, so we can (re) parse the results array now XContentParser subParser = new MapXContentParser( @@ -419,17 +408,40 @@ public static void applyFieldInference( inferenceMap.put(field, fieldMap); } - record MapperAndParent(ObjectMapper parent, Mapper mapper) {} + record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} + + /** + * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} + * and the {@link MapperBuilderContext} that was used to build it. + * If the field is not found or is of the wrong type, this method returns {@code null}. + */ + static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { + ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); + return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName, fullName.split("\\.")); + } - static MapperAndParent findMapper(ObjectMapper mapper, String fullPath) { - String[] pathElements = fullPath.split("\\."); - for (int i = 0; i < pathElements.length - 1; i++) { - Mapper next = mapper.getMapper(pathElements[i]); - if (next == null || next instanceof ObjectMapper == false) { + static SemanticTextMapperContext createSemanticFieldContext( + MapperBuilderContext mapperContext, + ObjectMapper objectMapper, + String fullName, + String[] paths + ) { + Mapper mapper = objectMapper.getMapper(paths[0]); + if (mapper instanceof ObjectMapper newObjectMapper) { + mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); + return createSemanticFieldContext(mapperContext, newObjectMapper, fullName, Arrays.copyOfRange(paths, 1, paths.length)); + } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } else { + if (mapper == null || paths.length == 1) { return null; } - mapper = (ObjectMapper) next; + // check if the semantic field is defined within a multi-field + Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); + if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } } - return new MapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + return null; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 551b5f73fe27e..e9b5a788256d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; @@ -31,7 +32,7 @@ import java.util.List; import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.findMapper; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -219,7 +220,12 @@ public void testUpdateModelSettings() throws IOException { } static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - var res = findMapper(mapperService.mappingLookup().getMapping().getRoot(), fieldName); + InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( + MapperBuilderContext.root(false, false), + mapperService.mappingLookup().getMapping().getRoot(), + fieldName, + fieldName.split("\\.") + ); Mapper mapper = res.mapper(); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); From b3ae2840d2351e103bb287f2d141a85c99589849 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 21 Mar 2024 13:31:03 +0000 Subject: [PATCH 12/13] string formatting error --- .../xpack/inference/mapper/SemanticTextModelSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 54e86c108fa43..b1d0511008db8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -158,7 +158,7 @@ public void validate() { + TASK_TYPE_FIELD.getPreferredName() + "], expected " + TEXT_EMBEDDING - + "or " + + " or " + SPARSE_EMBEDDING + ", got " + taskType.name() From 2e7fc7f386bf5b0b62d87b897a6b3593ebb204ed Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 22 Mar 2024 10:17:43 +0000 Subject: [PATCH 13/13] results => chunks renaming --- .../mapper/InferenceMetadataFieldMapper.java | 42 +- .../mapper/InferenceResultFieldMapper.java | 372 ------------------ .../mapper/SemanticTextFieldMapper.java | 11 +- .../InferenceMetadataFieldMapperTests.java | 26 +- .../mapper/SemanticTextFieldMapperTests.java | 3 +- .../inference/10_semantic_text_inference.yml | 44 +-- 6 files changed, 66 insertions(+), 432 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 20315cb43e2a0..9eeb7a5407bc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -70,7 +70,7 @@ * "model_settings": { * "task_type": "SPARSE_EMBEDDING" * }, - * "results" [ + * "chunks" [ * { * "inference": { * "lucas": 0.05212344, @@ -89,22 +89,25 @@ * } * * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: + * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: *
*
*
  * {
  *     "mappings": {
  *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
+ *             "my_semantic_field": {
+ *                 "chunks": {
+ *                      "type": "nested",
+ *                      "properties": {
+ *                          "embedding": {
+ *                              "type": "sparse_vector|dense_vector"
+ *                          },
+ *                          "text": {
+ *                              "type": "keyword",
+ *                              "index": false,
+ *                              "doc_values": false
+ *                          }
  *                     }
  *                 }
  *             }
@@ -118,7 +121,7 @@ public class InferenceMetadataFieldMapper extends MetadataFieldMapper {
     public static final String CONTENT_TYPE = "_inference";
 
     public static final String INFERENCE_ID = "inference_id";
-    public static final String RESULTS = "results";
+    public static final String CHUNKS = "chunks";
     public static final String INFERENCE_CHUNKS_RESULTS = "inference";
     public static final String INFERENCE_CHUNKS_TEXT = "text";
 
@@ -283,10 +286,10 @@ private void parseSingleField(DocumentParserContext context) throws IOException
             XContentType.JSON
         );
         DocumentParserContext mapContext = context.switchParser(subParser);
-        parseFieldInferenceObject(xContentLocation, subParser, mapContext, nestedObjectMapper);
+        parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper);
     }
 
-    private void parseFieldInferenceObject(
+    private void parseFieldInference(
         XContentLocation xContentLocation,
         XContentParser parser,
         DocumentParserContext context,
@@ -296,13 +299,13 @@ private void parseFieldInferenceObject(
         failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT);
         for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
             switch (parser.currentName()) {
-                case RESULTS -> parseResultsList(xContentLocation, parser, context, nestedMapper);
+                case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper);
                 default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName());
             }
         }
     }
 
-    private void parseResultsList(
+    private void parseChunks(
         XContentLocation xContentLocation,
         XContentParser parser,
         DocumentParserContext context,
@@ -404,7 +407,7 @@ public static void applyFieldInference(
         Map fieldMap = new LinkedHashMap<>();
         fieldMap.put(INFERENCE_ID, model.getInferenceEntityId());
         fieldMap.putAll(new SemanticTextModelSettings(model).asMap());
-        fieldMap.put(RESULTS, chunks);
+        fieldMap.put(CHUNKS, chunks);
         inferenceMap.put(field, fieldMap);
     }
 
@@ -417,19 +420,18 @@ record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextField
      */
     static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) {
         ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot();
-        return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName, fullName.split("\\."));
+        return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\."));
     }
 
     static SemanticTextMapperContext createSemanticFieldContext(
         MapperBuilderContext mapperContext,
         ObjectMapper objectMapper,
-        String fullName,
         String[] paths
     ) {
         Mapper mapper = objectMapper.getMapper(paths[0]);
         if (mapper instanceof ObjectMapper newObjectMapper) {
             mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE);
-            return createSemanticFieldContext(mapperContext, newObjectMapper, fullName, Arrays.copyOfRange(paths, 1, paths.length));
+            return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length));
         } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) {
             return new SemanticTextMapperContext(mapperContext, semanticMapper);
         } else {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java
deleted file mode 100644
index 2ede5419ab74e..0000000000000
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java
+++ /dev/null
@@ -1,372 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.mapper;
-
-import org.apache.lucene.search.Query;
-import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.index.IndexVersion;
-import org.elasticsearch.index.mapper.DocumentParserContext;
-import org.elasticsearch.index.mapper.DocumentParsingException;
-import org.elasticsearch.index.mapper.FieldMapper;
-import org.elasticsearch.index.mapper.MappedFieldType;
-import org.elasticsearch.index.mapper.Mapper;
-import org.elasticsearch.index.mapper.MapperBuilderContext;
-import org.elasticsearch.index.mapper.MetadataFieldMapper;
-import org.elasticsearch.index.mapper.NestedObjectMapper;
-import org.elasticsearch.index.mapper.ObjectMapper;
-import org.elasticsearch.index.mapper.SourceLoader;
-import org.elasticsearch.index.mapper.SourceValueFetcher;
-import org.elasticsearch.index.mapper.TextFieldMapper;
-import org.elasticsearch.index.mapper.TextSearchInfo;
-import org.elasticsearch.index.mapper.ValueFetcher;
-import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
-import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
-import org.elasticsearch.index.query.SearchExecutionContext;
-import org.elasticsearch.inference.ChunkedInferenceServiceResults;
-import org.elasticsearch.inference.Model;
-import org.elasticsearch.inference.SimilarityMeasure;
-import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.logging.LogManager;
-import org.elasticsearch.logging.Logger;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * A mapper for the {@code _semantic_text_inference} field.
- * 
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 7925bfac39476..2445d5c8751a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; @@ -40,9 +41,9 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; /** * A {@link FieldMapper} for semantic text fields. @@ -132,7 +133,6 @@ public static class Builder extends FieldMapper.Builder { } }); - @SuppressWarnings("unchecked") private final Parameter modelSettings = new Parameter<>( "model_settings", true, @@ -167,7 +167,7 @@ protected Parameter[] getParameters() { @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(RESULTS, indexVersionCreated); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) .indexed(false) @@ -241,6 +241,11 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return super.syntheticFieldLoader(); + } + private static Mapper.Builder createInferenceMapperBuilder( String fieldName, SemanticTextModelSettings modelSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index 1d517204ab598..37e4e5e774bec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -59,9 +59,9 @@ import java.util.Set; import java.util.function.Consumer; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -131,9 +131,9 @@ public void testSuccessfulParse() throws IOException { assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } // nested docs are in reversed order - assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".results.inference", 2); - assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".results.inference", 1); - assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".results.inference", 3); + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); assertEquals(doc.rootDoc(), luceneDocs.get(3)); assertNull(luceneDocs.get(3).getParent()); @@ -147,9 +147,9 @@ public void testSuccessfulParse() throws IOException { Set visitedNestedIdentities = new HashSet<>(); Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 0, null), - new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 1, null), - new SearchHit.NestedIdentity(fieldName2 + "." + RESULTS, 0, null) + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) ); assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); @@ -167,7 +167,7 @@ public void testSuccessfulParse() throws IOException { TopDocs topDocs = searcher.search( generateNestedTermSparseVectorQuery( mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + RESULTS, + fieldName1 + "." + CHUNKS, List.of("a") ), 10 @@ -179,7 +179,7 @@ public void testSuccessfulParse() throws IOException { TopDocs topDocs = searcher.search( generateNestedTermSparseVectorQuery( mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + RESULTS, + fieldName1 + "." + CHUNKS, List.of("a", "b") ), 10 @@ -191,7 +191,7 @@ public void testSuccessfulParse() throws IOException { TopDocs topDocs = searcher.search( generateNestedTermSparseVectorQuery( mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + RESULTS, + fieldName2 + "." + CHUNKS, List.of("d") ), 10 @@ -203,7 +203,7 @@ public void testSuccessfulParse() throws IOException { TopDocs topDocs = searcher.search( generateNestedTermSparseVectorQuery( mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + RESULTS, + fieldName2 + "." + CHUNKS, List.of("z") ), 10 @@ -294,7 +294,7 @@ public void testExtraSubfields() throws IOException { Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + RESULTS)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -539,7 +539,7 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(RESULTS); + List> fieldResultList = (List>) optionsMap.get(CHUNKS); for (var entry : fieldResultList) { if (includeTextSubfield == false) { entry.remove(INFERENCE_CHUNKS_TEXT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index e9b5a788256d0..1b5311ac9effb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -223,7 +223,6 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( MapperBuilderContext.root(false, false), mapperService.mappingLookup().getMapping().getRoot(), - fieldName, fieldName.split("\\.") ); Mapper mapper = res.mapper(); @@ -242,7 +241,7 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() .nestedLookup() .getNestedMappers() - .get(fieldName + "." + InferenceMetadataFieldMapper.RESULTS); + .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); assertNotNull(textMapper); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 4488af74c0616..528003e278aeb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -84,11 +84,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- "text expansion documents do not create new mappings": @@ -121,11 +121,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- @@ -155,8 +155,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: update: @@ -175,11 +175,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -215,8 +215,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -234,8 +234,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -272,11 +272,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model":