diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index b59dbbb86b..56e98474c7 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -14,7 +14,14 @@ import static org.opensearch.ml.processor.MLInferenceIngestProcessor.OVERRIDE; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; @@ -156,18 +163,27 @@ public void processResponseAsync( // if many to one, run rewriteResponseDocuments if (!oneToOne) { // use MLInferenceSearchResponseProcessor to allow writing to extension - MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( - null, - response.getInternalResponse(), - response.getScrollId(), - response.getTotalShards(), - response.getSuccessfulShards(), - response.getSkippedShards(), - response.getSuccessfulShards(), - response.getShardFailures(), - response.getClusters() - ); - rewriteResponseDocuments(mLInferenceSearchResponse, responseListener); + // check if the search response is in the type of MLInferenceSearchResponse + // if not, initiate a new one MLInferenceSearchResponse + MLInferenceSearchResponse mlInferenceSearchResponse; + + if (response instanceof MLInferenceSearchResponse) { + mlInferenceSearchResponse = (MLInferenceSearchResponse) response; + } else { + mlInferenceSearchResponse = new MLInferenceSearchResponse( + null, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + } + + rewriteResponseDocuments(mlInferenceSearchResponse, responseListener); } else { // if one to one, make one hit search response and run rewriteResponseDocuments GroupedActionListener combineResponseListener = getCombineResponseGroupedActionListener( diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index ad6f1db493..f462408943 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -3327,6 +3327,246 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); } + /** + * Tests the processResponseAsync method when the input is a regular SearchResponse. + * + * This test verifies that when a regular SearchResponse is passed to the method, + * it attempts to create a new MLInferenceSearchResponse object. + */ + @Test + public void testProcessResponseAsync_WithRegularSearchResponse() { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + assertEquals(responseAfterProcessor.getParams(), params); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + + /** + * Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse. + * + * This test verifies that when an MLInferenceSearchResponse is passed to the method, + * and the params is being passed over + */ + @Test + public void testProcessResponseAsync_WithMLInferenceSearchResponse() { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + assertEquals(responseAfterProcessor.getParams(), params); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + + /** + * Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse. + * + * This test verifies that when an MLInferenceSearchResponse is passed to the method, + * and the params is being passed over and new params is added + */ + @Test + public void testProcessResponseAsync_WriteExtensionToMLInferenceSearchResponse() { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.summary"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchResponse response = getSearchResponse(5, true, documentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + Map newParams = new HashMap<>(); + newParams.put("llm_response", "answer"); + newParams.put("summary", "there is 1 value"); + assertEquals(responseAfterProcessor.getParams(), newParams); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);