From fbefa0b7422d9ff7d1d7dd24365ef775997e4c5c Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 14 Feb 2024 19:11:56 +0100 Subject: [PATCH] Fix BulkOperationTests --- .../inference/ModelSettings.java | 8 +- .../action/bulk/BulkOperationTests.java | 143 ++++++++++-------- 2 files changed, 83 insertions(+), 68 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java index f11f5e67f80fe..154d4d34ba74d 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -21,10 +21,10 @@ public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) { public static final String NAME = "model_settings"; - private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index f8ed331d358b2..cbbe09ae90049 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -33,6 +33,9 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -56,6 +59,9 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -91,10 +97,10 @@ public void testNoInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -130,6 +136,26 @@ public void testNoInference() { verifyNoMoreInteractions(inferenceServiceRegistry); } + private static Model mockModel(String inferenceServiceId) { + Model model = mock(Model.class); + + when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); + TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; + when(model.getTaskType()).thenReturn(taskType); + + ServiceSettings serviceSettings = mock(ServiceSettings.class); + when(model.getServiceSettings()).thenReturn(serviceSettings); + SimilarityMeasure similarity = switch (randomInt(2)) { + case 0 -> SimilarityMeasure.COSINE; + case 1 -> SimilarityMeasure.DOT_PRODUCT; + default -> null; + }; + when(serviceSettings.similarity()).thenReturn(similarity); + when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); + + return model; + } + public void testFailedBulkShardRequest() { Map> fieldsForModels = Map.of(); @@ -191,10 +217,10 @@ public void testInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -257,8 +283,8 @@ public void testFailedInference() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); @@ -291,8 +317,8 @@ public void testInferenceFailsForIncorrectRootObject() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -315,39 +341,6 @@ public void testInferenceFailsForIncorrectRootObject() { assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); } - public void testInferenceFailsForIncorrectInferenceFieldObject() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value") - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object") - ); - } - public void testInferenceIdNotFound() { Map> fieldsForModels = Map.of( @@ -359,8 +352,8 @@ public void testInferenceIdNotFound() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceService(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -400,17 +393,20 @@ private static void checkInferenceResults( ); for (String inferenceFieldName : inferenceFieldNames) { - List> inferenceService1FieldResults = (List>) inferenceRootResultField.get( - inferenceFieldName - ); + Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(1)); - Map inferenceResultElement = inferenceService1FieldResults.get(0); - assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); - assertThat( - inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), - equalTo(docSource.get(inferenceFieldName)) + assertThat(inferenceService1FieldResults.size(), equalTo(2)); + Map modelSettings = (Map) inferenceService1FieldResults.get(ModelSettings.NAME); + assertNotNull(modelSettings); + assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName())); + assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName())); + + List> inferenceResultElement = (List>) inferenceService1FieldResults.get( + INFERENCE_RESULTS ); + assertFalse(inferenceResultElement.isEmpty()); + assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); + assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); } } @@ -421,8 +417,13 @@ private static void verifyInferenceServiceInvoked( Model model, Collection inferenceTexts ) { - verify(modelRegistry).getModel(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); + verify(inferenceService).parsePersistedConfigWithSecrets( + eq(inferenceService1Id), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ); verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); verifyNoMoreInteractions(inferenceService); } @@ -537,9 +538,16 @@ private static BulkShardRequest runBulkOperation( ); }; - private static InferenceService createInferenceService(Model model, String inferenceServiceId) { + private static InferenceService createInferenceService(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); @@ -556,9 +564,16 @@ private static InferenceService createInferenceService(Model model, String infer return inferenceService; } - private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) { + private static InferenceService createInferenceServiceThatFails(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); @@ -591,7 +606,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("Model not found")); return null; - }).when(modelRegistry).getModel(any(), any()); + }).when(modelRegistry).getModelWithSecrets(any(), any()); inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( inferenceId, @@ -604,7 +619,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onResponse(unparsedModel); return null; - }).when(modelRegistry).getModel(eq(inferenceId), any()); + }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); }); return modelRegistry;