From 8381354fc77ffe74a07f0aadcca21fd4c73f1a63 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Tue, 1 Oct 2024 23:17:07 -0700 Subject: [PATCH] add MLInferenceSearchResponse Signed-off-by: Mingshi Liu --- .../processor/MLInferenceSearchResponse.java | 57 ++++++++ .../MLInferenceSearchResponseProcessor.java | 71 ++++++++-- ...InferenceSearchResponseProcessorTests.java | 130 +++++++++++++++++- 3 files changed, 239 insertions(+), 19 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java new file mode 100644 index 0000000000..1fb3b1471e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLInferenceSearchResponse extends SearchResponse { + private static final String EXT_SECTION_NAME = "ext"; + + private Map params; + + public MLInferenceSearchResponse( + Map params, + SearchResponseSections internalResponse, + String scrollId, + int totalShards, + int successfulShards, + int skippedShards, + long tookInMillis, + ShardSearchFailure[] shardFailures, + Clusters clusters + ) { + super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); + this.params = params; + } + + public void setParams(Map params) { + this.params = params; + } + + public Map getParams() { + return this.params; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + innerToXContent(builder, params); + + if (this.params != null) { + builder.startObject(EXT_SECTION_NAME); + builder.field(MLInferenceSearchResponseProcessor.TYPE, this.params); + + builder.endObject(); + } + builder.endObject(); + return builder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 2164877b9f..1574195153 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -84,6 +84,7 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; + public static final String EXTENSION_PREFIX = "ext.ml_inference"; protected MLInferenceSearchResponseProcessor( String modelId, @@ -158,7 +159,19 @@ public void processResponseAsync( // if many to one, run rewriteResponseDocuments if (!oneToOne) { - rewriteResponseDocuments(response, responseListener); + // use MLInferenceSearchResponseProcessor to allow writing to extension + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + null, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + rewriteResponseDocuments(mLInferenceSearchResponse, responseListener); } else { // if one to one, make one hit search response and run rewriteResponseDocuments GroupedActionListener combineResponseListener = getCombineResponseGroupedActionListener( @@ -545,22 +558,37 @@ public void onResponse(Map multipleMLOutputs) { } else { modelOutputValuePerDoc = modelOutputValue; } - - if (sourceAsMap.containsKey(newDocumentFieldName)) { - if (override) { - sourceAsMapWithInference.remove(newDocumentFieldName); - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + // writing to search response extension + if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) { + Map params = ((MLInferenceSearchResponse) response).getParams(); + String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", ""); + + if (params != null) { + params.put(paramsName, modelOutputValuePerDoc); + ((MLInferenceSearchResponse) response).setParams(params); } else { - logger - .debug( - "{} already exists in the search response hit. Skip processing this field.", - newDocumentFieldName - ); - // TODO when the response has the same field name, should it throw exception? currently, - // ingest processor quietly skip it + Map newParams = new HashMap<>(); + newParams.put(paramsName, modelOutputValuePerDoc); + ((MLInferenceSearchResponse) response).setParams(newParams); } } else { - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + // writing to search response hits + if (sourceAsMap.containsKey(newDocumentFieldName)) { + if (override) { + sourceAsMapWithInference.remove(newDocumentFieldName); + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } else { + logger + .debug( + "{} already exists in the search response hit. Skip processing this field.", + newDocumentFieldName + ); + // TODO when the response has the same field name, should it throw exception? currently, + // ingest processor quietly skip it + } + } else { + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } } } } @@ -774,6 +802,21 @@ public MLInferenceSearchResponseProcessor create( + ". Please adjust mappings." ); } + boolean writeToSearchExtension = false; + + if (outputMaps != null) { + for (Map outputMap : outputMaps) { + for (String key : outputMap.keySet()) { + if (key.startsWith(EXTENSION_PREFIX)) { + writeToSearchExtension = true; + break; + } + } + } + } + if (writeToSearchExtension & oneToOne) { + throw new IllegalArgumentException("Writing model response to search extension does not support when one_to_one is true."); + } return new MLInferenceSearchResponseProcessor( modelId, diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 8f04cab9d4..142fc6d995 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -21,6 +21,7 @@ import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FULL_RESPONSE_PATH; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FUNCTION_NAME; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.MODEL_INPUT; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.ONE_TO_ONE; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.TYPE; import java.util.ArrayList; @@ -60,6 +61,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.test.AbstractBuilderTestCase; @@ -503,6 +505,81 @@ public void onFailure(Exception e) { verify(client, times(1)).execute(any(), any(), any()); } + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * read the query text into model config + * with query extensions + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseSuccessWriteToExt() throws Exception { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + /** * Tests create processor with one_to_one is true * with no mapping provided @@ -978,14 +1055,18 @@ public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailu SearchResponse mockResponse = mock(SearchResponse.class); SearchHits searchHits = response.getHits(); + + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, null, 1); + when(mockResponse.getInternalResponse()).thenReturn(internalSearchResponse); + RuntimeException mockException = new RuntimeException("Mock exception"); AtomicInteger callCount = new AtomicInteger(0); - ; + when(mockResponse.getHits()).thenAnswer(invocation -> { int count = callCount.getAndIncrement(); - if (count == 2) { + if (count == 6) { // throw exception when it reaches createRewriteResponseListener throw mockException; } else { @@ -1011,7 +1092,7 @@ public void onFailure(Exception e) { } /** - * Tests create processor with one_to_one is true + * Tests create processor with one_to_one is false * with output_maps * createRewriteResponseListener throw Exceptions * expect to run one prediction task @@ -1066,7 +1147,10 @@ public void testProcessResponseCreateRewriteResponseListenerException() throws E SearchHits searchHits = response.getHits(); RuntimeException mockException = new RuntimeException("Mock exception"); AtomicInteger callCount = new AtomicInteger(0); - ; + + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, null, 1); + when(mockResponse.getInternalResponse()).thenReturn(internalSearchResponse); + when(mockResponse.getHits()).thenAnswer(invocation -> { int count = callCount.getAndIncrement(); @@ -3538,7 +3622,7 @@ public void testOutputMapsExceedInputMaps() throws Exception { output2.put("hashtag_embedding", "response"); outputMap.add(output2); Map output3 = new HashMap<>(); - output2.put("hashtvg_embedding", "response"); + output3.put("hashtvg_embedding", "response"); outputMap.add(output3); config.put(OUTPUT_MAP, outputMap); config.put(MAX_PREDICTION_TASKS, 2); @@ -3587,4 +3671,40 @@ public void testCreateOptionalFields() throws Exception { assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); } + + /** + * Tests the case where output map try to write to extension and one to one inference is true + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testWriteToExtensionAndOneToOne() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("ext.inference.hashtag_embedding", "response"); + outputMap.add(output2); + config.put(OUTPUT_MAP, outputMap); + config.put(ONE_TO_ONE, true); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), ""); + + } + } }