diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 592d5d4fad..f3da7c77bc 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -21,9 +21,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -33,16 +35,19 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.utils.MapUtils; +import org.opensearch.ml.utils.SearchResponseUtil; import org.opensearch.search.SearchHit; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -125,9 +130,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp /** * Processes the search response asynchronously by rewriting the documents with the inference results. * - * @param request the search request - * @param response the search response - * @param responseContext the pipeline processing context + * By default, it processes multiple documents in a single prediction through the rewriteResponseDocuments method. + * However, when processing one document per inference, it separates the N-hits search response into N one-hit search responses, + * executes the same rewriteResponseDocument method for each one-hit search response, + * and after receiving N one-hit search responses with inference results, + * it combines them back into a single N-hits search response. + * + * @param request the search request + * @param response the search response + * @param responseContext the pipeline processing context * @param responseListener the listener to be notified when the response is processed */ @Override @@ -144,20 +155,130 @@ public void processResponseAsync( responseListener.onResponse(response); return; } - rewriteResponseDocuments(response, responseListener); + + // if many to one, run rewriteResponseDocuments + if (!oneToOne) { + rewriteResponseDocuments(response, responseListener); + } else { + // if one to one, make one hit search response and run rewriteResponseDocuments + GroupedActionListener combineResponseListener = getCombineResponseGroupedActionListener( + response, + responseListener, + hits + ); + AtomicBoolean isOneHitListenerFailed = new AtomicBoolean(false); + ; + for (SearchHit hit : hits) { + SearchHit[] newHits = new SearchHit[1]; + newHits[0] = hit; + SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response); + ActionListener oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed); + rewriteResponseDocuments(oneHitResponse, oneHitListener); + // if any OneHitListener failure, try stop the rest of the predictions + if (isOneHitListenerFailed.get()) { + break; + } + } + } + } catch (Exception e) { if (ignoreFailure) { responseListener.onResponse(response); } else { responseListener.onFailure(e); + if (e instanceof OpenSearchStatusException) { + responseListener + .onFailure( + new OpenSearchStatusException( + "Failed to process response: " + e.getMessage(), + RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus()) + ) + ); + } else if (e instanceof MLResourceNotFoundException) { + responseListener + .onFailure(new OpenSearchStatusException("Failed to process response: " + e.getMessage(), RestStatus.NOT_FOUND)); + } else { + responseListener.onFailure(e); + } } } } + /** + * Creates an ActionListener for a single SearchResponse that delegates its + * onResponse and onFailure callbacks to a GroupedActionListener. + * + * @param combineResponseListener The GroupedActionListener to which the + * onResponse and onFailure callbacks will be + * delegated. + * @param isOneHitListenerFailed + * @return An ActionListener that delegates its callbacks to the provided + * GroupedActionListener. + */ + private static ActionListener getOneHitListener( + GroupedActionListener combineResponseListener, + AtomicBoolean isOneHitListenerFailed + ) { + ActionListener oneHitListener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + combineResponseListener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + // if any OneHitListener failure, try stop the rest of the predictions and return + isOneHitListenerFailed.compareAndSet(false, true); + combineResponseListener.onFailure(e); + } + }; + return oneHitListener; + } + + /** + * Creates a GroupedActionListener that combines the SearchResponses from individual hits + * and constructs a new SearchResponse with the combined hits. + * + * @param response The original SearchResponse containing the hits to be processed. + * @param responseListener The ActionListener to be notified with the combined SearchResponse. + * @param hits The array of SearchHits to be processed. + * @return A GroupedActionListener that combines the SearchResponses and constructs a new SearchResponse. + */ + private GroupedActionListener getCombineResponseGroupedActionListener( + SearchResponse response, + ActionListener responseListener, + SearchHit[] hits + ) { + GroupedActionListener combineResponseListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection responseMapCollection) { + SearchHit[] combinedHits = new SearchHit[hits.length]; + int i = 0; + for (SearchResponse OneHitResponseAfterInference : responseMapCollection) { + SearchHit[] hitsAfterInference = OneHitResponseAfterInference.getHits().getHits(); + combinedHits[i] = hitsAfterInference[0]; + i++; + } + SearchResponse oneToOneInferenceSearchResponse = SearchResponseUtil.replaceHits(combinedHits, response); + responseListener.onResponse(oneToOneInferenceSearchResponse); + } + + @Override + public void onFailure(Exception e) { + if (ignoreFailure) { + responseListener.onResponse(response); + } else { + responseListener.onFailure(e); + } + } + }, hits.length); + return combineResponseListener; + } + /** * Rewrite the documents in the search response with the inference results. * - * @param response the search response + * @param response the search response * @param responseListener the listener to be notified when the response is processed * @throws IOException if an I/O error occurs during the rewriting process */ @@ -168,27 +289,23 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener hitCountInPredictions = new HashMap<>(); - if (!oneToOne) { - ActionListener> rewriteResponseListener = createRewriteResponseListenerManyToOne( - response, - responseListener, - processInputMap, - processOutputMap, - hitCountInPredictions - ); - GroupedActionListener> batchPredictionListener = createBatchPredictionListenerManyToOne( - rewriteResponseListener, - inputMapSize - ); - SearchHit[] hits = response.getHits().getHits(); - for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) { - processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions); - } - } else { - responseListener.onFailure(new IllegalArgumentException("one to one prediction is not supported yet.")); - } + ActionListener> rewriteResponseListener = createRewriteResponseListener( + response, + responseListener, + processInputMap, + processOutputMap, + hitCountInPredictions + ); + GroupedActionListener> batchPredictionListener = createBatchPredictionListener( + rewriteResponseListener, + inputMapSize + ); + SearchHit[] hits = response.getHits().getHits(); + for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) { + processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions); + } } /** @@ -201,7 +318,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener> processInputMap, int inputMapIndex, @@ -242,7 +359,7 @@ private void processPredictionsManyToOne( Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName); if (documentValue != null) { // when not existed in the map, add into the modelInputParameters map - updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue); + updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue); } } } else { // when document does not contain the documentFieldName, skip when ignoreMissing @@ -263,8 +380,7 @@ private void processPredictionsManyToOne( Object documentValue = entry.getValue(); // when not existed in the map, add into the modelInputParameters map - updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue); - + updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue); } } } @@ -306,18 +422,28 @@ public void onFailure(Exception e) { }); } - private void updateModelInputParametersManyToOne( - Map modelInputParameters, - String modelInputFieldName, - Object documentValue - ) { - if (!modelInputParameters.containsKey(modelInputFieldName)) { - List documentValueList = new ArrayList<>(); - documentValueList.add(documentValue); - modelInputParameters.put(modelInputFieldName, documentValueList); + /** + * Updates the model input parameters map with the given document value. + * If the setting is one-to-one, + * simply put the document value in the map + * If the setting is many-to-one, + * create a new list and add the document value + * @param modelInputParameters The map containing the model input parameters. + * @param modelInputFieldName The name of the model input field. + * @param documentValue The value from the document that needs to be added to the model input parameters. + */ + private void updateModelInputParameters(Map modelInputParameters, String modelInputFieldName, Object documentValue) { + if (!this.oneToOne) { + if (!modelInputParameters.containsKey(modelInputFieldName)) { + List documentValueList = new ArrayList<>(); + documentValueList.add(documentValue); + modelInputParameters.put(modelInputFieldName, documentValueList); + } else { + List valueList = ((List) modelInputParameters.get(modelInputFieldName)); + valueList.add(documentValue); + } } else { - List valueList = ((List) modelInputParameters.get(modelInputFieldName)); - valueList.add(documentValue); + modelInputParameters.put(modelInputFieldName, documentValue); } } @@ -328,7 +454,7 @@ private void updateModelInputParametersManyToOne( * @param inputMapSize the size of the input map * @return a grouped action listener for batch predictions */ - private GroupedActionListener> createBatchPredictionListenerManyToOne( + private GroupedActionListener> createBatchPredictionListener( ActionListener> rewriteResponseListener, int inputMapSize ) { @@ -353,14 +479,14 @@ public void onFailure(Exception e) { /** * Creates an action listener for rewriting the response with the inference results. * - * @param response the search response - * @param responseListener the listener to be notified when the response is processed - * @param processInputMap the list of input mappings - * @param processOutputMap the list of output mappings - * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction + * @param response the search response + * @param responseListener the listener to be notified when the response is processed + * @param processInputMap the list of input mappings + * @param processOutputMap the list of output mappings + * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction * @return an action listener for rewriting the response with the inference results */ - private ActionListener> createRewriteResponseListenerManyToOne( + private ActionListener> createRewriteResponseListener( SearchResponse response, ActionListener responseListener, List> processInputMap, @@ -392,7 +518,7 @@ public void onResponse(Map multipleMLOutputs) { Map outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap); boolean isModelInputMissing = false; - if (processInputMap != null) { + if (processInputMap != null && !processInputMap.isEmpty()) { isModelInputMissing = checkIsModelInputMissing(document, inputMapping); } if (!isModelInputMissing) { @@ -499,10 +625,10 @@ private boolean checkIsModelInputMissing(Map document, MapIf the processOutputMap is not null and not empty, the mapping at the specified mappingIndex * is returned. * - * @param mappingIndex the index of the mapping to retrieve from the processOutputMap + * @param mappingIndex the index of the mapping to retrieve from the processOutputMap * @param processOutputMap the list of output mappings, can be null or empty * @return a Map containing the output mapping, either the default mapping or the mapping at the - * specified index + * specified index */ private static Map getDefaultOutputMapping(Integer mappingIndex, List> processOutputMap) { Map outputMapping; @@ -524,11 +650,11 @@ private static Map getDefaultOutputMapping(Integer mappingIndex, *

If the processInputMap is not null and not empty, the mapping at the specified mappingIndex * is returned. * - * @param sourceAsMap the source map containing the input data - * @param mappingIndex the index of the mapping to retrieve from the processInputMap + * @param sourceAsMap the source map containing the input data + * @param mappingIndex the index of the mapping to retrieve from the processInputMap * @param processInputMap the list of input mappings, can be null or empty * @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or - * the mapping at the specified index + * the mapping at the specified index */ private static Map getDefaultInputMapping( Map sourceAsMap, diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 7d38597751..850e466ba6 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -7,12 +7,21 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; -import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.*; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.DEFAULT_MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.DEFAULT_OUTPUT_FIELD_NAME; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FULL_RESPONSE_PATH; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FUNCTION_NAME; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.MODEL_INPUT; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.TYPE; import java.util.ArrayList; import java.util.Arrays; @@ -20,23 +29,28 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.client.Client; import org.opensearch.common.document.DocumentField; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -52,10 +66,8 @@ public class MLInferenceSearchResponseProcessorTests extends AbstractBuilderTestCase { @Mock private Client client; - @Mock private PipelineProcessingContext responseContext; - static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() ); @@ -82,44 +94,1516 @@ public void testProcessResponseException() throws Exception { null, false, false, - false + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + try { + responseProcessor.processResponse(request, response); + + } catch (Exception e) { + assertEquals("ML inference search response processor make asynchronous calls and does not call processRequest", e.getMessage()); + } + } + + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseSuccess() throws Exception { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests create processor with one_to_one is true + * with custom prompt + * with many to one prediction, 5 documents in hits are calling 1 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { + + String newDocumentField = "context"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${input_map.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"prompt\": \"${model_config.prompt}\"}", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there are 5 values")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField).toString(), + "there are 5 values" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField).toString(), + "there are 5 values" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField).toString(), + "there are 5 values" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField).toString(), + "there are 5 values" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField).toString(), + "there are 5 values" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with no mapping provided + * with one to one prediction, 5 documents in hits are calling 5 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithNoMappings() throws Exception { + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with empty mapping provided + * with one to one prediction, 5 documents in hits are calling 5 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithEmptyMappings() throws Exception { + List> outputMap = new ArrayList<>(); + List> inputMap = new ArrayList<>(); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, 5 documents in hits are calling 5 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappings() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, the only one prediction task onFailure + * expect to run one prediction task + * when there is one document, the combinedResponseListener calls onFailure + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerFail() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Prediction Failed", e.getMessage()); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, the only one prediction task throw Exceptions + * expect to run one prediction task + * when there is one document, the combinedResponseListener calls onFailure + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerException() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + when(client.execute(any(), any())).thenThrow(new RuntimeException("Prediction Failed")); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Failed to process response: Prediction Failed", e.getMessage()); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, the only one prediction task throw Exceptions + * expect to run one prediction task + * when there is one document and ignoreFailure, should return the original response + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerExceptionIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + when(client.execute(any(), any())).thenThrow(new RuntimeException("Prediction Failed")); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits(), response.getHits().getHits()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * createRewriteResponseListener throw Exceptions + * expect to run one prediction task + * when there is one document and ignoreFailure, should return the original response + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + RuntimeException mockException = new RuntimeException("Mock exception"); + AtomicInteger callCount = new AtomicInteger(0); + ; + when(mockResponse.getHits()).thenAnswer(invocation -> { + + int count = callCount.getAndIncrement(); + + if (count == 2) { + // throw exception when it reaches createRewriteResponseListener + throw mockException; + } else { + return searchHits; + } + }); + + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertNotNull(newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + + }; + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * createRewriteResponseListener throw Exceptions + * expect to run one prediction task + * createRewriteResponseListener should reach on Failure + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseCreateRewriteResponseListenerException() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + RuntimeException mockException = new RuntimeException("Mock exception"); + AtomicInteger callCount = new AtomicInteger(0); + ; + when(mockResponse.getHits()).thenAnswer(invocation -> { + + int count = callCount.getAndIncrement(); + + if (count == 2) { + // throw exception when it reaches createRewriteResponseListener + throw mockException; + } else { + return searchHits; + } + }); + + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e.getMessage()); + } + }; + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * test throwing OpenSearchStatusException + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOpenSearchStatusException() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + RuntimeException mockException = new OpenSearchStatusException("Mock exception", RestStatus.BAD_REQUEST); + AtomicInteger callCount = new AtomicInteger(0); + ; + when(mockResponse.getHits()).thenAnswer(invocation -> { + + int count = callCount.getAndIncrement(); + + if (count == 0) { + // throw exception when it reaches processResponseAsync + throw mockException; + } else { + return searchHits; + } + }); + + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e.getMessage()); + } + }; + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + verify(client, times(0)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * test throwing MLResourceNotFoundException + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseMLResourceNotFoundException() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + RuntimeException mockException = new MLResourceNotFoundException("Mock exception"); + AtomicInteger callCount = new AtomicInteger(0); + ; + when(mockResponse.getHits()).thenAnswer(invocation -> { + + int count = callCount.getAndIncrement(); + + if (count == 0) { + // throw exception when it reaches processResponseAsync + throw mockException; + } else { + return searchHits; + } + }); + + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e.getMessage()); + } + }; + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + verify(client, times(0)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, the only one prediction task throw exception + * expect to run one prediction task + * when there is one document, the combinedResponseListener calls onFailure + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(1, true, fieldName); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits(), response.getHits().getHits()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, the only one prediction task throw exception + * expect to run one prediction task + * when there is one document, the combinedResponseListener calls onFailure + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsMLTaskResponseExceptionIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + MLTaskResponse mockMLTaskResponse = mock(MLTaskResponse.class); + when(mockMLTaskResponse.getOutput()).thenThrow(new RuntimeException("get mlTaskResponse failed.")); + SearchResponse response = getSearchResponse(1, true, fieldName); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(mockMLTaskResponse); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits(), response.getHits().getHits()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, when one of 5 prediction tasks failed, + * expect to run one prediction task and the rest 4 predictions tasks are not created + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsPredictException() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + when(client.execute(any(), any())).thenThrow(new RuntimeException("Prediction Failed")); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + + assertEquals("Failed to process response: Prediction Failed", e.getMessage()); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, when one of 5 prediction tasks failed, + * expect to run one prediction task and the rest 4 predictions tasks are not created + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsPredictFail() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + + assertEquals("Failed to process response: Prediction Failed", e.getMessage()); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps + * with one to one prediction, prediction tasks throw exception + * ignore Failure is true + * when ignoreFailure, will run all 5 prediction tasks + * then return original response + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneWithOutputMappingsPredictFailIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits(), response.getHits().getHits()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with two rounds predictions for every document + * with one to one prediction, 5 documents in hits are calling 10 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneTwoRoundsPredictions() throws Exception { + + String modelInputField = "inputs"; + String modelOutputField = "response"; + + // document fields for first round of prediction + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + + // document fields for second round of prediction + String originalDocumentField1 = "image"; + String newDocumentField1 = "image_embedding"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + Map input1 = new HashMap<>(); + input1.put(modelInputField, originalDocumentField1); + inputMap.add(input1); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + Map output2 = new HashMap<>(); + output2.put(newDocumentField1, modelOutputField); + outputMap.add(output2); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + + SearchResponse response = getSearchResponseTwoFields(5, true, originalDocumentField, originalDocumentField1); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(10)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is true + * with two rounds predictions for every document + * failed in first round prediction when ignoreFailure is false + * expect to throw exception without further processing + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneTwoRoundsPredictionsOneException() throws Exception { + + String modelInputField = "inputs"; + String modelOutputField = "response"; + + // document fields for first round of prediction + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + + // document fields for second round of prediction + String originalDocumentField1 = "image"; + String newDocumentField1 = "image_embedding"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + Map input1 = new HashMap<>(); + input1.put(modelInputField, originalDocumentField1); + inputMap.add(input1); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + Map output2 = new HashMap<>(); + output2.put(newDocumentField1, modelOutputField); + outputMap.add(output2); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true ); + SearchRequest request = getSearchRequest(); - String fieldName = "text"; - SearchResponse response = getSearchResponse(5, true, fieldName); + // create a search response with a typo in the document field + SearchResponse response = getSearchResponseTwoFields(5, true, originalDocumentField + "typo", originalDocumentField1); - try { - responseProcessor.processResponse(request, response); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); - } catch (Exception e) { - assertEquals("ML inference search response processor make asynchronous calls and does not call processRequest", e.getMessage()); - } + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals( + "cannot find all required input fields: [text] in hit:{\n" + + " \"_id\" : \"doc 0\",\n" + + " \"_score\" : 0.0,\n" + + " \"_source\" : {\n" + + " \"texttypo\" : \"value 0\",\n" + + " \"image\" : \"value 0\"\n" + + " }\n" + + "}", + e.getMessage() + ); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(0)).execute(any(), any(), any()); } /** - * Tests the successful processing of a response with a single pair of input and output mappings. - * + * Tests create processor with one_to_one is true + * with two rounds predictions for every document + * failed in first round prediction when ignoreMissing is true + * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ - public void testProcessResponseSuccess() throws Exception { + public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreMissing() throws Exception { + String modelInputField = "inputs"; + String modelOutputField = "response"; + + // document fields for first round of prediction String originalDocumentField = "text"; String newDocumentField = "text_embedding"; - String modelOutputField = "response"; - MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( - modelOutputField, - modelInputField, - originalDocumentField, - newDocumentField, + + // document fields for second round of prediction + String originalDocumentField1 = "image"; + String newDocumentField1 = "image_embedding"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + Map input1 = new HashMap<>(); + input1.put(modelInputField, originalDocumentField1); + inputMap.add(input1); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + Map output2 = new HashMap<>(); + output2.put(newDocumentField1, modelOutputField); + outputMap.add(output2); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + true, + "remote", false, false, - false + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true ); - assertEquals(responseProcessor.getType(), TYPE); SearchRequest request = getSearchRequest(); - String fieldName = "text"; - SearchResponse response = getSearchResponse(5, true, fieldName); + // create a search response with a typo in the document field + SearchResponse response = getSearchResponseTwoFields(5, true, originalDocumentField + "typo", originalDocumentField1); ModelTensor modelTensor = ModelTensor .builder() @@ -138,41 +1622,88 @@ public void testProcessResponseSuccess() throws Exception { @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); - assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); - assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); - assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); - assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1).toString(), + "[0.0, 1.0, 2.0, 3.0, 4.0]" + ); } @Override public void onFailure(Exception e) { throw new RuntimeException(e); } - }; + }; responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(10)).execute(any(), any(), any()); } /** - * Tests create processor with many_to_one is false - * + * Tests create processor with one_to_one is true + * with two rounds predictions for every document + * failed in first round prediction when ignoreFailure is true + * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ - public void testProcessResponseOneToOneException() throws Exception { + public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreFailure() throws Exception { + + String modelInputField = "inputs"; + String modelOutputField = "response"; + + // document fields for first round of prediction + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + + // document fields for second round of prediction + String originalDocumentField1 = "image"; + String newDocumentField1 = "image_embedding"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + Map input1 = new HashMap<>(); + input1.put(modelInputField, originalDocumentField1); + inputMap.add(input1); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + Map output2 = new HashMap<>(); + output2.put(newDocumentField1, modelOutputField); + outputMap.add(output2); MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", - null, - null, + inputMap, + outputMap, null, DEFAULT_MAX_PREDICTION_TASKS, PROCESSOR_TAG, DESCRIPTION, false, "remote", - true, false, + true, false, "{ \"parameters\": ${ml_inference.parameters} }", client, @@ -181,8 +1712,8 @@ public void testProcessResponseOneToOneException() throws Exception { ); SearchRequest request = getSearchRequest(); - String fieldName = "text"; - SearchResponse response = getSearchResponse(5, true, fieldName); + // create a search response with a typo in the document field + SearchResponse response = getSearchResponseTwoFields(5, true, originalDocumentField + "typo", originalDocumentField1); ModelTensor modelTensor = ModelTensor .builder() @@ -200,17 +1731,22 @@ public void testProcessResponseOneToOneException() throws Exception { ActionListener listener = new ActionListener<>() { @Override public void onResponse(SearchResponse newSearchResponse) { - throw new RuntimeException("error handling not properly"); + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertNull(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1)); + assertNull(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1)); + assertNull(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1)); + assertNull(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1)); + assertNull(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1)); } @Override public void onFailure(Exception e) { - assertEquals("one to one prediction is not supported yet.", e.getMessage()); + throw new RuntimeException(e); } }; responseProcessor.processResponseAsync(request, response, responseContext, listener); - + verify(client, times(0)).execute(any(), any(), any()); } /** @@ -462,7 +1998,93 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ - public void testProcessResponseOverrideSameField() throws Exception { + public void testProcessResponseOverrideSameField() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "text" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text":[0.1, 0.2]}, + * { "text":[0.2, 0.2]}, + * { "text":[0.3, 0.2]}, + * { "text":[0.4, 0.2]}, + * { "text":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + true, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays + .asList( + Arrays.asList(0.1, 0.2), + Arrays.asList(0.2, 0.2), + Arrays.asList(0.3, 0.2), + Arrays.asList(0.4, 0.2), + Arrays.asList(0.5, 0.2) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text"), Arrays.asList(0.1, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text"), Arrays.asList(0.2, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text"), Arrays.asList(0.3, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text"), Arrays.asList(0.4, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text"), Arrays.asList(0.5, 0.2)); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response where the existing document field is skipped. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOverrideSameFieldFalse() throws Exception { /** * sample response before inference * { @@ -490,7 +2112,7 @@ public void testProcessResponseOverrideSameField() throws Exception { modelInputField, originalDocumentField, newDocumentField, - true, + false, false, false ); @@ -528,11 +2150,11 @@ public void testProcessResponseOverrideSameField() throws Exception { @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text"), Arrays.asList(0.1, 0.2)); - assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text"), Arrays.asList(0.2, 0.2)); - assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text"), Arrays.asList(0.3, 0.2)); - assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text"), Arrays.asList(0.4, 0.2)); - assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text"), Arrays.asList(0.5, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text"), "value 0"); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text"), "value 1"); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text"), "value 2"); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text"), "value 3"); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text"), "value 4"); } @Override @@ -540,8 +2162,8 @@ public void onFailure(Exception e) { throw new RuntimeException(e); } }; - responseProcessor.processResponseAsync(request, response, responseContext, listener); + } /** @@ -627,7 +2249,7 @@ public void onFailure(Exception e) { } /** - * Tests the case where one input field is missing, and an exception is expected + * Tests the case where one input field is missing, and an IllegalArgumentException is expected * when the `ignoreMissing` flag is set to false. * * @throws Exception if an error occurs during the test @@ -923,6 +2545,51 @@ public void testProcessResponsePredictionException() throws Exception { String fieldName = "text"; SearchResponse response = getSearchResponse(5, true, fieldName); + when(client.execute(any(), any())).thenThrow(new RuntimeException("Prediction Failed")); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly."); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Prediction Failed", e.getMessage()); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where an onFailure occurs during prediction, and the `ignoreFailure` flag is set to false. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponsePredictionFailed() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new RuntimeException("Prediction Failed")); @@ -1038,6 +2705,252 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } + /** + * Tests the case where there are no hits in the search response. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseHitWithNoSource() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponseNoSource(0, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(response, newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps in one to one prediction + * Exceptions happen when replaceHits to be one Hit Response + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneMadeOneHitResponseExceptions() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly."); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Mock exception", e.getMessage()); + } + }; + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + when(mockResponse.getHits()).thenReturn(searchHits); + RuntimeException mockException = new RuntimeException("Mock exception"); + AtomicInteger callCount = new AtomicInteger(0); + when(mockResponse.getAggregations()).thenAnswer(invocation -> { + int count = callCount.getAndIncrement(); + if (count < 6) { + return null; + } else { + throw mockException; + } + }); + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps in one to one prediction + * Exceptions happen when replaceHits and ignoreFailure return original response + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneMadeOneHitResponseExceptionsIgnoreFailure() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits(), response.getHits().getHits()); + assertNull(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField)); + assertNull(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField)); + assertNull(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField)); + assertNull(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField)); + assertNull(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField)); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly."); + } + + }; + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + when(mockResponse.getHits()).thenReturn(searchHits); + RuntimeException mockException = new RuntimeException("Mock exception"); + when(mockResponse.getAggregations()).thenThrow(mockException); + + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + } + + /** + * Tests create processor with one_to_one is true + * with output_maps in one to one prediction + * Exceptions happen when replaceHits + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneCombinedHitsExceptions() throws Exception { + + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly."); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Failed to process response: Mock exception", e.getMessage()); + } + }; + + SearchResponse mockResponse = mock(SearchResponse.class); + SearchHits searchHits = response.getHits(); + when(mockResponse.getHits()).thenReturn(searchHits); + RuntimeException mockException = new RuntimeException("Mock exception"); + AtomicInteger callCount = new AtomicInteger(0); + when(mockResponse.getAggregations()).thenAnswer(invocation -> { + int count = callCount.getAndIncrement(); + // every time getting one Hit Response will call get Aggregations two times , + // the 12th time is used for combine hits + if (count < 12) { + return null; + } else { + throw mockException; + } + }); + when(mockResponse.getTook()).thenReturn(TimeValue.timeValueNanos(10)); + responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); + } + private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); @@ -1115,6 +3028,23 @@ private SearchResponse getSearchResponse(int size, boolean includeMapping, Strin return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); } + private SearchResponse getSearchResponseNoSource(int size, boolean includeMapping, String fieldName) { + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + Map searchHitFields = new HashMap<>(); + if (includeMapping) { + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + } + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + hits[i] = new SearchHit(i, "doc " + i, null, Collections.emptyMap()); + hits[i].sourceRef(null); + hits[i].score(i); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + private SearchResponse getSearchResponseMissingField(int size, boolean includeMapping, String fieldName) { SearchHit[] hits = new SearchHit[size]; for (int i = 0; i < size; i++) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java index c67c2138c8..6c8f54110f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -194,7 +194,6 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); Map response = searchWithPipeline(client(), index_name, pipelineName, query); - System.out.println(response); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); @@ -205,6 +204,59 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); } + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and a string field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * runs one to one prediction by sending one document to one prediction + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelStringField() throws Exception { + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.bedrockEmbeddingModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"$.embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false,\n" + + " \"one_to_one\": true\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"diary\":{\"value\":\"happy\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.734375, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), 0.87109375, 0.005); + } + /** * Tests the MLInferenceSearchResponseProcessor with a remote model and a nested list field as input. * It creates a search pipeline with the processor configured to use the remote model, @@ -250,7 +302,6 @@ public void testMLInferenceProcessorRemoteModelNestedListField() throws Exceptio createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); Map response = searchWithPipeline(client(), index_name, pipelineName, query); - System.out.println(response); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); @@ -322,7 +373,6 @@ public void testMLInferenceProcessorLocalModel() throws Exception { String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"768\"}}}}"; Map response = searchWithPipeline(client(), index_name, pipelineName, query); - System.out.println(response); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "768"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "bored");