From f8edd8c2d40992ccf5ed314d17ba596194a62e76 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Sat, 20 Jul 2024 10:57:06 -0700 Subject: [PATCH] add ITs Signed-off-by: Mingshi Liu --- .../MLInferenceSearchResponseProcessor.java | 387 ++++++--- .../{VersionedMapUtils.java => MapUtils.java} | 17 +- .../ml/plugin/MachineLearningPluginTests.java | 4 +- ...InferenceSearchResponseProcessorTests.java | 807 ++++++++++++++++-- .../ml/rest/MLCommonsRestTestCase.java | 7 + ...tMLInferenceSearchResponseProcessorIT.java | 475 +++++++++++ .../opensearch/ml/utils/MapUtilsTests.java | 83 ++ 7 files changed, 1603 insertions(+), 177 deletions(-) rename plugin/src/main/java/org/opensearch/ml/utils/{VersionedMapUtils.java => MapUtils.java} (66%) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 239beaab64..3348d82f50 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.processor; +import static java.lang.Math.max; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; @@ -17,7 +18,6 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -42,7 +42,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.utils.StringUtils; -import org.opensearch.ml.utils.VersionedMapUtils; +import org.opensearch.ml.utils.MapUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -118,6 +118,14 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp throw new RuntimeException("ML inference search response processor make asynchronous calls and does not call processRequest"); } + /** + * Processes the search response asynchronously by rewriting the documents with the inference results. + * + * @param request the search request + * @param response the search response + * @param responseContext the pipeline processing context + * @param responseListener the listener to be notified when the response is processed + */ @Override public void processResponseAsync( SearchRequest request, @@ -134,27 +142,36 @@ public void processResponseAsync( } rewriteResponseDocuments(response, hits, responseListener); } catch (Exception e) { - responseListener.onFailure(e); + if (ignoreFailure) { + responseListener.onResponse(response); + } else { + responseListener.onFailure(e); + } } } + /** + * Rewrite the documents in the search response with the inference results. + * + * @param response the search response + * @param hits the search hits + * @param responseListener the listener to be notified when the response is processed + * @throws IOException if an I/O error occurs during the rewriting process + */ private void rewriteResponseDocuments(SearchResponse response, SearchHit[] hits, ActionListener responseListener) throws IOException { List> processInputMap = inferenceProcessorAttributes.getInputMaps(); List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); - int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0; + int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size(); - // TODO decide the default mapping - if (inputMapSize == 0) { - responseListener.onResponse(response); - return; - } - - ActionListener> rewriteResponseListener = createRewriteRequestListener( + // hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction + Map hitCountInPredictions = new HashMap<>(); + ActionListener> rewriteResponseListener = createRewriteResponseListener( response, responseListener, processInputMap, - processOutputMap + processOutputMap, + hitCountInPredictions ); GroupedActionListener> batchPredictionListener = createBatchPredictionListener( @@ -162,17 +179,29 @@ private void rewriteResponseDocuments(SearchResponse response, SearchHit[] hits, inputMapSize ); - for (int inputMapIndex = 0; inputMapIndex < inputMapSize; inputMapIndex++) { - processPredictions(response, hits, processInputMap, inputMapIndex, batchPredictionListener); + for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) { + processPredictions(response, hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions); } } + /** + * Processes the predictions for the given input map index. + * + * @param response the search response + * @param hits the search hits + * @param processInputMap the list of input mappings + * @param inputMapIndex the index of the input mapping to process + * @param batchPredictionListener the listener to be notified when the predictions are processed + * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction + * @throws IOException if an I/O error occurs during the prediction process + */ private void processPredictions( SearchResponse response, SearchHit[] hits, List> processInputMap, int inputMapIndex, - GroupedActionListener> batchPredictionListener + GroupedActionListener> batchPredictionListener, + Map hitCountInPredictions ) throws IOException { Map modelParameters = new HashMap<>(); @@ -188,91 +217,125 @@ private void processPredictions( Map inputMapping; if (processInputMap != null) { inputMapping = processInputMap.get(inputMapIndex); + for (SearchHit hit : hits) { - for (Map.Entry entry : inputMapping.entrySet()) { - // model field as key, document field name as value - String modelInputFieldName = entry.getKey(); - String documentFieldName = entry.getValue(); - - Map document = hit.getSourceAsMap(); - Object documentJson = JsonPath.parse(document).read("$"); - Configuration configuration = Configuration - .builder() - .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) - .build(); - - Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName); - if (documentValue != null) { - // when not existed in the map, add into the modelInputParameters map - if (!modelInputParameters.containsKey(modelInputFieldName)) { - modelInputParameters.put(modelInputFieldName, documentValue); - } else { - if (modelInputParameters.get(modelInputFieldName) instanceof List) { - List valueList = ((List) modelInputParameters.get(modelInputFieldName)); - valueList.add(documentValue); + Map document = hit.getSourceAsMap(); + boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping); + if (!isModelInputMissing) { + MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex); + for (Map.Entry entry : inputMapping.entrySet()) { + // model field as key, document field name as value + String modelInputFieldName = entry.getKey(); + String documentFieldName = entry.getValue(); + + Object documentJson = JsonPath.parse(document).read("$"); + Configuration configuration = Configuration + .builder() + .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) + .build(); + + Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName); + if (documentValue != null) { + + // when not existed in the map, add into the modelInputParameters map + if (!modelInputParameters.containsKey(modelInputFieldName)) { + modelInputParameters.put(modelInputFieldName, documentValue); } else { - Object firstValue = modelInputParameters.remove(modelInputFieldName); - List documentValueList = new ArrayList<>(); - documentValueList.add(firstValue); - documentValueList.add(documentValue); - modelInputParameters.put(modelInputFieldName, documentValueList); + if (modelInputParameters.get(modelInputFieldName) instanceof List) { + List valueList = ((List) modelInputParameters.get(modelInputFieldName)); + valueList.add(documentValue); + } else { + Object firstValue = modelInputParameters.remove(modelInputFieldName); + List documentValueList = new ArrayList<>(); + documentValueList.add(firstValue); + documentValueList.add(documentValue); + modelInputParameters.put(modelInputFieldName, documentValueList); + } } } } - // when document does not contain the documentFieldName, skip when ignoreMissing - else { - if (!ignoreMissing) { - throw new IllegalArgumentException("cannot find field name: " + documentFieldName + " in hit:" + hit); - } - + } else { // when document does not contain the documentFieldName, skip when ignoreMissing + if (!ignoreMissing) { + throw new IllegalArgumentException( + "cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit + ); } } } - for (Map.Entry entry : modelInputParameters.entrySet()) { - String key = entry.getKey(); - Object value = entry.getValue(); - modelParameters.put(key, StringUtils.toJson(value)); - } + } else { + for (SearchHit hit : hits) { + Map document = hit.getSourceAsMap(); + MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex); + for (Map.Entry entry : document.entrySet()) { + // model field as key, document field name as value + String modelInputFieldName = entry.getKey(); + Object documentValue = entry.getValue(); - Set inputMapKeys = new HashSet<>(modelParameters.keySet()); - inputMapKeys.removeAll(modelConfigs.keySet()); + // when not existed in the map, add into the modelInputParameters map + if (!modelInputParameters.containsKey(modelInputFieldName)) { + modelInputParameters.put(modelInputFieldName, documentValue); + } else { + if (modelInputParameters.get(modelInputFieldName) instanceof List) { + List valueList = ((List) modelInputParameters.get(modelInputFieldName)); + valueList.add(documentValue); + } else { + Object firstValue = modelInputParameters.remove(modelInputFieldName); + List documentValueList = new ArrayList<>(); + documentValueList.add(firstValue); + documentValueList.add(documentValue); + modelInputParameters.put(modelInputFieldName, documentValueList); + } + } - Map inputMappings = new HashMap<>(); - for (String k : inputMapKeys) { - inputMappings.put(k, modelParameters.get(k)); + } } + } - ActionRequest request = getMLModelInferenceRequest( - xContentRegistry, - modelParameters, - modelConfigs, - inputMappings, - inferenceProcessorAttributes.getModelId(), - functionName, - modelInput - ); + modelParameters = StringUtils.getParameterMap(modelInputParameters); - client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); - @Override - public void onResponse(MLTaskResponse mlTaskResponse) { - MLOutput mlOutput = mlTaskResponse.getOutput(); - Map mlOutputMap = new HashMap<>(); - mlOutputMap.put(inputMapIndex, mlOutput); - batchPredictionListener.onResponse(mlOutputMap); - } + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } - @Override - public void onFailure(Exception e) { - batchPredictionListener.onFailure(e); - } - }); + ActionRequest request = getMLModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); - } + client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { + @Override + public void onResponse(MLTaskResponse mlTaskResponse) { + MLOutput mlOutput = mlTaskResponse.getOutput(); + Map mlOutputMap = new HashMap<>(); + mlOutputMap.put(inputMapIndex, mlOutput); + batchPredictionListener.onResponse(mlOutputMap); + } + + @Override + public void onFailure(Exception e) { + batchPredictionListener.onFailure(e); + } + }); } + /** + * Creates a grouped action listener for batch predictions. + * + * @param rewriteResponseListener the listener to be notified when the response is rewritten + * @param inputMapSize the size of the input map + * @return a grouped action listener for batch predictions + */ private GroupedActionListener> createBatchPredictionListener( ActionListener> rewriteResponseListener, int inputMapSize @@ -292,21 +355,31 @@ public void onFailure(Exception e) { logger.error("Prediction Failed:", e); rewriteResponseListener.onFailure(e); } - }, Math.max(inputMapSize, 1)); + }, max(inputMapSize, 1)); } - private ActionListener> createRewriteRequestListener( + /** + * Creates an action listener for rewriting the response with the inference results. + * + * @param response the search response + * @param responseListener the listener to be notified when the response is processed + * @param processInputMap the list of input mappings + * @param processOutputMap the list of output mappings + * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction + * @return an action listener for rewriting the response with the inference results + */ + private ActionListener> createRewriteResponseListener( SearchResponse response, ActionListener responseListener, List> processInputMap, - List> processOutputMap + List> processOutputMap, + Map hitCountInPredictions ) { return new ActionListener<>() { @Override public void onResponse(Map multipleMLOutputs) { try { - - Map> hasInputMapFieldDocCounter = new HashMap<>(); + Map> writeOutputMapDocCounter = new HashMap<>(); for (SearchHit hit : response.getHits().getHits()) { Map sourceAsMapWithInference = new HashMap<>(); @@ -317,30 +390,28 @@ public void onResponse(Map multipleMLOutputs) { Map sourceAsMap = typeAndSourceMap.v2(); sourceAsMapWithInference.putAll(sourceAsMap); + Map document = hit.getSourceAsMap(); + for (Map.Entry entry : multipleMLOutputs.entrySet()) { Integer mappingIndex = entry.getKey(); MLOutput mlOutput = entry.getValue(); - Map outputMapping = processOutputMap.get(mappingIndex); - Map inputMapping = processInputMap.get(mappingIndex); - // TODO deal with no inputMapping and no outputMapping edge case. - Iterator> inputIterator = inputMapping.entrySet().iterator(); - Iterator> outputIterator = outputMapping.entrySet().iterator(); + Map inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap); + Map outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap); - // Iterate over both maps simultaneously - while (inputIterator.hasNext() || outputIterator.hasNext()) { - Map.Entry inputMapEntry = inputIterator.hasNext() ? inputIterator.next() : null; - Map.Entry outputMapEntry = outputIterator.hasNext() ? outputIterator.next() : null; - String modelInputFieldName = inputMapEntry.getKey(); - String oldDocumentFieldName = inputMapEntry.getValue(); + boolean isModelInputMissing = false; + if (processInputMap != null) { + isModelInputMissing = checkIsModelInputMissing(document, inputMapping); + } + if (!isModelInputMissing) { + // Iterate over outputMapping + for (Map.Entry outputMapEntry : outputMapping.entrySet()) { - Map document = hit.getSourceAsMap(); - if (hasField(document, oldDocumentFieldName)) { + String newDocumentFieldName = outputMapEntry.getKey(); // text_embedding + String modelOutputFieldName = outputMapEntry.getValue(); // response - VersionedMapUtils.incrementCounter(hasInputMapFieldDocCounter, mappingIndex, modelInputFieldName); + MapUtils.incrementCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName); - String newDocumentFieldName = outputMapEntry.getKey(); - String modelOutputFieldName = outputMapEntry.getValue(); Object modelOutputValue = getModelOutputValue( mlOutput, modelOutputFieldName, @@ -348,12 +419,10 @@ public void onResponse(Map multipleMLOutputs) { fullResponsePath ); Object modelOutputValuePerDoc; - if (modelOutputValue instanceof List && ((List) modelOutputValue).size() > 1) { + if (modelOutputValue instanceof List + && ((List) modelOutputValue).size() == hitCountInPredictions.get(mappingIndex)) { Object valuePerDoc = ((List) modelOutputValue) - .get( - VersionedMapUtils - .getCounter(hasInputMapFieldDocCounter, mappingIndex, modelInputFieldName) - ); + .get(MapUtils.getCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName)); modelOutputValuePerDoc = valuePerDoc; } else { modelOutputValuePerDoc = modelOutputValue; @@ -363,27 +432,26 @@ public void onResponse(Map multipleMLOutputs) { if (override) { sourceAsMapWithInference.remove(newDocumentFieldName); sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } else { + logger + .debug( + "{} already exists in the search response hit. Skip processing this field.", + newDocumentFieldName + ); + // TODO when the response has the same field name, should it throw exception? currently, + // ingest processor quietly skip it } } else { sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); } - } else { - if (!ignoreMissing) { - throw new IllegalArgumentException( - "cannot find field name: " + oldDocumentFieldName + " in hit:" + hit - ); - } } - } } - XContentBuilder builder = XContentBuilder.builder(typeAndSourceMap.v1().xContent()); builder.map(sourceAsMapWithInference); hit.sourceRef(BytesReference.bytes(builder)); } - } } catch (Exception e) { if (ignoreFailure) { @@ -393,14 +461,13 @@ public void onResponse(Map multipleMLOutputs) { responseListener.onFailure(e); } } - responseListener.onResponse(response); } @Override public void onFailure(Exception e) { if (ignoreFailure) { - logger.error("Failed in writing prediction outcomes to new query", e); + logger.error("Failed in writing prediction outcomes to search response", e); responseListener.onResponse(response); } else { @@ -410,6 +477,84 @@ public void onFailure(Exception e) { }; } + /** + * Checks if the document is missing any of the required input fields specified in the input mapping. + * + * @param document the document map + * @param inputMapping the input mapping + * @return true if the document is missing any of the required input fields, false otherwise + */ + private boolean checkIsModelInputMissing(Map document, Map inputMapping) { + boolean isModelInputMissing = false; + + for (Map.Entry inputMapEntry : inputMapping.entrySet()) { + String oldDocumentFieldName = inputMapEntry.getValue(); + boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName); + if (!checkSingleModelInputPresent) { + isModelInputMissing = true; + break; + } + } + return isModelInputMissing; + } + + /** + * Retrieves the default output mapping for a given mapping index. + * + *

If the provided processOutputMap is null or empty, a new HashMap is created with a default + * output field name mapped to a JsonPath expression representing the root object ($) followed by + * the default output field name. + * + *

If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex + * is returned. + * + * @param mappingIndex the index of the mapping to retrieve from the processOutputMap + * @param processOutputMap the list of output mappings, can be null or empty + * @return a Map containing the output mapping, either the default mapping or the mapping at the + * specified index + */ + private static Map getDefaultOutputMapping(Integer mappingIndex, List> processOutputMap) { + Map outputMapping; + if (processOutputMap == null || processOutputMap.size() == 0) { + outputMapping = new HashMap<>(); + outputMapping.put(DEFAULT_OUTPUT_FIELD_NAME, "$." + DEFAULT_OUTPUT_FIELD_NAME); + } else { + outputMapping = processOutputMap.get(mappingIndex); + } + return outputMapping; + } + + /** + * Retrieves the default input mapping for a given mapping index and source map. + * + *

If the provided processInputMap is null or empty, a new HashMap is created by extracting + * key-value pairs from the sourceAsMap using StringUtils.getParameterMap(). + * + *

If the processInputMap is not null and not empty, the mapping at the specified mappingIndex + * is returned. + * + * @param sourceAsMap the source map containing the input data + * @param mappingIndex the index of the mapping to retrieve from the processInputMap + * @param processInputMap the list of input mappings, can be null or empty + * @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or + * the mapping at the specified index + */ + private static Map getDefaultInputMapping( + Map sourceAsMap, + Integer mappingIndex, + List> processInputMap + ) { + Map inputMapping; + + if (processInputMap == null || processInputMap.size() == 0) { + inputMapping = new HashMap<>(); + inputMapping.putAll(StringUtils.getParameterMap(sourceAsMap)); + } else { + inputMapping = processInputMap.get(mappingIndex); + } + return inputMapping; + } + /** * Returns the type of the processor. * @@ -462,8 +607,8 @@ public MLInferenceSearchResponseProcessor create( String modelId = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, MODEL_ID); Map modelConfigInput = ConfigurationUtils.readOptionalMap(TYPE, processorTag, config, MODEL_CONFIG); - List> inputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, INPUT_MAP); - List> outputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, OUTPUT_MAP); + List> inputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, INPUT_MAP); + List> outputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, OUTPUT_MAP); int maxPredictionTask = ConfigurationUtils .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); @@ -504,9 +649,9 @@ public MLInferenceSearchResponseProcessor create( } if (outputMaps != null && inputMaps != null && outputMaps.size() != inputMaps.size()) { throw new IllegalArgumentException( - "when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of" + "when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of " + inputMaps.size() - + ", while output_maps is in the length of" + + ", while output_maps is in the length of " + outputMaps.size() + ". Please adjust mappings." ); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/VersionedMapUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java similarity index 66% rename from plugin/src/main/java/org/opensearch/ml/utils/VersionedMapUtils.java rename to plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java index 301b9a9400..bc1deb085a 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/VersionedMapUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java @@ -8,7 +8,7 @@ import java.util.HashMap; import java.util.Map; -public class VersionedMapUtils { +public class MapUtils { /** * Increments the counter for the given key in the specified version. @@ -35,4 +35,19 @@ public static int getCounter(Map> versionedCounter return counters != null ? counters.getOrDefault(key, -1) : 0; } + /** + * Increments the counter value for the given key in the provided counters map. + * If the key does not exist in the map, it is added with an initial counter value of 0. + * + * @param counters A map that stores integer counters for each integer key. + * @param key The integer key for which the counter needs to be incremented. + */ + public static void incrementCounter(Map counters, int key) { + counters.put(key, counters.getOrDefault(key, 0) + 1); + } + + public static int getCounter(Map counters, int key) { + return counters.getOrDefault(key, 0); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index ba75fb7df1..de9b040238 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -38,6 +38,7 @@ import org.opensearch.ml.common.spi.MLCommonsExtension; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.tools.MLModelTool; +import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -83,10 +84,11 @@ public void testGetRequestProcessors() { public void testGetResponseProcessors() { SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); Map responseProcessors = plugin.getResponseProcessors(parameters); - assertEquals(1, responseProcessors.size()); + assertEquals(2, responseProcessors.size()); assertTrue( responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory ); + assertTrue(responseProcessors.get(MLInferenceSearchResponseProcessor.TYPE) instanceof MLInferenceSearchResponseProcessor.Factory); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index bdef23f37e..8739dd5ff8 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -7,6 +7,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.*; import java.util.ArrayList; @@ -20,6 +25,7 @@ import org.junit.Before; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -68,15 +74,12 @@ public void setup() { * @throws Exception if an error occurs during the test */ public void testProcessResponseException() throws Exception { - String modelInputField = "inputs"; - String originalDocumentField = "query.term.text.value"; - String newDocumentField = "query.term.text.value"; - String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( - modelOutputField, - modelInputField, - originalDocumentField, - newDocumentField, + null, + null, + null, + null, false, false, false @@ -93,6 +96,11 @@ public void testProcessResponseException() throws Exception { } } + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * + * @throws Exception if an error occurs during the test + */ public void testProcessResponseSuccess() throws Exception { String modelInputField = "inputs"; String originalDocumentField = "text"; @@ -107,6 +115,8 @@ public void testProcessResponseSuccess() throws Exception { false, false ); + + assertEquals(responseProcessor.getType(), TYPE); SearchRequest request = getSearchRequest(); String fieldName = "text"; SearchResponse response = getSearchResponse(5, true, fieldName); @@ -128,8 +138,6 @@ public void testProcessResponseSuccess() throws Exception { @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsString()); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding")); assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); @@ -146,6 +154,87 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } + /** + * Tests the successful processing of a response without any input-output mappings. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseNoMappingSuccess() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response with a list of embeddings as the output. + * + * @throws Exception if an error occurs during the test + */ public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { /** * sample response before inference @@ -161,8 +250,8 @@ public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { * { "text" : "value 0", "text_embedding":[0.1, 0.2]}, * { "text" : "value 1", "text_embedding":[0.2, 0.2]}, * { "text" : "value 2", "text_embedding":[0.3, 0.2]}, - * { "text" : "value 3","text_embedding":[0.4, 0.2]}, - * { "text" : "value 4","text_embedding":[0.5, 0.2]} + * { "text" : "value 3","text_embedding":[0.4, 0.2]}, + * { "text" : "value 4","text_embedding":[0.5, 0.2]} */ String modelInputField = "inputs"; @@ -212,9 +301,6 @@ public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - System.out.println("printing document.. "); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsString()); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding")); assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), Arrays.asList(0.1, 0.2)); assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), Arrays.asList(0.2, 0.2)); assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), Arrays.asList(0.3, 0.2)); @@ -231,6 +317,99 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } + /** + * Tests the successful processing of a response where the existing document field is overridden. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOverrideSameField() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "text" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text":[0.1, 0.2]}, + * { "text":[0.2, 0.2]}, + * { "text":[0.3, 0.2]}, + * { "text":[0.4, 0.2]}, + * { "text":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + true, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays + .asList( + Arrays.asList(0.1, 0.2), + Arrays.asList(0.2, 0.2), + Arrays.asList(0.3, 0.2), + Arrays.asList(0.4, 0.2), + Arrays.asList(0.5, 0.2) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text"), Arrays.asList(0.1, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text"), Arrays.asList(0.2, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text"), Arrays.asList(0.3, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text"), Arrays.asList(0.4, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text"), Arrays.asList(0.5, 0.2)); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response where one input field is missing, + * and the `ignoreMissing` flag is set to true. + * + * @throws Exception if an error occurs during the test + */ public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSuccess() throws Exception { /** * sample response before inference @@ -290,10 +469,6 @@ public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSucce @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - - System.out.println("printing document.. "); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsString()); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding")); assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), Arrays.asList(0.1, 0.2)); assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), Arrays.asList(0.2, 0.2)); @@ -311,7 +486,97 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } - public void testProcessResponseTwoRoundsOfPredictionSuccess() { + /** + * Tests the case where one input field is missing, and an exception is expected + * when the `ignoreMissing` flag is set to false. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseListOfEmbeddingsMissingOneInputException() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "textMissing" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text" : "value 0", "text_embedding":[0.1, 0.2]}, + * { "text" : "value 1", "text_embedding":[0.2, 0.2]}, + * { "textMissing" : "value 2"}, + * { "text" : "value 3","text_embedding":[0.4, 0.2]}, + * { "text" : "value 4","text_embedding":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponseMissingField(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays.asList(Arrays.asList(0.1, 0.2), Arrays.asList(0.2, 0.2), Arrays.asList(0.4, 0.2), Arrays.asList(0.5, 0.2)) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals( + "cannot find all required input fields: [text] in hit:{\n" + + " \"_id\" : \"doc 2\",\n" + + " \"_score\" : 2.0,\n" + + " \"_source\" : {\n" + + " \"textMissing\" : \"value 2\"\n" + + " }\n" + + "}", + e.getMessage() + ); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response with two rounds of prediction. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseTwoRoundsOfPredictionSuccess() throws Exception { String modelInputField = "inputs"; String modelOutputField = "response"; @@ -341,6 +606,8 @@ public void testProcessResponseTwoRoundsOfPredictionSuccess() { output2.put(newDocumentField1, modelOutputField); outputMap.add(output2); + Map modelConfig = new HashMap<>(); + modelConfig.put("model_task_type", "TEXT_EMBEDDING"); MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", inputMap, @@ -378,8 +645,6 @@ public void testProcessResponseTwoRoundsOfPredictionSuccess() { @Override public void onResponse(SearchResponse newSearchResponse) { assertEquals(newSearchResponse.getHits().getHits().length, 5); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsString()); - System.out.println(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField)); assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); @@ -402,30 +667,254 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } - private static SearchRequest getSearchRequest() { - QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); - SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); - SearchRequest request = new SearchRequest().source(source); - return request; - } - /** - * Helper method to create an instance of the MLInferenceSearchResponseProcessor with the specified parameters in - * single pair of input and output mapping. + * Tests the successful processing of a response with one model input and multiple model outputs. * - * @param modelInputField the model input field name - * @param originalDocumentField the original query field name - * @param newDocumentField the new document field name - * @param override the flag indicating whether to override existing document field - * @param ignoreFailure the flag indicating whether to ignore failures or not - * @param ignoreMissing the flag indicating whether to ignore missing fields or not - * @return an instance of the MLInferenceSearchResponseProcessor + * @throws Exception if an error occurs during the test */ - private MLInferenceSearchResponseProcessor getMlInferenceSearchResponseProcessorSinglePairMapping( - String modelOutputField, - String modelInputField, - String originalDocumentField, - String newDocumentField, + public void testProcessResponseOneModelInputMultipleModelOutputs() throws Exception { + // one model input + String modelInputField = "inputs"; + String originalDocumentField = "text"; + + // two model outputs + String modelOutputField = "response"; + String newDocumentField = "text_embedding"; + String modelOutputField1 = "response_type"; + String newDocumentField1 = "embedding_type"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + output.put(newDocumentField1, modelOutputField1); + outputMap.add(output); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + SearchRequest request = getSearchRequest(); + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of(modelOutputField, Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0), "response_type", "embedding_float")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1), "embedding_float"); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where an exception occurs during prediction, and the `ignoreFailure` flag is set to false. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponsePredictionException() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + throw new RuntimeException("Prediction Failed"); + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly."); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Prediction Failed", e.getMessage()); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where an exception occurs during prediction, but the `ignoreFailure` flag is set to true. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponsePredictionExceptionIgnoreFailure() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + throw new RuntimeException("Prediction Failed"); + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(response, newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly."); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where there are no hits in the search response. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseEmptyHit() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(0, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(response, newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + private static SearchRequest getSearchRequest() { + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + return request; + } + + /** + * Helper method to create an instance of the MLInferenceSearchResponseProcessor with the specified parameters in + * single pair of input and output mapping. + * + * @param modelInputField the model input field name + * @param originalDocumentField the original query field name + * @param newDocumentField the new document field name + * @param override the flag indicating whether to override existing document field + * @param ignoreFailure the flag indicating whether to ignore failures or not + * @param ignoreMissing the flag indicating whether to ignore missing fields or not + * @return an instance of the MLInferenceSearchResponseProcessor + */ + private MLInferenceSearchResponseProcessor getMlInferenceSearchResponseProcessorSinglePairMapping( + String modelOutputField, + String modelInputField, + String originalDocumentField, + String newDocumentField, boolean override, boolean ignoreFailure, boolean ignoreMissing @@ -462,7 +951,6 @@ private MLInferenceSearchResponseProcessor getMlInferenceSearchResponseProcessor private SearchResponse getSearchResponse(int size, boolean includeMapping, String fieldName) { SearchHit[] hits = new SearchHit[size]; - System.out.println("printing hit.. "); for (int i = 0; i < size; i++) { Map searchHitFields = new HashMap<>(); if (includeMapping) { @@ -472,8 +960,6 @@ private SearchResponse getSearchResponse(int size, boolean includeMapping, Strin hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "\" : \"value " + i + "\" }")); hits[i].score(i); - - System.out.println(hits[i].getSourceAsString()); } SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); @@ -482,7 +968,6 @@ private SearchResponse getSearchResponse(int size, boolean includeMapping, Strin private SearchResponse getSearchResponseMissingField(int size, boolean includeMapping, String fieldName) { SearchHit[] hits = new SearchHit[size]; - System.out.println("printing hit.. "); for (int i = 0; i < size; i++) { if (i == (size % 3)) { @@ -494,8 +979,6 @@ private SearchResponse getSearchResponseMissingField(int size, boolean includeMa hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "Missing" + "\" : \"value " + i + "\" }")); hits[i].score(i); - - System.out.println(hits[i].getSourceAsString()); } else { Map searchHitFields = new HashMap<>(); if (includeMapping) { @@ -505,8 +988,6 @@ private SearchResponse getSearchResponseMissingField(int size, boolean includeMa hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "\" : \"value " + i + "\" }")); hits[i].score(i); - - System.out.println(hits[i].getSourceAsString()); } } SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); @@ -516,7 +997,6 @@ private SearchResponse getSearchResponseMissingField(int size, boolean includeMa private SearchResponse getSearchResponseTwoFields(int size, boolean includeMapping, String fieldName, String fieldName1) { SearchHit[] hits = new SearchHit[size]; - System.out.println("printing hit.. "); for (int i = 0; i < size; i++) { Map searchHitFields = new HashMap<>(); if (includeMapping) { @@ -533,11 +1013,230 @@ private SearchResponse getSearchResponseTwoFields(int size, boolean includeMappi ) ); hits[i].score(i); - - System.out.println(hits[i].getSourceAsString()); } SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); } + + private MLInferenceSearchResponseProcessor.Factory factory; + + @Mock + private NamedXContentRegistry xContentRegistry; + + @Before + public void init() { + factory = new MLInferenceSearchResponseProcessor.Factory(client, xContentRegistry); + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor with required fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateRequiredFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor for a local model. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessor() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } + + /** + * The model input field is required for using a local model + * when missing the model input field, expected to throw Exceptions + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessorMissingModelInputField() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + try { + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + } catch (Exception e) { + assertEquals(e.getMessage(), "Please provide model input when using a local model in ML Inference Processor"); + } + } + + /** + * Tests the case where the `model_id` field is missing in the configuration, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateNoFieldPresent() throws Exception { + Map config = new HashMap<>(); + try { + factory.create(Collections.emptyMap(), "no field", null, false, config, null); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[model_id] required property is missing")); + } + } + + /** + * Tests the case where the number of prediction tasks exceeds the maximum allowed value, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExceedMaxPredictionTasks() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + Map input2 = new HashMap<>(); + input2.put("inputs", "timestamp"); + inputMap.add(input2); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + ("The number of prediction task setting in this process is 3. It exceeds the max_prediction_tasks of 2. Please reduce the size of input_map or increase max_prediction_tasks.") + ); + } + } + + /** + * Tests the case where the length of the `output_map` list is greater than the length of the `input_map` list, + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testOutputMapsExceedInputMaps() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("hashtag_embedding", "response"); + outputMap.add(output2); + Map output3 = new HashMap<>(); + output2.put("hashtvg_embedding", "response"); + outputMap.add(output3); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of 2, while output_maps is in the length of 3. Please adjust mappings." + ); + + } + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateOptionalFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + Map model_config = new HashMap<>(); + model_config.put("hidden_size", 768); + model_config.put("gradient_checkpointing", false); + model_config.put("position_embedding_type", "absolute"); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 788a98239c..2092b9f4b4 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -1015,4 +1015,11 @@ public String registerRemoteModel(String createConnectorInput, String modelName, } return modelId; } + + public Map searchWithPipeline(RestClient client, String indexName, String pipelineName, String query) throws IOException { + String formattedQuery = String.format(Locale.ROOT, query); + Response response = TestHelper + .makeRequest(client, "GET", "/" + indexName + "/" + "_search?search_pipeline=" + pipelineName, null, formattedQuery, null); + return parseResponseToMap(response); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java new file mode 100644 index 0000000000..9e031935d7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -0,0 +1,475 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; + +public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestCase { + + private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + private String openAIChatModelId; + private String bedrockEmbeddingModelId; + private String localModelId; + private final String completionModelConnectorEntity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private final String bedrockEmbeddingModelConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" + + " \"description\": \"The connector to bedrock Titan embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + + /** + * Registers two remote models and creates an index and documents before running the tests. + * + * @throws Exception if any error occurs during the setup + */ + @Before + public void setup() throws Exception { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5); + this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); + + String index_name = "daily_index"; + String createIndexRequestBody = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"diary_embedding_size\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + String uploadDocumentRequestBodyDoc1 = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"diary_embedding_size\": \"1536\",\n" // embedding size for ada model + + " \"weather\": \"sunny\"\n" + + " }"; + + String uploadDocumentRequestBodyDoc2 = "{\n" + + " \"id\": 2,\n" + + " \"diary\": [\"bored\",\"at home\"],\n" + + " \"diary_embedding_size\": \"768\",\n" // embedding size for local text embedding model + + " \"weather\": \"sunny\"\n" + + " }"; + + createIndex(index_name, createIndexRequestBody); + uploadDocument(index_name, "1", uploadDocumentRequestBodyDoc1); + uploadDocument(index_name, "2", uploadDocumentRequestBodyDoc2); + } + + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + } + + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and a string field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelString() throws Exception { + + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.bedrockEmbeddingModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"$.embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"1536\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.734375, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), 0.87109375, 0.005); + } + + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and a nested list field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelNestedListField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"diary[0]\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"dairy_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "diary_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.dairy_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), -0.011842756, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.012508746, 0.005); + } + + /** + * Tests the ML inference processor with a local model. + * It registers, deploys, and gets a local model, creates a search pipeline with the ML inference processor + * configured to use the local model, and then performs a search using the pipeline. + * The test verifies that the query string is rewritten correctly based on the inference results + * from the local model. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorLocalModel() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"$.inference_results[0].output[0].data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline_local"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"768\"}}}}"; + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "768"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "bored"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "at home"); + + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 768); + Assert.assertEquals((Double) embeddingList.get(0), 0.54809606, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), 0.46797526, 0.005); + + } + + /** + * Creates a search pipeline processor with the given request body and pipeline name. + * + * @param requestBody the request body for creating the search pipeline processor + * @param pipelineName the name of the search pipeline + * @throws Exception if any error occurs during the creation of the search pipeline processor + */ + protected void createSearchPipelineProcessor(String requestBody, final String pipelineName) throws Exception { + Response pipelineCreateResponse = TestHelper + .makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, pipelineCreateResponse.getStatusLine().getStatusCode()); + + } + + /** + * Creates an index with the given name and request body. + * + * @param indexName the name of the index + * @param requestBody the request body for creating the index + * @throws Exception if any error occurs during the creation of the index + */ + protected void createIndex(String indexName, String requestBody) throws Exception { + Response response = makeRequest( + client(), + "PUT", + indexName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + /** + * Uploads a document to the specified index with the given document ID and JSON body. + * + * @param index the name of the index + * @param docId the document ID + * @param jsonBody the JSON body of the document + * @throws IOException if an I/O error occurs during the upload + */ + protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { + Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(jsonBody); + client().performRequest(request); + } + + /** + * Creates a MLRegisterModelInput instance with the specified configuration. + * + * @return the MLRegisterModelInput instance + * @throws IOException if an I/O error occurs during the creation of the input + * @throws InterruptedException if the thread is interrupted during the creation of the input + */ + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java new file mode 100644 index 0000000000..d133e00657 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.utils; + +import static org.junit.Assert.assertEquals; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +public class MapUtilsTests { + + @Test + public void testIncrementCounterForVersionedCounters() { + Map> versionedCounters = new HashMap<>(); + + MapUtils.incrementCounter(versionedCounters, 0, "key1"); + assertEquals(0, (int) versionedCounters.get(0).get("key1")); + + // Test incrementing counter for an existing version and key + MapUtils.incrementCounter(versionedCounters, 0, "key1"); + assertEquals(1, (int) versionedCounters.get(0).get("key1")); + + // Test incrementing counter for a new key in an existing version + MapUtils.incrementCounter(versionedCounters, 0, "key2"); + assertEquals(0, (int) versionedCounters.get(0).get("key2")); + + // Test incrementing counter for a new version + MapUtils.incrementCounter(versionedCounters, 1, "key3"); + assertEquals(0, (int) versionedCounters.get(1).get("key3")); + } + + @Test + public void testIncrementCounterForIntegerCounters() { + Map counters = new HashMap<>(); + + // Test incrementing counter for a new key + MapUtils.incrementCounter(counters, 1); + assertEquals(1, (int) counters.get(1)); + + // Test incrementing counter for an existing key + MapUtils.incrementCounter(counters, 1); + assertEquals(2, (int) counters.get(1)); + + // Test incrementing counter for a new key + MapUtils.incrementCounter(counters, 2); + assertEquals(1, (int) counters.get(2)); + } + + @Test + public void testGetCounterForVersionedCounters() { + Map> versionedCounters = new HashMap<>(); + versionedCounters.put(0, new HashMap<>()); + versionedCounters.put(1, new HashMap<>()); + versionedCounters.get(0).put("key1", 5); + versionedCounters.get(1).put("key2", 10); + + // Test getting counter for an existing key + assertEquals(5, MapUtils.getCounter(versionedCounters, 0, "key1")); + assertEquals(10, MapUtils.getCounter(versionedCounters, 1, "key2")); + + // Test getting counter for a non-existing key + assertEquals(-1, MapUtils.getCounter(versionedCounters, 0, "key3")); + assertEquals(0, MapUtils.getCounter(versionedCounters, 2, "key4")); + } + + @Test + public void testGetCounterForIntegerCounters() { + Map counters = new HashMap<>(); + counters.put(1, 5); + counters.put(2, 10); + + // Test getting counter for an existing key + assertEquals(5, MapUtils.getCounter(counters, 1)); + assertEquals(10, MapUtils.getCounter(counters, 2)); + + // Test getting counter for a non-existing key + assertEquals(0, MapUtils.getCounter(counters, 3)); + } +}