Skip to content

Commit

Permalink
Merge branch 'refs/heads/feature/semantic-text' into carlosdelest/sem…
Browse files Browse the repository at this point in the history
…antic-text-add-query-yaml-tests
  • Loading branch information
carlosdelest committed Apr 8, 2024
2 parents 5711bba + f565596 commit 3bdefa9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -282,16 +283,27 @@ public void onResponse(List<ChunkedInferenceServiceResults> results) {
var request = requests.get(i);
var result = results.get(i);
var acc = inferenceResults.get(request.id);
acc.addOrUpdateResponse(
new FieldInferenceResponse(
request.field(),
request.input(),
request.inputOrder(),
request.isOriginalFieldInput(),
inferenceProvider.model,
result
)
);
if (result instanceof ErrorChunkedInferenceResults error) {
acc.addFailure(
new ElasticsearchException(
"Exception when running inference id [{}] on field [{}]",
error.getException(),
inferenceProvider.model.getInferenceEntityId(),
request.field
)
);
} else {
acc.addOrUpdateResponse(
new FieldInferenceResponse(
request.field(),
request.input(),
request.inputOrder(),
request.isOriginalFieldInput(),
inferenceProvider.model,
result
)
);
}
}
} finally {
onFinish();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceService;
Expand All @@ -33,7 +34,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.After;
Expand All @@ -54,7 +55,9 @@
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.any;
Expand Down Expand Up @@ -144,6 +147,60 @@ public void testInferenceNotFound() throws Exception {
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testItemFailures() throws Exception {
StaticModel model = randomStaticModel();
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10)
);
model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success")));
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertNull(bulkShardRequest.getInferenceFieldMap());
assertThat(bulkShardRequest.items().length, equalTo(3));

// item 0 is a failure
assertNotNull(bulkShardRequest.items()[0].getPrimaryResponse());
assertTrue(bulkShardRequest.items()[0].getPrimaryResponse().isFailed());
BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
assertThat(failure.getCause().getCause().getMessage(), containsString("boom"));

// item 1 is a success
assertNull(bulkShardRequest.items()[1].getPrimaryResponse());
IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request());
assertThat(XContentMapValues.extractValue("field1.text", actualRequest.sourceAsMap()), equalTo("I am a success"));

// item 2 is a failure
assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse());
assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed());
failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure();
assertThat(failure.getCause().getCause().getMessage(), containsString("boom"));
} finally {
chainExecuted.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"field1",
new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" })
);
BulkItemRequest[] items = new BulkItemRequest[3];
items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure"));
items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success"));
items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure"));
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testManyRandomDocs() throws Exception {
Map<String, StaticModel> inferenceModelMap = new HashMap<>();
Expand Down Expand Up @@ -282,7 +339,7 @@ private static BulkItemRequest[] randomBulkItemRequest(
continue;
}
var result = randomSemanticText(field, model, List.of(text), requestContentType);
model.putResult(text, result);
model.putResult(text, toChunkedResult(result));
expectedDocMap.put(field, result);
}
return new BulkItemRequest[] {
Expand All @@ -304,7 +361,7 @@ private static StaticModel randomStaticModel() {
}

private static class StaticModel extends TestModel {
private final Map<String, SemanticTextField> resultMap;
private final Map<String, ChunkedInferenceServiceResults> resultMap;

StaticModel(
String inferenceEntityId,
Expand All @@ -319,14 +376,10 @@ private static class StaticModel extends TestModel {
}

ChunkedInferenceServiceResults getResults(String text) {
SemanticTextField result = resultMap.get(text);
if (result == null) {
return new ChunkedSparseEmbeddingResults(List.of());
}
return toChunkedResult(result);
return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of()));
}

void putResult(String text, SemanticTextField result) {
void putResult(String text, ChunkedInferenceServiceResults result) {
resultMap.put(text, result);
}
}
Expand Down

0 comments on commit 3bdefa9

Please sign in to comment.