From 81e6fe833133fb81a147ef20fb469398e173daca Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Thu, 17 Oct 2024 11:28:52 -0700 Subject: [PATCH] add more tests Signed-off-by: Mingshi Liu --- .../MLInferenceSearchResponseProcessor.java | 19 +- ...InferenceSearchResponseProcessorTests.java | 44 +++++ .../MLInferenceSearchResponseTests.java | 181 ++++++++++++++++++ 3 files changed, 233 insertions(+), 11 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 158eaa545c..b59dbbb86b 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -14,13 +14,7 @@ import static org.opensearch.ml.processor.MLInferenceIngestProcessor.OVERRIDE; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; @@ -84,6 +78,8 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; + // allow to write to the extension of the search response, the path to point to search extension + // is prefix with ext.ml_inference public static final String EXTENSION_PREFIX = "ext.ml_inference"; protected MLInferenceSearchResponseProcessor( @@ -804,11 +800,12 @@ public MLInferenceSearchResponseProcessor create( } boolean writeToSearchExtension = false; - if (outputMaps != null - && outputMaps + if (outputMaps != null) { + writeToSearchExtension = outputMaps .stream() - .anyMatch(outputMap -> outputMap.keySet().stream().anyMatch(key -> key.startsWith(EXTENSION_PREFIX)))) { - writeToSearchExtension = true; + .filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps + .flatMap(outputMap -> outputMap.keySet().stream()) + .anyMatch(key -> key.startsWith(EXTENSION_PREFIX)); } if (writeToSearchExtension & oneToOne) { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 142fc6d995..ad6f1db493 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; @@ -87,6 +88,7 @@ public void setup() { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseException() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( @@ -115,6 +117,7 @@ public void testProcessResponseException() throws Exception { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseSuccess() throws Exception { String modelInputField = "inputs"; String originalDocumentField = "text"; @@ -174,6 +177,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { String newDocumentField = "context"; @@ -263,6 +267,7 @@ public void onFailure(Exception e) { * with many to one prediction, 5 documents in hits are calling 1 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { String documentField = "text"; @@ -360,6 +365,7 @@ public void onFailure(Exception e) { * with full response path false and no output mapping is provided * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPromptFullResponsePathFalse() throws Exception { String documentField = "text"; @@ -436,6 +442,7 @@ public void onFailure(Exception e) { * with full response path true and no output mapping is provided * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPromptFullResponsePathTrue() throws Exception { String documentField = "text"; @@ -511,6 +518,7 @@ public void onFailure(Exception e) { * with query extensions * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseSuccessWriteToExt() throws Exception { String documentField = "text"; String modelInputField = "context"; @@ -586,6 +594,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithNoMappings() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( @@ -666,6 +675,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithEmptyMappings() throws Exception { List> outputMap = new ArrayList<>(); List> inputMap = new ArrayList<>(); @@ -747,6 +757,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappings() throws Exception { String newDocumentField = "text_embedding"; @@ -836,6 +847,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerFail() throws Exception { String newDocumentField = "text_embedding"; @@ -897,6 +909,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerException() throws Exception { String newDocumentField = "text_embedding"; @@ -953,6 +966,7 @@ public void onFailure(Exception e) { * when there is one document and ignoreFailure, should return the original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1009,6 +1023,7 @@ public void onFailure(Exception e) { * when there is one document and ignoreFailure, should return the original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1099,6 +1114,7 @@ public void onFailure(Exception e) { * createRewriteResponseListener should reach on Failure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseCreateRewriteResponseListenerException() throws Exception { String newDocumentField = "text_embedding"; @@ -1185,6 +1201,7 @@ public void onFailure(Exception e) { * test throwing OpenSearchStatusException * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOpenSearchStatusException() throws Exception { String newDocumentField = "text_embedding"; @@ -1268,6 +1285,7 @@ public void onFailure(Exception e) { * test throwing MLResourceNotFoundException * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseMLResourceNotFoundException() throws Exception { String newDocumentField = "text_embedding"; @@ -1353,6 +1371,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1414,6 +1433,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsMLTaskResponseExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1476,6 +1496,7 @@ public void onFailure(Exception e) { * expect to run one prediction task and the rest 4 predictions tasks are not created * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictException() throws Exception { String newDocumentField = "text_embedding"; @@ -1532,6 +1553,7 @@ public void onFailure(Exception e) { * expect to run one prediction task and the rest 4 predictions tasks are not created * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictFail() throws Exception { String newDocumentField = "text_embedding"; @@ -1594,6 +1616,7 @@ public void onFailure(Exception e) { * then return original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictFailIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1653,6 +1676,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 10 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictions() throws Exception { String modelInputField = "inputs"; @@ -1783,6 +1807,7 @@ public void onFailure(Exception e) { * expect to throw exception without further processing * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneException() throws Exception { String modelInputField = "inputs"; @@ -1883,6 +1908,7 @@ public void onFailure(Exception e) { * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreMissing() throws Exception { String modelInputField = "inputs"; @@ -1993,6 +2019,7 @@ public void onFailure(Exception e) { * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreFailure() throws Exception { String modelInputField = "inputs"; @@ -2086,6 +2113,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseNoMappingSuccess() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2163,6 +2191,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseEmptyMappingSuccess() throws Exception { List> inputMap = new ArrayList<>(); Map input = new HashMap<>(); @@ -2243,6 +2272,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { /** * sample response before inference @@ -2330,6 +2360,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOverrideSameField() throws Exception { /** * sample response before inference @@ -2416,6 +2447,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOverrideSameFieldFalse() throws Exception { /** * sample response before inference @@ -2504,6 +2536,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSuccess() throws Exception { /** * sample response before inference @@ -2586,6 +2619,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsMissingOneInputException() throws Exception { /** * sample response before inference @@ -2670,6 +2704,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseTwoRoundsOfPredictionSuccess() throws Exception { String modelInputField = "inputs"; String modelOutputField = "response"; @@ -2767,6 +2802,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneModelInputMultipleModelOutputs() throws Exception { // one model input String modelInputField = "inputs"; @@ -2853,6 +2889,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionException() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2898,6 +2935,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionFailed() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2948,6 +2986,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionExceptionIgnoreFailure() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2998,6 +3037,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseEmptyHit() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -3042,6 +3082,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseHitWithNoSource() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -3087,6 +3128,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits to be one Hit Response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneMadeOneHitResponseExceptions() throws Exception { String newDocumentField = "text_embedding"; @@ -3154,6 +3196,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits and ignoreFailure return original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneMadeOneHitResponseExceptionsIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -3220,6 +3263,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneCombinedHitsExceptions() throws Exception { String newDocumentField = "text_embedding"; diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java new file mode 100644 index 0000000000..a50467e261 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInferenceSearchResponseTests extends OpenSearchTestCase { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + /** + * Tests the toXContent method of MLInferenceSearchResponse with non-null parameters. + * This test ensures that the method correctly serializes the response when parameters are present. + * + * @throws IOException if an I/O error occurs during the test + */ + @Test + public void testToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("key1", "value1"); + params.put("key2", "value2"); + + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + params, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } + + /** + * Tests the toXContent method of MLInferenceSearchResponse with null parameters. + * This test verifies that the method handles null parameters correctly during serialization. + * + * @throws IOException if an I/O error occurs during the test + */ + @Test + public void testToXContentWithNullParams() throws IOException { + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } + + /** + * Tests the getParams method of MLInferenceSearchResponse. + * This test ensures that the method correctly returns the parameters that were set during object creation. + */ + @Test + public void testGetParams() { + Map params = new HashMap<>(); + params.put("key1", "value1"); + params.put("key2", "value2"); + + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + params, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + assertEquals(params, searchResponse.getParams()); + } + + /** + * Tests the setParams method of MLInferenceSearchResponse. + * This test verifies that the method correctly updates the parameters of the response object. + */ + @Test + public void testSetParams() { + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + Map newParams = new HashMap<>(); + newParams.put("key3", "value3"); + searchResponse.setParams(newParams); + + assertEquals(newParams, searchResponse.getParams()); + } +}