diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index b0e5129c8439b..2b84ec8746cd2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -29,7 +29,6 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.routing.IndexRouting; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.TriConsumer; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; @@ -48,6 +47,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -237,15 +237,12 @@ void processRequestsByShards( BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final TriConsumer bulkItemFailedListener = ( - itemReq, - itemIndex, - e) -> markBulkItemRequestFailed(bulkShardRequest, itemReq, itemIndex, e); + final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { @Override - public void onResponse(BulkShardRequest bulkShardRequest) { + public void onResponse(BulkShardRequest inferenceBulkShardRequest) { executeBulkShardRequest( - bulkShardRequest, + inferenceBulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), bulkItemFailedListener ); @@ -276,21 +273,18 @@ private BulkShardRequest createBulkShardRequest(ClusterState clusterState, Shard } // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, int bulkItemIndex, Exception e) { + private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { final String indexName = itemRequest.index(); DocWriteRequest docWriteRequest = itemRequest.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - - // make sure the request gets never processed again, removing the item from the shard request - shardRequest.items()[bulkItemIndex] = null; } private void executeBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - TriConsumer bulkItemErrorListener + BiConsumer bulkItemErrorListener ) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early @@ -315,9 +309,8 @@ public void onResponse(BulkShardResponse bulkShardResponse) { public void onFailure(Exception e) { // create failures for all relevant requests BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest request = items[i]; - bulkItemErrorListener.apply(request, i, e); + for (BulkItemRequest item : items) { + bulkItemErrorListener.accept(item, e); } listener.onFailure(e); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 8a2847ddcb842..7acc93be13a46 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -26,6 +25,7 @@ import org.elasticsearch.inference.ModelRegistry; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -33,6 +33,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.stream.Collectors; /** @@ -106,10 +107,20 @@ public void onFailure(Exception e) { } } + /** + * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from + * the original request will be modified with the inference results, to avoid copying the entire requests from + * the original bulk request. + * + * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. + * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, + * which may have fewer items than the original because of inference failures + * @param onBulkItemFailure invoked when a bulk item fails inference + */ public void processBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - TriConsumer onBulkItemFailure + BiConsumer onBulkItemFailure ) { Map> fieldsForModels = clusterState.metadata() @@ -121,14 +132,41 @@ public void processBulkShardRequest( return; } - Runnable onInferenceComplete = () -> { listener.onResponse(bulkShardRequest); }; + Set failedItems = Collections.synchronizedSet(new HashSet<>()); + Runnable onInferenceComplete = () -> { + if (failedItems.isEmpty()) { + listener.onResponse(bulkShardRequest); + return; + } + // Remove failed items from the original bulk shard request + BulkItemRequest[] originalItems = bulkShardRequest.items(); + BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; + for (int i = 0, j = 0; i < originalItems.length; i++) { + if (failedItems.contains(i) == false) { + newItems[j++] = originalItems[i]; + } + } + BulkShardRequest newBulkShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + newItems + ); + listener.onResponse(newBulkShardRequest); + }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { BulkItemRequest[] items = bulkShardRequest.items(); for (int i = 0; i < items.length; i++) { BulkItemRequest bulkItemRequest = items[i]; // Bulk item might be null because of previous errors, skip in that case if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest(bulkItemRequest, i, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire()); + performInferenceOnBulkItemRequest( + bulkItemRequest, + fieldsForModels, + i, + onBulkItemFailure, + failedItems, + bulkItemReqRef.acquire() + ); } } } @@ -136,9 +174,10 @@ public void processBulkShardRequest( private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, - int bulkItemIndex, Map> fieldsForModels, - TriConsumer onBulkItemFailure, + Integer itemIndex, + BiConsumer onBulkItemFailure, + Set failedItems, Releasable releaseOnFinish ) { @@ -186,9 +225,9 @@ private void performInferenceOnBulkItemRequest( InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { - onBulkItemFailure.apply( + failedItems.add(itemIndex); + onBulkItemFailure.accept( bulkItemRequest, - bulkItemIndex, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); continue; @@ -223,7 +262,8 @@ public void onResponse(InferenceServiceResults results) { @Override public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, bulkItemIndex, e); + failedItems.add(itemIndex); + onBulkItemFailure.accept(bulkItemRequest, e); } }; inferenceProvider.service() diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index c1500abc28abe..a688df5d797a2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -111,6 +111,7 @@ public void testNoInference() { fieldsForModels, modelRegistry, inferenceServiceRegistry, + true, bulkOperationListener ); verify(bulkOperationListener).onResponse(any()); @@ -151,6 +152,7 @@ public void testFailedBulkShardRequest() { modelRegistry, inferenceServiceRegistry, bulkOperationListener, + true, request -> new BulkShardResponse( request.shardId(), new BulkItemResponse[] { @@ -217,6 +219,7 @@ public void testInference() { fieldsForModels, modelRegistry, inferenceServiceRegistry, + true, bulkOperationListener ); verify(bulkOperationListener).onResponse(any()); @@ -267,16 +270,8 @@ public void testFailedInference() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - assertNull(items[0]); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); assertTrue(bulkResponse.hasFailures()); @@ -313,16 +308,10 @@ public void testInferenceIdNotFound() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - assertNull(items[0]); + doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); assertTrue(bulkResponse.hasFailures()); @@ -392,6 +381,7 @@ private static BulkShardRequest runBulkOperation( Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, + boolean expectTransportShardBulkActionToExecute, ActionListener bulkOperationListener ) { return runBulkOperation( @@ -400,6 +390,7 @@ private static BulkShardRequest runBulkOperation( modelRegistry, inferenceServiceRegistry, bulkOperationListener, + expectTransportShardBulkActionToExecute, successfulBulkShardResponse ); } @@ -410,6 +401,7 @@ private static BulkShardRequest runBulkOperation( ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener, + boolean expectTransportShardBulkActionToExecute, Function bulkShardResponseSupplier ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); @@ -456,9 +448,12 @@ private static BulkShardRequest runBulkOperation( ); bulkOperation.doRun(); - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + if (expectTransportShardBulkActionToExecute) { + verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + return bulkShardRequestCaptor.getValue(); + } - return bulkShardRequestCaptor.getValue(); + return null; } private static final Function successfulBulkShardResponse = (request) -> {