From 0f1ef56c9c334f4ac53e52a6a30b502cd40e3647 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 14 May 2024 17:45:02 +0200 Subject: [PATCH] Refactorings --- .../BulkShardOperationInferenceProcessor.java | 2 +- .../ShardBulkInferenceActionFilter.java | 536 ------------------ .../ShardBulkInferenceActionFilterTests.java | 386 ------------- 3 files changed, 1 insertion(+), 923 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessor.java index 51ebc21d11ecb..4a69a57602244 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessor.java @@ -85,7 +85,7 @@ public BulkShardOperationInferenceProcessor( @Override public void apply(BulkShardRequest request, ClusterState clusterState, ActionListener listener) { - var indexMetadata = clusterState.getMetadata().index(request.shardId().getIndexName()); + var indexMetadata = clusterState.metadata().index(request.shardId().getIndexName()); if (indexMetadata != null) { var fieldInferenceMetadata = indexMetadata.getInferenceFields(); if (fieldInferenceMetadata.isEmpty() == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java deleted file mode 100644 index 38d8b8d9b35c0..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ /dev/null @@ -1,536 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.action.filter; - -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.bulk.BulkItemRequest; -import org.elasticsearch.action.bulk.BulkShardRequest; -import org.elasticsearch.action.bulk.TransportShardBulkAction; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilterChain; -import org.elasticsearch.action.support.MappedActionFilter; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.inference.mapper.SemanticTextField; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; - -/** - * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified - * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in - * the request source, we generate embeddings and include the results in the source under the semantic text field - * name as a {@link SemanticTextField}. - * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the - * results during indexing on the shard. - * - * TODO: batchSize should be configurable via a cluster setting - */ -public class ShardBulkInferenceActionFilter implements MappedActionFilter { - protected static final int DEFAULT_BATCH_SIZE = 512; - - private final InferenceServiceRegistry inferenceServiceRegistry; - private final ModelRegistry modelRegistry; - private final int batchSize; - - public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { - this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); - } - - public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { - this.inferenceServiceRegistry = inferenceServiceRegistry; - this.modelRegistry = modelRegistry; - this.batchSize = batchSize; - } - - @Override - public int order() { - // must execute last (after the security action filter) - return Integer.MAX_VALUE; - } - - @Override - public String actionName() { - return TransportShardBulkAction.ACTION_NAME; - } - - @Override - public void apply( - Task task, - String action, - Request request, - ActionListener listener, - ActionFilterChain chain - ) { - if (TransportShardBulkAction.ACTION_NAME.equals(action)) { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); - if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { - Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); - processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); - return; - } - } - chain.proceed(task, action, request, listener); - } - - private void processBulkShardRequest( - Map fieldInferenceMap, - BulkShardRequest bulkShardRequest, - Runnable onCompletion - ) { - new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); - } - - private record InferenceProvider(InferenceService service, Model model) {} - - /** - * A field inference request on a single input. - * @param index The index of the request in the original bulk request. - * @param field The target field. - * @param input The input to run inference on. - * @param inputOrder The original order of the input. - * @param isOriginalFieldInput Whether the input is part of the original values of the field. - */ - private record FieldInferenceRequest(int index, String field, String input, int inputOrder, boolean isOriginalFieldInput) {} - - /** - * The field inference response. - * @param field The target field. - * @param input The input that was used to run inference. - * @param inputOrder The original order of the input. - * @param isOriginalFieldInput Whether the input is part of the original values of the field. - * @param model The model used to run inference. - * @param chunkedResults The actual results. - */ - private record FieldInferenceResponse( - String field, - String input, - int inputOrder, - boolean isOriginalFieldInput, - Model model, - ChunkedInferenceServiceResults chunkedResults - ) {} - - private record FieldInferenceResponseAccumulator( - int id, - Map> responses, - List failures - ) { - void addOrUpdateResponse(FieldInferenceResponse response) { - synchronized (this) { - var list = responses.computeIfAbsent(response.field, k -> new ArrayList<>()); - list.add(response); - } - } - - void addFailure(Exception exc) { - synchronized (this) { - failures.add(exc); - } - } - } - - private class AsyncBulkShardInferenceAction implements Runnable { - private final Map fieldInferenceMap; - private final BulkShardRequest bulkShardRequest; - private final Runnable onCompletion; - private final AtomicArray inferenceResults; - - private AsyncBulkShardInferenceAction( - Map fieldInferenceMap, - BulkShardRequest bulkShardRequest, - Runnable onCompletion - ) { - this.fieldInferenceMap = fieldInferenceMap; - this.bulkShardRequest = bulkShardRequest; - this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); - this.onCompletion = onCompletion; - } - - @Override - public void run() { - Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); - Runnable onInferenceCompletion = () -> { - try { - for (var inferenceResponse : inferenceResults.asList()) { - var request = bulkShardRequest.items()[inferenceResponse.id]; - try { - applyInferenceResponses(request, inferenceResponse); - } catch (Exception exc) { - request.abort(bulkShardRequest.index(), exc); - } - } - } finally { - onCompletion.run(); - } - }; - try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { - for (var entry : inferenceRequests.entrySet()) { - executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); - } - } - } - - private void executeShardBulkInferenceAsync( - final String inferenceId, - @Nullable InferenceProvider inferenceProvider, - final List requests, - final Releasable onFinish - ) { - if (inferenceProvider == null) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - var provider = new InferenceProvider( - service.get(), - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ) - ); - executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); - } else { - try (onFinish) { - for (FieldInferenceRequest request : requests) { - inferenceResults.get(request.index).failures.add( - new ResourceNotFoundException( - "Inference service [{}] not found for field [{}]", - unparsedModel.service(), - request.field - ) - ); - } - } - } - } - - @Override - public void onFailure(Exception exc) { - try (onFinish) { - for (FieldInferenceRequest request : requests) { - Exception failure; - if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { - failure = new ResourceNotFoundException( - "Inference id [{}] not found for field [{}]", - inferenceId, - request.field - ); - } else { - failure = new ElasticsearchException( - "Error loading inference for inference id [{}] on field [{}]", - exc, - inferenceId, - request.field - ); - } - inferenceResults.get(request.index).failures.add(failure); - } - } - } - }; - modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); - return; - } - int currentBatchSize = Math.min(requests.size(), batchSize); - final List currentBatch = requests.subList(0, currentBatchSize); - final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); - ActionListener> completionListener = new ActionListener<>() { - @Override - public void onResponse(List results) { - try { - var requestsIterator = requests.iterator(); - for (ChunkedInferenceServiceResults result : results) { - var request = requestsIterator.next(); - var acc = inferenceResults.get(request.index); - if (result instanceof ErrorChunkedInferenceResults error) { - acc.addFailure( - new ElasticsearchException( - "Exception when running inference id [{}] on field [{}]", - error.getException(), - inferenceProvider.model.getInferenceEntityId(), - request.field - ) - ); - } else { - acc.addOrUpdateResponse( - new FieldInferenceResponse( - request.field(), - request.input(), - request.inputOrder(), - request.isOriginalFieldInput(), - inferenceProvider.model, - result - ) - ); - } - } - } finally { - onFinish(); - } - } - - @Override - public void onFailure(Exception exc) { - try { - for (FieldInferenceRequest request : requests) { - addInferenceResponseFailure( - request.index, - new ElasticsearchException( - "Exception when running inference id [{}] on field [{}]", - exc, - inferenceProvider.model.getInferenceEntityId(), - request.field - ) - ); - } - } finally { - onFinish(); - } - } - - private void onFinish() { - if (nextBatch.isEmpty()) { - onFinish.close(); - } else { - executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); - } - } - }; - inferenceProvider.service() - .chunkedInfer( - inferenceProvider.model(), - null, - inputs, - Map.of(), - InputType.INGEST, - new ChunkingOptions(null, null), - TimeValue.MAX_VALUE, - completionListener - ); - } - - private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { - FieldInferenceResponseAccumulator acc = inferenceResults.get(id); - if (acc == null) { - acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); - inferenceResults.set(id, acc); - } - return acc; - } - - private void addInferenceResponseFailure(int id, Exception failure) { - var acc = ensureResponseAccumulatorSlot(id); - acc.addFailure(failure); - } - - /** - * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. - * If the response contains failures, the bulk item request is marked as failed for the downstream action. - * Otherwise, the source of the request is augmented with the field inference results under the - * {@link SemanticTextField#INFERENCE_FIELD} field. - */ - private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { - if (response.failures().isEmpty() == false) { - for (var failure : response.failures()) { - item.abort(item.index(), failure); - } - return; - } - - final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - var newDocMap = indexRequest.sourceAsMap(); - for (var entry : response.responses.entrySet()) { - var fieldName = entry.getKey(); - var responses = entry.getValue(); - var model = responses.get(0).model(); - // ensure that the order in the original field is consistent in case of multiple inputs - Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); - List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); - List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); - var result = new SemanticTextField( - fieldName, - inputs, - new SemanticTextField.InferenceResult( - model.getInferenceEntityId(), - new SemanticTextField.ModelSettings(model), - toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), results, indexRequest.getContentType()) - ), - indexRequest.getContentType() - ); - newDocMap.put(fieldName, result); - } - indexRequest.source(newDocMap, indexRequest.getContentType()); - } - - /** - * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. - * If results are already populated for fields in the original index request, the inference request for this specific - * field is skipped, and the existing results remain unchanged. - * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, - * where an error will be thrown if they mismatch or if the content is malformed. - *

- * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? - */ - private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { - Map> fieldRequestsMap = new LinkedHashMap<>(); - int itemIndex = 0; - for (var item : bulkShardRequest.items()) { - if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security) - continue; - } - boolean isUpdateRequest = false; - final IndexRequest indexRequest; - if (item.request() instanceof IndexRequest ir) { - indexRequest = ir; - } else if (item.request() instanceof UpdateRequest updateRequest) { - isUpdateRequest = true; - if (updateRequest.script() != null) { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "Cannot apply update with a script on indices that contain [{}] field(s)", - RestStatus.BAD_REQUEST, - SemanticTextFieldMapper.CONTENT_TYPE - ) - ); - continue; - } - indexRequest = updateRequest.doc(); - } else { - // ignore delete request - continue; - } - final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMap.values()) { - String field = entry.getName(); - String inferenceId = entry.getInferenceId(); - var originalFieldValue = XContentMapValues.extractValue(field, docMap); - if (originalFieldValue instanceof Map) { - continue; - } - int order = 0; - for (var sourceField : entry.getSourceFields()) { - boolean isOriginalFieldInput = sourceField.equals(field); - var valueObj = XContentMapValues.extractValue(sourceField, docMap); - if (valueObj == null) { - if (isUpdateRequest) { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "Field [{}] must be specified on an update request to calculate inference for field [{}]", - RestStatus.BAD_REQUEST, - sourceField, - field - ) - ); - break; - } - continue; - } - ensureResponseAccumulatorSlot(itemIndex); - final List values; - try { - values = nodeStringValues(field, valueObj); - } catch (Exception exc) { - addInferenceResponseFailure(item.id(), exc); - break; - } - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - for (var v : values) { - fieldRequests.add(new FieldInferenceRequest(itemIndex, field, v, order++, isOriginalFieldInput)); - } - } - } - itemIndex++; - } - return fieldRequestsMap; - } - } - - /** - * This method converts the given {@code valueObj} into a list of strings. - * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. - */ - private static List nodeStringValues(String field, Object valueObj) { - if (valueObj instanceof String value) { - return List.of(value); - } else if (valueObj instanceof Collection values) { - List valuesString = new ArrayList<>(); - for (var v : values) { - if (v instanceof String value) { - valuesString.add(value); - } else { - throw new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - valueObj.getClass().getSimpleName() - ); - } - } - return valuesString; - } - throw new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - valueObj.getClass().getSimpleName() - ); - } - - static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { - if (docWriteRequest instanceof IndexRequest indexRequest) { - return indexRequest; - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - return updateRequest.doc(); - } else { - return null; - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java deleted file mode 100644 index c87faa2b52cc8..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.action.filter; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.bulk.BulkItemRequest; -import org.elasticsearch.action.bulk.BulkItemResponse; -import org.elasticsearch.action.bulk.BulkShardRequest; -import org.elasticsearch.action.bulk.TransportShardBulkAction; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilterChain; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.inference.model.TestModel; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.junit.After; -import org.junit.Before; -import org.mockito.stubbing.Answer; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; -import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ShardBulkInferenceActionFilterTests extends ESTestCase { - private ThreadPool threadPool; - - @Before - public void setupThreadPool() { - threadPool = new TestThreadPool(getTestName()); - } - - @After - public void tearDownThreadPool() throws Exception { - terminate(threadPool); - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - assertNull(((BulkShardRequest) request).getInferenceFieldMap()); - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - BulkShardRequest request = new BulkShardRequest( - new ShardId("test", "test", 0), - WriteRequest.RefreshPolicy.NONE, - new BulkItemRequest[0] - ); - request.setInferenceFieldMap( - Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) - ); - filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testInferenceNotFound() throws Exception { - StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10) - ); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getInferenceFieldMap()); - for (BulkItemRequest item : bulkShardRequest.items()) { - assertNotNull(item.getPrimaryResponse()); - assertTrue(item.getPrimaryResponse().isFailed()); - BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); - assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); - } - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - - Map inferenceFieldMap = Map.of( - "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), - "field2", - new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), - "field3", - new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) - ); - BulkItemRequest[] items = new BulkItemRequest[10]; - for (int i = 0; i < items.length; i++) { - items[i] = randomBulkItemRequest(Map.of(), inferenceFieldMap)[0]; - } - BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setInferenceFieldMap(inferenceFieldMap); - filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testItemFailures() throws Exception { - StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10) - ); - model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getInferenceFieldMap()); - assertThat(bulkShardRequest.items().length, equalTo(3)); - - // item 0 is a failure - assertNotNull(bulkShardRequest.items()[0].getPrimaryResponse()); - assertTrue(bulkShardRequest.items()[0].getPrimaryResponse().isFailed()); - BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); - assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); - - // item 1 is a success - assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); - IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); - assertThat(XContentMapValues.extractValue("field1.text", actualRequest.sourceAsMap()), equalTo("I am a success")); - - // item 2 is a failure - assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); - assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); - failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); - assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - - Map inferenceFieldMap = Map.of( - "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) - ); - BulkItemRequest[] items = new BulkItemRequest[3]; - items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); - items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); - items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); - BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setInferenceFieldMap(inferenceFieldMap); - filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testManyRandomDocs() throws Exception { - Map inferenceModelMap = new HashMap<>(); - int numModels = randomIntBetween(1, 5); - for (int i = 0; i < numModels; i++) { - StaticModel model = StaticModel.createRandomInstance(); - inferenceModelMap.put(model.getInferenceEntityId(), model); - } - - int numInferenceFields = randomIntBetween(1, 5); - Map inferenceFieldMap = new HashMap<>(); - for (int i = 0; i < numInferenceFields; i++) { - String field = randomAlphaOfLengthBetween(5, 10); - String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); - } - - int numRequests = randomIntBetween(100, 1000); - BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; - BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; - for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(inferenceModelMap, inferenceFieldMap); - originalRequests[id] = res[0]; - modifiedRequests[id] = res[1]; - } - - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - assertThat(request, instanceOf(BulkShardRequest.class)); - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getInferenceFieldMap()); - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(originalRequests.length)); - for (int id = 0; id < items.length; id++) { - IndexRequest actualRequest = getIndexRequestOrNull(items[id].request()); - IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request()); - try { - assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType()); - } catch (Exception exc) { - throw new IllegalStateException(exc); - } - } - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setInferenceFieldMap(inferenceFieldMap); - filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); - } - - @SuppressWarnings("unchecked") - private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { - ModelRegistry modelRegistry = mock(ModelRegistry.class); - Answer unparsedModelAnswer = invocationOnMock -> { - String id = (String) invocationOnMock.getArguments()[0]; - ActionListener listener = (ActionListener) invocationOnMock - .getArguments()[1]; - var model = modelMap.get(id); - if (model != null) { - listener.onResponse( - new ModelRegistry.UnparsedModel( - model.getInferenceEntityId(), - model.getTaskType(), - model.getServiceSettings().model(), - XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), - XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) - ) - ); - } else { - listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); - } - return null; - }; - doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); - - InferenceService inferenceService = mock(InferenceService.class); - Answer chunkedInferAnswer = invocationOnMock -> { - StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[7]; - Runnable runnable = () -> { - List results = new ArrayList<>(); - for (String input : inputs) { - results.add(model.getResults(input)); - } - listener.onResponse(results); - }; - if (randomBoolean()) { - try { - threadPool.generic().execute(runnable); - } catch (Exception exc) { - listener.onFailure(exc); - } - } else { - runnable.run(); - } - return null; - }; - doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any(), any()); - - Answer modelAnswer = invocationOnMock -> { - String inferenceId = (String) invocationOnMock.getArguments()[0]; - return modelMap.get(inferenceId); - }; - doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); - - InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); - return filter; - } - - private static BulkItemRequest[] randomBulkItemRequest( - Map modelMap, - Map fieldInferenceMap - ) { - Map docMap = new LinkedHashMap<>(); - Map expectedDocMap = new LinkedHashMap<>(); - XContentType requestContentType = randomFrom(XContentType.values()); - for (var entry : fieldInferenceMap.values()) { - String field = entry.getName(); - var model = modelMap.get(entry.getInferenceId()); - String text = randomAlphaOfLengthBetween(10, 100); - docMap.put(field, text); - expectedDocMap.put(field, text); - if (model == null) { - // ignore results, the doc should fail with a resource not found exception - continue; - } - var result = randomSemanticText(field, model, List.of(text), requestContentType); - model.putResult(text, toChunkedResult(result)); - expectedDocMap.put(field, result); - } - - int requestId = randomIntBetween(0, Integer.MAX_VALUE); - return new BulkItemRequest[] { - new BulkItemRequest(requestId, new IndexRequest("index").source(docMap, requestContentType)), - new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; - } - - private static class StaticModel extends TestModel { - private final Map resultMap; - - StaticModel( - String inferenceEntityId, - TaskType taskType, - String service, - TestServiceSettings serviceSettings, - TestTaskSettings taskSettings, - TestSecretSettings secretSettings - ) { - super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); - this.resultMap = new HashMap<>(); - } - - public static StaticModel createRandomInstance() { - TestModel testModel = TestModel.createRandomInstance(); - return new StaticModel( - testModel.getInferenceEntityId(), - testModel.getTaskType(), - randomAlphaOfLength(10), - testModel.getServiceSettings(), - testModel.getTaskSettings(), - testModel.getSecretSettings() - ); - } - - ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); - } - - void putResult(String text, ChunkedInferenceServiceResults result) { - resultMap.put(text, result); - } - } -}