Skip to content

Commit

Permalink
Add multiple fields to test
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 5, 2024
1 parent df5f799 commit 0465118
Showing 1 changed file with 141 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
Expand All @@ -63,51 +68,125 @@

public class BulkOperationTests extends ESTestCase {

public static final String INFERENCE_SERVICE_ID = "inferenece_service_id";
public static final String INDEX_NAME = "test-index";
public static final String INFERENCE_FIELD = "inference_field";
public static final String SERVICE_ID = "elser_v2";
public static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id";
public static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id";
public static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1";
public static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1";
public static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2";
public static final String SERVICE_1_ID = "elser_v2";
public static final String SERVICE_2_ID = "e5";
private static TestThreadPool threadPool;

@SuppressWarnings("unchecked")
public void testInference() {

Map<String, Set<String>> fieldsForModels = Map.of(INFERENCE_SERVICE_ID, Set.of(INFERENCE_FIELD));

ModelRegistry modelRegistry = createModelRegistry();
Map<String, Set<String>> fieldsForModels = Map.of(
INFERENCE_SERVICE_1_ID,
Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1),
INFERENCE_SERVICE_2_ID,
Set.of(INFERENCE_FIELD_SERVICE_2)
);

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceService(model);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(inferenceService);
ModelRegistry modelRegistry = createModelRegistry(Map.of(
INFERENCE_SERVICE_1_ID, SERVICE_1_ID,
INFERENCE_SERVICE_2_ID, SERVICE_2_ID)
);

String inferenceText = "test";
Map<String, String> originalSource = Map.of(INFERENCE_FIELD, inferenceText);
Model model1 = mock(Model.class);
InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID, 2);
Model model2 = mock(Model.class);
InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID, 1);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(
Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2)
);

BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels , modelRegistry, inferenceServiceRegistry);
String firstInferenceTextService1 = "firstInferenceTextService1";
String secondInferenceTextService1 = "secondInferenceTextService1";
String inferenceTextService2 = "inferenceTextService2";
Map<String, String> originalSource = Map.of(
FIRST_INFERENCE_FIELD_SERVICE_1,
firstInferenceTextService1,
SECOND_INFERENCE_FIELD_SERVICE_1,
secondInferenceTextService1,
INFERENCE_FIELD_SERVICE_2,
inferenceTextService2,
"other_field",
"other_value",
"yet_another_field",
"yet_another_value"
);

verifyInferenceDone(modelRegistry, inferenceService, model, inferenceText);
BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry);

BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(1));
Map<String, Object> docSource = ((IndexRequest) items[0].request()).sourceAsMap();
Map<String, Object> inferenceRootResultField = (Map<String, Object>) docSource.get(BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD);
List<Map<String, Object>> inferenceFieldResults = (List<Map<String, Object>>) inferenceRootResultField.get(INFERENCE_FIELD);
assertNotNull(inferenceFieldResults);
assertThat(inferenceFieldResults.size(), equalTo(1));
Map<String, Object> inferenceResultElement = inferenceFieldResults.get(0);

Map<String, Object> writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap();
// Original doc source is preserved
assertTrue(writtenDocSource.keySet().containsAll(originalSource.keySet()));
assertTrue(writtenDocSource.values().containsAll(originalSource.values()));

// Check inference results
verifyInferenceServiceInvoked(
modelRegistry,
INFERENCE_SERVICE_1_ID,
inferenceService1,
model1,
List.of(firstInferenceTextService1, secondInferenceTextService1)
);
verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2));
Map<String, Object> inferenceRootResultField = (Map<String, Object>) writtenDocSource.get(
BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD
);

checkInferenceResult(inferenceRootResultField, FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1);
checkInferenceResult(inferenceRootResultField, SECOND_INFERENCE_FIELD_SERVICE_1, secondInferenceTextService1);
checkInferenceResult(inferenceRootResultField, INFERENCE_FIELD_SERVICE_2, inferenceTextService2);
}

private static void checkInferenceResult(Map<String, Object> inferenceRootResultField, String fieldName, String expectedText) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> inferenceService1FieldResults = (List<Map<String, Object>>) inferenceRootResultField.get(fieldName);
assertNotNull(inferenceService1FieldResults);
assertThat(inferenceService1FieldResults.size(), equalTo(1));
Map<String, Object> inferenceResultElement = inferenceService1FieldResults.get(0);
assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME));
assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(inferenceText));
assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(expectedText));
}

private static void verifyInferenceDone(ModelRegistry modelRegistry, InferenceService inferenceService, Model model, String inferenceText) {
verify(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any());
verify(inferenceService).parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap());
verify(inferenceService).infer(eq(model), eq(List.of(inferenceText)), anyMap(), eq(InputType.INGEST), any());
private static void verifyInferenceServiceInvoked(
ModelRegistry modelRegistry,
String inferenceService1Id,
InferenceService inferenceService,
Model model,
List<String> inferenceTexts
) {
verify(modelRegistry).getModel(eq(inferenceService1Id), any());
verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap());
verify(inferenceService).infer(eq(model), argThat(containsAll(inferenceTexts)), anyMap(), eq(InputType.INGEST), any());
}

private static BulkShardRequest runBulkOperation(Map<String, String> docSource, Map<String, Set<String>> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry) {
private static ArgumentMatcher<List<String>> containsAll(List<String> expected) {
return new ArgumentMatcher<>() {
@Override
public boolean matches(List<String> argument) {
return argument.containsAll(expected) && argument.size() == expected.size();
}

@Override
public String toString() {
return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")";
}
};
}

private static BulkShardRequest runBulkOperation(
Map<String, String> docSource,
Map<String, Set<String>> fieldsForModels,
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry
) {
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build();
IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME)
.fieldsForModels(fieldsForModels)
Expand All @@ -133,6 +212,7 @@ private static BulkShardRequest runBulkOperation(Map<String, String> docSource,
BulkShardResponse bulkShardResponse = new BulkShardResponse(
request.shardId(),
Arrays.stream(request.items())
.filter(Objects::nonNull)
.map(
item -> BulkItemResponse.success(
item.id(),
Expand All @@ -151,8 +231,7 @@ private static BulkShardRequest runBulkOperation(Map<String, String> docSource,
);
bulkShardResponseListener.onResponse(bulkShardResponse);
return null;
}
).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any());
}).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any());

@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
Expand Down Expand Up @@ -181,46 +260,57 @@ private static BulkShardRequest runBulkOperation(Map<String, String> docSource,
return bulkShardRequestCaptor.getValue();
}

private static InferenceService createInferenceService(Model model) {
private static InferenceService createInferenceService(Model model, String inferenceServiceId, int numResults) {
InferenceService inferenceService = mock(InferenceService.class);
when(inferenceService.parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class);
List<InferenceResults> inferenceResults = new ArrayList<>();
for (int i = 0; i < numResults; i++) {
inferenceResults.add(createInferenceResults());
}
doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat();
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
listener.onResponse(inferenceServiceResults);
return null;
}).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any());
return inferenceService;
}

private static InferenceResults createInferenceResults() {
InferenceResults inferenceResults = mock(InferenceResults.class);
when(inferenceResults.asMap(any())).then(
invocation -> Map.of(
(String) invocation.getArguments()[0],
Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat())))
)
);
doReturn(List.of(inferenceResults)).when(inferenceServiceResults).transformToLegacyFormat();
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
listener.onResponse(inferenceServiceResults);
return null;
}).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any());
return inferenceService;
return inferenceResults;
}

private static InferenceServiceRegistry createInferenceServiceRegistry(InferenceService inferenceService) {
private static InferenceServiceRegistry createInferenceServiceRegistry(Map<String, InferenceService> inferenceServices) {
InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class);
when(inferenceServiceRegistry.getService(SERVICE_ID)).thenReturn(Optional.of(inferenceService));
inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service)));
return inferenceServiceRegistry;
}

private static ModelRegistry createModelRegistry() {
private static ModelRegistry createModelRegistry(Map<String, String> inferenceIdsToServiceIds) {
ModelRegistry modelRegistry = mock(ModelRegistry.class);
ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel(
INFERENCE_SERVICE_ID,
TaskType.SPARSE_EMBEDDING,
SERVICE_ID,
emptyMap(),
emptyMap()
);
doAnswer(invocation -> {
ActionListener<ModelRegistry.UnparsedModel> listener = invocation.getArgument(1);
listener.onResponse(unparsedModel);
return null;
}).when(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any());
inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> {
ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel(
inferenceId,
TaskType.SPARSE_EMBEDDING,
serviceId,
emptyMap(),
emptyMap()
);
doAnswer(invocation -> {
ActionListener<ModelRegistry.UnparsedModel> listener = invocation.getArgument(1);
listener.onResponse(unparsedModel);
return null;
}).when(modelRegistry).getModel(eq(inferenceId), any());
});

return modelRegistry;
}

Expand All @@ -240,7 +330,6 @@ private static ClusterService createClusterService(IndexMetadata indexMetadata)
return clusterService;
}


@BeforeClass
public static void createThreadPool() {
threadPool = new TestThreadPool(getTestClass().getName());
Expand Down

0 comments on commit 0465118

Please sign in to comment.