Skip to content

Commit

Permalink
Add test for inference id not found
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 5, 2024
1 parent d97a043 commit 9c8cd37
Showing 1 changed file with 61 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@

public class BulkOperationTests extends ESTestCase {

public static final String INDEX_NAME = "test-index";
public static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id";
public static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id";
public static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1";
public static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1";
public static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2";
public static final String SERVICE_1_ID = "elser_v2";
public static final String SERVICE_2_ID = "e5";
private static final String INDEX_NAME = "test-index";
private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id";
private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id";
private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1";
private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1";
private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2";
private static final String SERVICE_1_ID = "elser_v2";
private static final String SERVICE_2_ID = "e5";
private static final String INFERENCE_FAILED_MSG = "Inference failed";
private static TestThreadPool threadPool;

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -187,11 +188,56 @@ public void testFailedInference() {
verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());
BulkItemResponse item = bulkResponse.getItems()[0];
assertTrue(item.isFailed());
assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG));

verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1));

}

public void testInferenceIdNotFound() {

Map<String, Set<String>> fieldsForModels = Map.of(
INFERENCE_SERVICE_1_ID,
Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1),
INFERENCE_SERVICE_2_ID,
Set.of(INFERENCE_FIELD_SERVICE_2)
);

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID, 1);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

String firstInferenceTextService1 = "firstInferenceTextService1";
Map<String, String> originalSource = Map.of(INFERENCE_FIELD_SERVICE_2, "text_for_service_2", "other_field", "other_value");

ArgumentCaptor<BulkResponse> bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class);
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
BulkShardRequest bulkShardRequest = runBulkOperation(
originalSource,
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener
);
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
assertNull(items[0]);
verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());
BulkItemResponse item = bulkResponse.getItems()[0];
assertTrue(item.isFailed());
assertThat(
item.getFailure().getCause().getMessage(),
equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID)
);
}

private static void checkInferenceResult(Map<String, Object> inferenceRootResultField, String fieldName, String expectedText) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> inferenceService1FieldResults = (List<Map<String, Object>>) inferenceRootResultField.get(fieldName);
Expand Down Expand Up @@ -328,7 +374,7 @@ private static InferenceService createInferenceServiceThatFails(Model model, Str
when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
listener.onFailure(new IllegalArgumentException("Inference failed"));
listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG));
return null;
}).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any());
return inferenceService;
Expand All @@ -353,6 +399,12 @@ private static InferenceServiceRegistry createInferenceServiceRegistry(Map<Strin

private static ModelRegistry createModelRegistry(Map<String, String> inferenceIdsToServiceIds) {
ModelRegistry modelRegistry = mock(ModelRegistry.class);
// Fails for unknown inference ids
doAnswer(invocation -> {
ActionListener<ModelRegistry.UnparsedModel> listener = invocation.getArgument(1);
listener.onFailure(new IllegalArgumentException("Model not found"));
return null;
}).when(modelRegistry).getModel(any(), any());
inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> {
ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel(
inferenceId,
Expand Down

0 comments on commit 9c8cd37

Please sign in to comment.