Skip to content

Commit

Permalink
Fix request ids management. We need to specify the index of the reque…
Browse files Browse the repository at this point in the history
…st item in the bulk request, not use the item id as it can be non-correlative with multiple shards
  • Loading branch information
carlosdelest committed Apr 17, 2024
1 parent 596bbdf commit 43a3b2b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ private record InferenceProvider(InferenceService service, Model model) {}

/**
* A field inference request on a single input.
* @param id The id of the request in the original bulk request.
* @param index The index of the request in the original bulk request.
* @param field The target field.
* @param input The input to run inference on.
* @param inputOrder The original order of the input.
* @param isOriginalFieldInput Whether the input is part of the original values of the field.
*/
private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isOriginalFieldInput) {}
private record FieldInferenceRequest(int index, String field, String input, int inputOrder, boolean isOriginalFieldInput) {}

/**
* The field inference response.
Expand Down Expand Up @@ -245,7 +245,7 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) {
try (onFinish) {
for (int i = 0; i < requests.size(); i++) {
var request = requests.get(i);
inferenceResults.get(request.id).failures.add(
inferenceResults.get(request.index).failures.add(
new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]",
unparsedModel.service(),
Expand All @@ -262,7 +262,7 @@ public void onFailure(Exception exc) {
try (onFinish) {
for (int i = 0; i < requests.size(); i++) {
var request = requests.get(i);
inferenceResults.get(request.id).failures.add(
inferenceResults.get(request.index).failures.add(
new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field)
);
}
Expand All @@ -283,7 +283,7 @@ public void onResponse(List<ChunkedInferenceServiceResults> results) {
for (int i = 0; i < results.size(); i++) {
var request = requests.get(i);
var result = results.get(i);
var acc = inferenceResults.get(request.id);
var acc = inferenceResults.get(request.index);
if (result instanceof ErrorChunkedInferenceResults error) {
acc.addFailure(
new ElasticsearchException(
Expand Down Expand Up @@ -317,7 +317,7 @@ public void onFailure(Exception exc) {
for (int i = 0; i < requests.size(); i++) {
var request = requests.get(i);
addInferenceResponseFailure(
request.id,
request.index,
new ElasticsearchException(
"Exception when running inference id [{}] on field [{}]",
exc,
Expand Down Expand Up @@ -416,6 +416,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
*/
private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
int itemIndex = 0;
for (var item : bulkShardRequest.items()) {
if (item.getPrimaryResponse() != null) {
// item was already aborted/processed by a filter in the chain upstream (e.g. security)
Expand Down Expand Up @@ -470,7 +471,7 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
}
continue;
}
ensureResponseAccumulatorSlot(item.id());
ensureResponseAccumulatorSlot(itemIndex);
final List<String> values;
try {
values = nodeStringValues(field, valueObj);
Expand All @@ -480,10 +481,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
}
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
for (var v : values) {
fieldRequests.add(new FieldInferenceRequest(item.id(), field, v, order++, isOriginalFieldInput));
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, v, order++, isOriginalFieldInput));
}
}
}
itemIndex++;
}
return fieldRequestsMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void testInferenceNotFound() throws Exception {
);
BulkItemRequest[] items = new BulkItemRequest[10];
for (int i = 0; i < items.length; i++) {
items[i] = randomBulkItemRequest(i, Map.of(), inferenceFieldMap)[0];
items[i] = randomBulkItemRequest(Map.of(), inferenceFieldMap)[0];
}
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
Expand Down Expand Up @@ -222,7 +222,7 @@ public void testManyRandomDocs() throws Exception {
BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests];
BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests];
for (int id = 0; id < numRequests; id++) {
BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFieldMap);
BulkItemRequest[] res = randomBulkItemRequest(inferenceModelMap, inferenceFieldMap);
originalRequests[id] = res[0];
modifiedRequests[id] = res[1];
}
Expand Down Expand Up @@ -321,7 +321,6 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool
}

private static BulkItemRequest[] randomBulkItemRequest(
int id,
Map<String, StaticModel> modelMap,
Map<String, InferenceFieldMetadata> fieldInferenceMap
) {
Expand All @@ -342,9 +341,11 @@ private static BulkItemRequest[] randomBulkItemRequest(
model.putResult(text, toChunkedResult(result));
expectedDocMap.put(field, result);
}

int requestId = randomIntBetween(0, Integer.MAX_VALUE);
return new BulkItemRequest[] {
new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)),
new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
new BulkItemRequest(requestId, new IndexRequest("index").source(docMap, requestContentType)),
new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
}

private static StaticModel randomStaticModel() {
Expand Down

0 comments on commit 43a3b2b

Please sign in to comment.