Skip to content

Commit

Permalink
Fix BulkOperationTests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 14, 2024
1 parent af763f0 commit fbefa0b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) {

public static final String NAME = "model_settings";
private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
private static final ParseField SIMILARITY_FIELD = new ParseField("similarity");
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
public static final ParseField SIMILARITY_FIELD = new ParseField("similarity");

public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
Objects.requireNonNull(taskType, "task type must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -56,6 +59,9 @@
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
Expand Down Expand Up @@ -91,10 +97,10 @@ public void testNoInference() {
Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID)
);

Model model1 = mock(Model.class);
InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID);
Model model2 = mock(Model.class);
InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID);
Model model1 = mockModel(INFERENCE_SERVICE_1_ID);
InferenceService inferenceService1 = createInferenceService(model1);
Model model2 = mockModel(INFERENCE_SERVICE_2_ID);
InferenceService inferenceService2 = createInferenceService(model2);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(
Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2)
);
Expand Down Expand Up @@ -130,6 +136,26 @@ public void testNoInference() {
verifyNoMoreInteractions(inferenceServiceRegistry);
}

private static Model mockModel(String inferenceServiceId) {
Model model = mock(Model.class);

when(model.getInferenceEntityId()).thenReturn(inferenceServiceId);
TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING;
when(model.getTaskType()).thenReturn(taskType);

ServiceSettings serviceSettings = mock(ServiceSettings.class);
when(model.getServiceSettings()).thenReturn(serviceSettings);
SimilarityMeasure similarity = switch (randomInt(2)) {
case 0 -> SimilarityMeasure.COSINE;
case 1 -> SimilarityMeasure.DOT_PRODUCT;
default -> null;
};
when(serviceSettings.similarity()).thenReturn(similarity);
when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000));

return model;
}

public void testFailedBulkShardRequest() {

Map<String, Set<String>> fieldsForModels = Map.of();
Expand Down Expand Up @@ -191,10 +217,10 @@ public void testInference() {
Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID)
);

Model model1 = mock(Model.class);
InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID);
Model model2 = mock(Model.class);
InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID);
Model model1 = mockModel(INFERENCE_SERVICE_1_ID);
InferenceService inferenceService1 = createInferenceService(model1);
Model model2 = mockModel(INFERENCE_SERVICE_2_ID);
InferenceService inferenceService2 = createInferenceService(model2);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(
Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2)
);
Expand Down Expand Up @@ -257,8 +283,8 @@ public void testFailedInference() {

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID);
Model model = mockModel(INFERENCE_SERVICE_1_ID);
InferenceService inferenceService = createInferenceServiceThatFails(model);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100);
Expand Down Expand Up @@ -291,8 +317,8 @@ public void testInferenceFailsForIncorrectRootObject() {

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID);
Model model = mockModel(INFERENCE_SERVICE_1_ID);
InferenceService inferenceService = createInferenceServiceThatFails(model);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

Map<String, Object> originalSource = Map.of(
Expand All @@ -315,39 +341,6 @@ public void testInferenceFailsForIncorrectRootObject() {
assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object"));
}

public void testInferenceFailsForIncorrectInferenceFieldObject() {

Map<String, Set<String>> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1));

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

Map<String, Object> originalSource = Map.of(
FIRST_INFERENCE_FIELD_SERVICE_1,
randomAlphaOfLengthBetween(1, 100),
ROOT_INFERENCE_FIELD,
Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value")
);

ArgumentCaptor<BulkResponse> bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class);
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> bulkOperationListener = mock(ActionListener.class);
runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener);

verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture());
BulkResponse bulkResponse = bulkResponseCaptor.getValue();
assertTrue(bulkResponse.hasFailures());
BulkItemResponse item = bulkResponse.getItems()[0];
assertTrue(item.isFailed());
assertThat(
item.getFailure().getCause().getMessage(),
containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object")
);
}

public void testInferenceIdNotFound() {

Map<String, Set<String>> fieldsForModels = Map.of(
Expand All @@ -359,8 +352,8 @@ public void testInferenceIdNotFound() {

ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID));

Model model = mock(Model.class);
InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID);
Model model = mockModel(INFERENCE_SERVICE_1_ID);
InferenceService inferenceService = createInferenceService(model);
InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService));

Map<String, Object> originalSource = Map.of(
Expand Down Expand Up @@ -400,17 +393,20 @@ private static void checkInferenceResults(
);

for (String inferenceFieldName : inferenceFieldNames) {
List<Map<String, Object>> inferenceService1FieldResults = (List<Map<String, Object>>) inferenceRootResultField.get(
inferenceFieldName
);
Map<String, Object> inferenceService1FieldResults = (Map<String, Object>) inferenceRootResultField.get(inferenceFieldName);
assertNotNull(inferenceService1FieldResults);
assertThat(inferenceService1FieldResults.size(), equalTo(1));
Map<String, Object> inferenceResultElement = inferenceService1FieldResults.get(0);
assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME));
assertThat(
inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME),
equalTo(docSource.get(inferenceFieldName))
assertThat(inferenceService1FieldResults.size(), equalTo(2));
Map<String, Object> modelSettings = (Map<String, Object>) inferenceService1FieldResults.get(ModelSettings.NAME);
assertNotNull(modelSettings);
assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName()));
assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName()));

List<Map<String, Object>> inferenceResultElement = (List<Map<String, Object>>) inferenceService1FieldResults.get(
INFERENCE_RESULTS
);
assertFalse(inferenceResultElement.isEmpty());
assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS));
assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName)));
}
}

Expand All @@ -421,8 +417,13 @@ private static void verifyInferenceServiceInvoked(
Model model,
Collection<String> inferenceTexts
) {
verify(modelRegistry).getModel(eq(inferenceService1Id), any());
verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap());
verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any());
verify(inferenceService).parsePersistedConfigWithSecrets(
eq(inferenceService1Id),
eq(TaskType.SPARSE_EMBEDDING),
anyMap(),
anyMap()
);
verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any());
verifyNoMoreInteractions(inferenceService);
}
Expand Down Expand Up @@ -537,9 +538,16 @@ private static BulkShardRequest runBulkOperation(
);
};

private static InferenceService createInferenceService(Model model, String inferenceServiceId) {
private static InferenceService createInferenceService(Model model) {
InferenceService inferenceService = mock(InferenceService.class);
when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
when(
inferenceService.parsePersistedConfigWithSecrets(
eq(model.getInferenceEntityId()),
eq(TaskType.SPARSE_EMBEDDING),
anyMap(),
anyMap()
)
).thenReturn(model);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class);
Expand All @@ -556,9 +564,16 @@ private static InferenceService createInferenceService(Model model, String infer
return inferenceService;
}

private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) {
private static InferenceService createInferenceServiceThatFails(Model model) {
InferenceService inferenceService = mock(InferenceService.class);
when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model);
when(
inferenceService.parsePersistedConfigWithSecrets(
eq(model.getInferenceEntityId()),
eq(TaskType.SPARSE_EMBEDDING),
anyMap(),
anyMap()
)
).thenReturn(model);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(4);
listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG));
Expand Down Expand Up @@ -591,7 +606,7 @@ private static ModelRegistry createModelRegistry(Map<String, String> inferenceId
ActionListener<ModelRegistry.UnparsedModel> listener = invocation.getArgument(1);
listener.onFailure(new IllegalArgumentException("Model not found"));
return null;
}).when(modelRegistry).getModel(any(), any());
}).when(modelRegistry).getModelWithSecrets(any(), any());
inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> {
ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel(
inferenceId,
Expand All @@ -604,7 +619,7 @@ private static ModelRegistry createModelRegistry(Map<String, String> inferenceId
ActionListener<ModelRegistry.UnparsedModel> listener = invocation.getArgument(1);
listener.onResponse(unparsedModel);
return null;
}).when(modelRegistry).getModel(eq(inferenceId), any());
}).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any());
});

return modelRegistry;
Expand Down

0 comments on commit fbefa0b

Please sign in to comment.