Skip to content

Commit

Permalink
Add failing inference test
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 5, 2024
1 parent 0465118 commit d97a043
Showing 1 changed file with 65 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class BulkOperationTests extends ESTestCase {
Expand All @@ -88,9 +89,8 @@ public void testInference() {
Set.of(INFERENCE_FIELD_SERVICE_2)
);

ModelRegistry modelRegistry = createModelRegistry(Map.of(
INFERENCE_SERVICE_1_ID, SERVICE_1_ID,
INFERENCE_SERVICE_2_ID, SERVICE_2_ID)
ModelRegistry modelRegistry = createModelRegistry(
Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID)
);

Model model1 = mock(Model.class);
Expand All @@ -117,7 +117,15 @@ public void testInference() {
"yet_another_value"
);

BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry);
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
BulkShardRequest bulkShardRequest = runBulkOperation(
originalSource,
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener
);
verify(bulkOperationListener).onResponse(any());

BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
Expand Down Expand Up @@ -145,6 +153,45 @@ public void testInference() {
checkInferenceResult(inferenceRootResultField, INFERENCE_FIELD_SERVICE_2, inferenceTextService2);
}

public void testFailedInference() {

Map<String, Set<String>> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1));

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

String firstInferenceTextService1 = "firstInferenceTextService1";
Map<String, String> originalSource = Map.of(
FIRST_INFERENCE_FIELD_SERVICE_1,
firstInferenceTextService1,
"other_field",
"other_value"
);

ArgumentCaptor<BulkResponse> bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class);
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
BulkShardRequest bulkShardRequest = runBulkOperation(
originalSource,
fieldsForModels,
modelRegistry,
inferenceServiceRegistry,
bulkOperationListener
);
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
assertNull(items[0]);
verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());

verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1));

}

private static void checkInferenceResult(Map<String, Object> inferenceRootResultField, String fieldName, String expectedText) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> inferenceService1FieldResults = (List<Map<String, Object>>) inferenceRootResultField.get(fieldName);
Expand All @@ -165,6 +212,7 @@ private static void verifyInferenceServiceInvoked(
verify(modelRegistry).getModel(eq(inferenceService1Id), any());
verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap());
verify(inferenceService).infer(eq(model), argThat(containsAll(inferenceTexts)), anyMap(), eq(InputType.INGEST), any());
verifyNoMoreInteractions(inferenceService);
}

private static ArgumentMatcher<List<String>> containsAll(List<String> expected) {
Expand All @@ -185,7 +233,8 @@ private static BulkShardRequest runBulkOperation(
Map<String, String> docSource,
Map<String, Set<String>> fieldsForModels,
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry
InferenceServiceRegistry inferenceServiceRegistry,
ActionListener<BulkResponse> bulkOperationListener
) {
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build();
IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME)
Expand Down Expand Up @@ -233,8 +282,6 @@ private static BulkShardRequest runBulkOperation(
return null;
}).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any());

@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap());
BulkOperation bulkOperation = new BulkOperation(
task,
Expand All @@ -254,7 +301,6 @@ private static BulkShardRequest runBulkOperation(
);

bulkOperation.doRun();
verify(bulkOperationListener).onResponse(any());
verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any());

return bulkShardRequestCaptor.getValue();
Expand All @@ -277,6 +323,17 @@ private static InferenceService createInferenceService(Model model, String infer
return inferenceService;
}

private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) {
InferenceService inferenceService = mock(InferenceService.class);
when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
listener.onFailure(new IllegalArgumentException("Inference failed"));
return null;
}).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any());
return inferenceService;
}

private static InferenceResults createInferenceResults() {
InferenceResults inferenceResults = mock(InferenceResults.class);
when(inferenceResults.asMap(any())).then(
Expand Down

0 comments on commit d97a043

Please sign in to comment.