Skip to content

Commit

Permalink
Avoid removing bulk items from request on failure, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 6, 2024
1 parent 7738460 commit c4154b9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.routing.IndexRouting;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.TriConsumer;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
Expand All @@ -48,6 +47,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;

import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY;
Expand Down Expand Up @@ -237,15 +237,12 @@ void processRequestsByShards(
BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests);

Releasable ref = bulkItemRequestCompleteRefCount.acquire();
final TriConsumer<BulkItemRequest, Integer, Exception> bulkItemFailedListener = (
itemReq,
itemIndex,
e) -> markBulkItemRequestFailed(bulkShardRequest, itemReq, itemIndex, e);
final BiConsumer<BulkItemRequest, Exception> bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e);
bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() {
@Override
public void onResponse(BulkShardRequest bulkShardRequest) {
public void onResponse(BulkShardRequest inferenceBulkShardRequest) {
executeBulkShardRequest(
bulkShardRequest,
inferenceBulkShardRequest,
ActionListener.releaseAfter(ActionListener.noop(), ref),
bulkItemFailedListener
);
Expand Down Expand Up @@ -276,21 +273,18 @@ private BulkShardRequest createBulkShardRequest(ClusterState clusterState, Shard
}

// When an item fails, store the failure in the responses array
private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, int bulkItemIndex, Exception e) {
private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) {
final String indexName = itemRequest.index();

DocWriteRequest<?> docWriteRequest = itemRequest.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e);
responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure));

// make sure the request gets never processed again, removing the item from the shard request
shardRequest.items()[bulkItemIndex] = null;
}

private void executeBulkShardRequest(
BulkShardRequest bulkShardRequest,
ActionListener<BulkShardRequest> listener,
TriConsumer<BulkItemRequest, Integer, Exception> bulkItemErrorListener
BiConsumer<BulkItemRequest, Exception> bulkItemErrorListener
) {
if (bulkShardRequest.items().length == 0) {
// No requests to execute due to previous errors, terminate early
Expand All @@ -315,9 +309,8 @@ public void onResponse(BulkShardResponse bulkShardResponse) {
public void onFailure(Exception e) {
// create failures for all relevant requests
BulkItemRequest[] items = bulkShardRequest.items();
for (int i = 0; i < items.length; i++) {
BulkItemRequest request = items[i];
bulkItemErrorListener.apply(request, i, e);
for (BulkItemRequest item : items) {
bulkItemErrorListener.accept(item, e);
}
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.TriConsumer;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.inference.InferenceResults;
Expand All @@ -26,13 +25,15 @@
import org.elasticsearch.inference.ModelRegistry;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -106,10 +107,20 @@ public void onFailure(Exception e) {
}
}

/**
* Performs inference on the fields that have inference models for a bulk shard request. Bulk items from
* the original request will be modified with the inference results, to avoid copying the entire requests from
* the original bulk request.
*
* @param bulkShardRequest original BulkShardRequest that will be modified with inference results.
* @param listener listener to be called when the inference process is finished with the new BulkShardRequest,
* which may have fewer items than the original because of inference failures
* @param onBulkItemFailure invoked when a bulk item fails inference
*/
public void processBulkShardRequest(
BulkShardRequest bulkShardRequest,
ActionListener<BulkShardRequest> listener,
TriConsumer<BulkItemRequest, Integer, Exception> onBulkItemFailure
BiConsumer<BulkItemRequest, Exception> onBulkItemFailure
) {

Map<String, Set<String>> fieldsForModels = clusterState.metadata()
Expand All @@ -121,24 +132,52 @@ public void processBulkShardRequest(
return;
}

Runnable onInferenceComplete = () -> { listener.onResponse(bulkShardRequest); };
Set<Integer> failedItems = Collections.synchronizedSet(new HashSet<>());
Runnable onInferenceComplete = () -> {
if (failedItems.isEmpty()) {
listener.onResponse(bulkShardRequest);
return;
}
// Remove failed items from the original bulk shard request
BulkItemRequest[] originalItems = bulkShardRequest.items();
BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()];
for (int i = 0, j = 0; i < originalItems.length; i++) {
if (failedItems.contains(i) == false) {
newItems[j++] = originalItems[i];
}
}
BulkShardRequest newBulkShardRequest = new BulkShardRequest(
bulkShardRequest.shardId(),
bulkShardRequest.getRefreshPolicy(),
newItems
);
listener.onResponse(newBulkShardRequest);
};
try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) {
BulkItemRequest[] items = bulkShardRequest.items();
for (int i = 0; i < items.length; i++) {
BulkItemRequest bulkItemRequest = items[i];
// Bulk item might be null because of previous errors, skip in that case
if (bulkItemRequest != null) {
performInferenceOnBulkItemRequest(bulkItemRequest, i, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire());
performInferenceOnBulkItemRequest(
bulkItemRequest,
fieldsForModels,
i,
onBulkItemFailure,
failedItems,
bulkItemReqRef.acquire()
);
}
}
}
}

private void performInferenceOnBulkItemRequest(
BulkItemRequest bulkItemRequest,
int bulkItemIndex,
Map<String, Set<String>> fieldsForModels,
TriConsumer<BulkItemRequest, Integer, Exception> onBulkItemFailure,
Integer itemIndex,
BiConsumer<BulkItemRequest, Exception> onBulkItemFailure,
Set<Integer> failedItems,
Releasable releaseOnFinish
) {

Expand Down Expand Up @@ -186,9 +225,9 @@ private void performInferenceOnBulkItemRequest(

InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId);
if (inferenceProvider == null) {
onBulkItemFailure.apply(
failedItems.add(itemIndex);
onBulkItemFailure.accept(
bulkItemRequest,
bulkItemIndex,
new IllegalArgumentException("No inference provider found for model ID " + modelId)
);
continue;
Expand Down Expand Up @@ -223,7 +262,8 @@ public void onResponse(InferenceServiceResults results) {

@Override
public void onFailure(Exception e) {
onBulkItemFailure.apply(bulkItemRequest, bulkItemIndex, e);
failedItems.add(itemIndex);
onBulkItemFailure.accept(bulkItemRequest, e);
}
};
inferenceProvider.service()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public void testNoInference() {
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
true,
bulkOperationListener
);
verify(bulkOperationListener).onResponse(any());
Expand Down Expand Up @@ -151,6 +152,7 @@ public void testFailedBulkShardRequest() {
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener,
true,
request -> new BulkShardResponse(
request.shardId(),
new BulkItemResponse[] {
Expand Down Expand Up @@ -217,6 +219,7 @@ public void testInference() {
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
true,
bulkOperationListener
);
verify(bulkOperationListener).onResponse(any());
Expand Down Expand Up @@ -267,16 +270,8 @@ public void testFailedInference() {
ArgumentCaptor<BulkResponse> bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class);
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
BulkShardRequest bulkShardRequest = runBulkOperation(
originalSource,
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener
);
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
assertNull(items[0]);
runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener);

verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());
Expand Down Expand Up @@ -313,16 +308,10 @@ public void testInferenceIdNotFound() {
ArgumentCaptor<BulkResponse> bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class);
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
BulkShardRequest bulkShardRequest = runBulkOperation(
originalSource,
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener
);
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
assertNull(items[0]);
doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture());

runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener);

verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());
Expand Down Expand Up @@ -392,6 +381,7 @@ private static BulkShardRequest runBulkOperation(
Map<String, Set<String>> fieldsForModels,
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry,
boolean expectTransportShardBulkActionToExecute,
ActionListener<BulkResponse> bulkOperationListener
) {
return runBulkOperation(
Expand All @@ -400,6 +390,7 @@ private static BulkShardRequest runBulkOperation(
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener,
expectTransportShardBulkActionToExecute,
successfulBulkShardResponse
);
}
Expand All @@ -410,6 +401,7 @@ private static BulkShardRequest runBulkOperation(
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry,
ActionListener<BulkResponse> bulkOperationListener,
boolean expectTransportShardBulkActionToExecute,
Function<BulkShardRequest, BulkShardResponse> bulkShardResponseSupplier
) {
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build();
Expand Down Expand Up @@ -456,9 +448,12 @@ private static BulkShardRequest runBulkOperation(
);

bulkOperation.doRun();
verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any());
if (expectTransportShardBulkActionToExecute) {
verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any());
return bulkShardRequestCaptor.getValue();
}

return bulkShardRequestCaptor.getValue();
return null;
}

private static final Function<BulkShardRequest, BulkShardResponse> successfulBulkShardResponse = (request) -> {
Expand Down

0 comments on commit c4154b9

Please sign in to comment.