diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 169f5a8d3f..4c35e1c605 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -214,6 +214,7 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.processor.MLInferenceIngestProcessor; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; +import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; @@ -996,6 +997,12 @@ public Map> getResponseProces new GenerativeQAResponseProcessor.Factory(this.client, () -> this.ragSearchPipelineEnabled) ); + responseProcessors + .put( + MLInferenceSearchResponseProcessor.TYPE, + new MLInferenceSearchResponseProcessor.Factory(parameters.client, parameters.namedXContentRegistry) + ); + return responseProcessors; } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java new file mode 100644 index 0000000000..3b4f2e24ac --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -0,0 +1,674 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.processor; + +import static java.lang.Math.max; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.OVERRIDE; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.client.Client; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.utils.MapUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; + +public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor, ModelExecutor { + + private final NamedXContentRegistry xContentRegistry; + private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class); + private final InferenceProcessorAttributes inferenceProcessorAttributes; + private final boolean ignoreMissing; + private final String functionName; + private final boolean override; + private final boolean fullResponsePath; + private final boolean oneToOne; + private final boolean ignoreFailure; + private final String modelInput; + private static Client client; + public static final String TYPE = "ml_inference"; + // allow to ignore a field from mapping is not present in the query, and when the output field is not found in the + // prediction outcomes, return the whole prediction outcome by skipping filtering + public static final String IGNORE_MISSING = "ignore_missing"; + public static final String FUNCTION_NAME = "function_name"; + public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; + public static final String ONE_TO_ONE = "one_to_one"; + public static final String DEFAULT_MODEL_INPUT = "{ \"parameters\": ${ml_inference.parameters} }"; + // At default, ml inference processor allows maximum 10 prediction tasks running in parallel + // it can be overwritten using max_prediction_tasks when creating processor + public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; + + protected MLInferenceSearchResponseProcessor( + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, + int maxPredictionTask, + String tag, + String description, + boolean ignoreMissing, + String functionName, + boolean fullResponsePath, + boolean ignoreFailure, + boolean override, + String modelInput, + Client client, + NamedXContentRegistry xContentRegistry, + boolean oneToOne + ) { + super(tag, description, ignoreFailure); + this.oneToOne = oneToOne; + this.inferenceProcessorAttributes = new InferenceProcessorAttributes( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask + ); + this.ignoreMissing = ignoreMissing; + this.functionName = functionName; + this.fullResponsePath = fullResponsePath; + this.ignoreFailure = ignoreFailure; + this.override = override; + this.modelInput = modelInput; + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + throw new RuntimeException("ML inference search response processor make asynchronous calls and does not call processRequest"); + } + + /** + * Processes the search response asynchronously by rewriting the documents with the inference results. + * + * @param request the search request + * @param response the search response + * @param responseContext the pipeline processing context + * @param responseListener the listener to be notified when the response is processed + */ + @Override + public void processResponseAsync( + SearchRequest request, + SearchResponse response, + PipelineProcessingContext responseContext, + ActionListener responseListener + ) { + try { + SearchHit[] hits = response.getHits().getHits(); + // skip processing when there is no hit + if (hits.length == 0) { + responseListener.onResponse(response); + return; + } + rewriteResponseDocuments(response, responseListener); + } catch (Exception e) { + if (ignoreFailure) { + responseListener.onResponse(response); + } else { + responseListener.onFailure(e); + } + } + } + + /** + * Rewrite the documents in the search response with the inference results. + * + * @param response the search response + * @param responseListener the listener to be notified when the response is processed + * @throws IOException if an I/O error occurs during the rewriting process + */ + private void rewriteResponseDocuments(SearchResponse response, ActionListener responseListener) throws IOException { + List> processInputMap = inferenceProcessorAttributes.getInputMaps(); + List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); + int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size(); + + // hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction + Map hitCountInPredictions = new HashMap<>(); + if (!oneToOne) { + ActionListener> rewriteResponseListener = createRewriteResponseListenerManyToOne( + response, + responseListener, + processInputMap, + processOutputMap, + hitCountInPredictions + ); + + GroupedActionListener> batchPredictionListener = createBatchPredictionListenerManyToOne( + rewriteResponseListener, + inputMapSize + ); + SearchHit[] hits = response.getHits().getHits(); + for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) { + processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions); + } + } else { + responseListener.onFailure(new IllegalArgumentException("one to one prediction is not supported yet.")); + } + + } + + /** + * Processes the predictions for the given input map index. + * + * @param hits the search hits + * @param processInputMap the list of input mappings + * @param inputMapIndex the index of the input mapping to process + * @param batchPredictionListener the listener to be notified when the predictions are processed + * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction + * @throws IOException if an I/O error occurs during the prediction process + */ + private void processPredictionsManyToOne( + SearchHit[] hits, + List> processInputMap, + int inputMapIndex, + GroupedActionListener> batchPredictionListener, + Map hitCountInPredictions + ) throws IOException { + + Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { + modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + + Map modelInputParameters = new HashMap<>(); + + Map inputMapping; + if (processInputMap != null && !processInputMap.isEmpty()) { + inputMapping = processInputMap.get(inputMapIndex); + + for (SearchHit hit : hits) { + Map document = hit.getSourceAsMap(); + boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping); + if (!isModelInputMissing) { + MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex); + for (Map.Entry entry : inputMapping.entrySet()) { + // model field as key, document field name as value + String modelInputFieldName = entry.getKey(); + String documentFieldName = entry.getValue(); + + Object documentJson = JsonPath.parse(document).read("$"); + Configuration configuration = Configuration + .builder() + .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) + .build(); + + Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName); + if (documentValue != null) { + // when not existed in the map, add into the modelInputParameters map + updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue); + } + } + } else { // when document does not contain the documentFieldName, skip when ignoreMissing + if (!ignoreMissing) { + throw new IllegalArgumentException( + "cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit + ); + } + } + } + } else { + for (SearchHit hit : hits) { + Map document = hit.getSourceAsMap(); + MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex); + for (Map.Entry entry : document.entrySet()) { + // model field as key, document field name as value + String modelInputFieldName = entry.getKey(); + Object documentValue = entry.getValue(); + + // when not existed in the map, add into the modelInputParameters map + updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue); + + } + } + } + + modelParameters = StringUtils.getParameterMap(modelInputParameters); + + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + + ActionRequest request = getMLModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { + + @Override + public void onResponse(MLTaskResponse mlTaskResponse) { + MLOutput mlOutput = mlTaskResponse.getOutput(); + Map mlOutputMap = new HashMap<>(); + mlOutputMap.put(inputMapIndex, mlOutput); + batchPredictionListener.onResponse(mlOutputMap); + } + + @Override + public void onFailure(Exception e) { + batchPredictionListener.onFailure(e); + } + }); + } + + private void updateModelInputParametersManyToOne( + Map modelInputParameters, + String modelInputFieldName, + Object documentValue + ) { + if (!modelInputParameters.containsKey(modelInputFieldName)) { + List documentValueList = new ArrayList<>(); + documentValueList.add(documentValue); + modelInputParameters.put(modelInputFieldName, documentValueList); + } else { + List valueList = ((List) modelInputParameters.get(modelInputFieldName)); + valueList.add(documentValue); + } + } + + /** + * Creates a grouped action listener for batch predictions. + * + * @param rewriteResponseListener the listener to be notified when the response is rewritten + * @param inputMapSize the size of the input map + * @return a grouped action listener for batch predictions + */ + private GroupedActionListener> createBatchPredictionListenerManyToOne( + ActionListener> rewriteResponseListener, + int inputMapSize + ) { + return new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection> mlOutputMapCollection) { + Map mlOutputMaps = new HashMap<>(); + for (Map mlOutputMap : mlOutputMapCollection) { + mlOutputMaps.putAll(mlOutputMap); + } + rewriteResponseListener.onResponse(mlOutputMaps); + } + + @Override + public void onFailure(Exception e) { + logger.error("Prediction Failed:", e); + rewriteResponseListener.onFailure(e); + } + }, max(inputMapSize, 1)); + } + + /** + * Creates an action listener for rewriting the response with the inference results. + * + * @param response the search response + * @param responseListener the listener to be notified when the response is processed + * @param processInputMap the list of input mappings + * @param processOutputMap the list of output mappings + * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction + * @return an action listener for rewriting the response with the inference results + */ + private ActionListener> createRewriteResponseListenerManyToOne( + SearchResponse response, + ActionListener responseListener, + List> processInputMap, + List> processOutputMap, + Map hitCountInPredictions + ) { + return new ActionListener<>() { + @Override + public void onResponse(Map multipleMLOutputs) { + try { + Map> writeOutputMapDocCounter = new HashMap<>(); + + for (SearchHit hit : response.getHits().getHits()) { + Map sourceAsMapWithInference = new HashMap<>(); + if (hit.hasSource()) { + BytesReference sourceRef = hit.getSourceRef(); + Tuple> typeAndSourceMap = XContentHelper + .convertToMap(sourceRef, false, (MediaType) null); + + Map sourceAsMap = typeAndSourceMap.v2(); + sourceAsMapWithInference.putAll(sourceAsMap); + Map document = hit.getSourceAsMap(); + + for (Map.Entry entry : multipleMLOutputs.entrySet()) { + Integer mappingIndex = entry.getKey(); + MLOutput mlOutput = entry.getValue(); + + Map inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap); + Map outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap); + + boolean isModelInputMissing = false; + if (processInputMap != null) { + isModelInputMissing = checkIsModelInputMissing(document, inputMapping); + } + if (!isModelInputMissing) { + // Iterate over outputMapping + for (Map.Entry outputMapEntry : outputMapping.entrySet()) { + + String newDocumentFieldName = outputMapEntry.getKey(); + String modelOutputFieldName = outputMapEntry.getValue(); + + MapUtils.incrementCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName); + + Object modelOutputValue = getModelOutputValue( + mlOutput, + modelOutputFieldName, + ignoreMissing, + fullResponsePath + ); + Object modelOutputValuePerDoc; + if (modelOutputValue instanceof List + && ((List) modelOutputValue).size() == hitCountInPredictions.get(mappingIndex)) { + Object valuePerDoc = ((List) modelOutputValue) + .get(MapUtils.getCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName)); + modelOutputValuePerDoc = valuePerDoc; + } else { + modelOutputValuePerDoc = modelOutputValue; + } + + if (sourceAsMap.containsKey(newDocumentFieldName)) { + if (override) { + sourceAsMapWithInference.remove(newDocumentFieldName); + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } else { + logger + .debug( + "{} already exists in the search response hit. Skip processing this field.", + newDocumentFieldName + ); + // TODO when the response has the same field name, should it throw exception? currently, + // ingest processor quietly skip it + } + } else { + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } + } + } + } + XContentBuilder builder = XContentBuilder.builder(typeAndSourceMap.v1().xContent()); + builder.map(sourceAsMapWithInference); + hit.sourceRef(BytesReference.bytes(builder)); + + } + } + } catch (Exception e) { + if (ignoreFailure) { + responseListener.onResponse(response); + + } else { + responseListener.onFailure(e); + } + } + responseListener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (ignoreFailure) { + logger.error("Failed in writing prediction outcomes to search response", e); + responseListener.onResponse(response); + + } else { + responseListener.onFailure(e); + } + } + }; + } + + /** + * Checks if the document is missing any of the required input fields specified in the input mapping. + * + * @param document the document map + * @param inputMapping the input mapping + * @return true if the document is missing any of the required input fields, false otherwise + */ + private boolean checkIsModelInputMissing(Map document, Map inputMapping) { + boolean isModelInputMissing = false; + + for (Map.Entry inputMapEntry : inputMapping.entrySet()) { + String oldDocumentFieldName = inputMapEntry.getValue(); + boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName); + if (!checkSingleModelInputPresent) { + isModelInputMissing = true; + break; + } + } + return isModelInputMissing; + } + + /** + * Retrieves the default output mapping for a given mapping index. + * + *

If the provided processOutputMap is null or empty, a new HashMap is created with a default + * output field name mapped to a JsonPath expression representing the root object ($) followed by + * the default output field name. + * + *

If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex + * is returned. + * + * @param mappingIndex the index of the mapping to retrieve from the processOutputMap + * @param processOutputMap the list of output mappings, can be null or empty + * @return a Map containing the output mapping, either the default mapping or the mapping at the + * specified index + */ + private static Map getDefaultOutputMapping(Integer mappingIndex, List> processOutputMap) { + Map outputMapping; + if (processOutputMap == null || processOutputMap.size() == 0) { + outputMapping = new HashMap<>(); + outputMapping.put(DEFAULT_OUTPUT_FIELD_NAME, "$." + DEFAULT_OUTPUT_FIELD_NAME); + } else { + outputMapping = processOutputMap.get(mappingIndex); + } + return outputMapping; + } + + /** + * Retrieves the default input mapping for a given mapping index and source map. + * + *

If the provided processInputMap is null or empty, a new HashMap is created by extracting + * key-value pairs from the sourceAsMap using StringUtils.getParameterMap(). + * + *

If the processInputMap is not null and not empty, the mapping at the specified mappingIndex + * is returned. + * + * @param sourceAsMap the source map containing the input data + * @param mappingIndex the index of the mapping to retrieve from the processInputMap + * @param processInputMap the list of input mappings, can be null or empty + * @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or + * the mapping at the specified index + */ + private static Map getDefaultInputMapping( + Map sourceAsMap, + Integer mappingIndex, + List> processInputMap + ) { + Map inputMapping; + + if (processInputMap == null || processInputMap.size() == 0) { + inputMapping = new HashMap<>(); + inputMapping.putAll(StringUtils.getParameterMap(sourceAsMap)); + } else { + inputMapping = processInputMap.get(mappingIndex); + } + return inputMapping; + } + + /** + * Returns the type of the processor. + * + * @return the type of the processor as a string + */ + @Override + public String getType() { + return TYPE; + } + + /** + * A factory class for creating instances of the MLInferenceSearchResponseProcessor. + * This class implements the Processor.Factory interface for creating SearchResponseProcessor instances. + */ + public static class Factory implements Processor.Factory { + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Constructs a new instance of the Factory class. + * + * @param client the Client instance to be used by the Factory + * @param xContentRegistry the xContentRegistry instance to be used by the Factory + */ + public Factory(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + /** + * Creates a new instance of the MLInferenceSearchResponseProcessor. + * + * @param processorFactories a map of processor factories + * @param processorTag the tag of the processor + * @param description the description of the processor + * @param ignoreFailure a flag indicating whether to ignore failures or not + * @param config the configuration map for the processor + * @param pipelineContext the pipeline context + * @return a new instance of the MLInferenceSearchResponseProcessor + */ + @Override + public MLInferenceSearchResponseProcessor create( + Map> processorFactories, + String processorTag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) { + String modelId = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, MODEL_ID); + Map modelConfigInput = ConfigurationUtils.readOptionalMap(TYPE, processorTag, config, MODEL_CONFIG); + + List> inputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, INPUT_MAP); + List> outputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, OUTPUT_MAP); + int maxPredictionTask = ConfigurationUtils + .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); + boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false); + boolean oneToOne = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, ONE_TO_ONE, false); + + String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT); + + // if model input is not provided for remote models, use default value + if (functionName.equalsIgnoreCase("remote")) { + modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEL_INPUT; + } else if (modelInput == null) { + // if model input is not provided for local models, throw exception since it is mandatory here + throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor"); + } + boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name()); + boolean fullResponsePath = ConfigurationUtils + .readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath); + + ignoreFailure = ConfigurationUtils + .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); + + // convert model config user input data structure to Map + Map modelConfigMaps = null; + if (modelConfigInput != null) { + modelConfigMaps = StringUtils.getParameterMap(modelConfigInput); + } + // check if the number of prediction tasks exceeds max prediction tasks + if (inputMaps != null && inputMaps.size() > maxPredictionTask) { + throw new IllegalArgumentException( + "The number of prediction task setting in this process is " + + inputMaps.size() + + ". It exceeds the max_prediction_tasks of " + + maxPredictionTask + + ". Please reduce the size of input_map or increase max_prediction_tasks." + ); + } + if (outputMaps != null && inputMaps != null && outputMaps.size() != inputMaps.size()) { + throw new IllegalArgumentException( + "when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of " + + inputMaps.size() + + ", while output_maps is in the length of " + + outputMaps.size() + + ". Please adjust mappings." + ); + } + + return new MLInferenceSearchResponseProcessor( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask, + processorTag, + description, + ignoreMissing, + functionName, + fullResponsePath, + ignoreFailure, + override, + modelInput, + client, + xContentRegistry, + oneToOne + ); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index b922cd8819..cf17afd904 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -281,6 +281,15 @@ default String toString(Object originalFieldValue) { return StringUtils.toJson(originalFieldValue); } + default boolean hasField(Object json, String path) { + Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path); + + if (value != null) { + return true; + } + return false; + } + /** * Writes a new dot path for a nested object within the given JSON object. * This method is useful when dealing with arrays or nested objects in the JSON structure. @@ -321,5 +330,4 @@ default List writeNewDotPathForNestedObject(Object json, String dotPath) default String convertToDotPath(String path) { return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); } - } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java new file mode 100644 index 0000000000..bc1deb085a --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils; + +import java.util.HashMap; +import java.util.Map; + +public class MapUtils { + + /** + * Increments the counter for the given key in the specified version. + * If the key doesn't exist, it initializes the counter to 0. + * + * @param version the version of the counter + * @param key the key for which the counter needs to be incremented + */ + public static void incrementCounter(Map> versionedCounters, int version, String key) { + Map counters = versionedCounters.computeIfAbsent(version, k -> new HashMap<>()); + counters.put(key, counters.getOrDefault(key, -1) + 1); + } + + /** + * Retrieves the counter value for the given key in the specified version. + * If the key doesn't exist, it returns 0. + * + * @param version the version of the counter + * @param key the key for which the counter needs to be retrieved + * @return the counter value for the given key + */ + public static int getCounter(Map> versionedCounters, int version, String key) { + Map counters = versionedCounters.get(version); + return counters != null ? counters.getOrDefault(key, -1) : 0; + } + + /** + * Increments the counter value for the given key in the provided counters map. + * If the key does not exist in the map, it is added with an initial counter value of 0. + * + * @param counters A map that stores integer counters for each integer key. + * @param key The integer key for which the counter needs to be incremented. + */ + public static void incrementCounter(Map counters, int key) { + counters.put(key, counters.getOrDefault(key, 0) + 1); + } + + public static int getCounter(Map counters, int key) { + return counters.getOrDefault(key, 0); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/SearchResponseUtil.java b/plugin/src/main/java/org/opensearch/ml/utils/SearchResponseUtil.java new file mode 100644 index 0000000000..44f7ad294b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/SearchResponseUtil.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; + +public class SearchResponseUtil { + private SearchResponseUtil() {} + + /** + * Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}. + * @param newHits new {@link SearchHits} + * @param response the existing search response + * @return a new search response where the {@link SearchHits} has been replaced + */ + public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) { + SearchResponseSections searchResponseSections; + if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) { + // We either have no aggregations, or we have Writeable InternalAggregations. + // Either way, we can produce a Writeable InternalSearchResponse. + searchResponseSections = new InternalSearchResponse( + newHits, + (InternalAggregations) response.getAggregations(), + response.getSuggest(), + new SearchProfileShardResults(response.getProfileResults()), + response.isTimedOut(), + response.isTerminatedEarly(), + response.getNumReducePhases() + ); + } else { + // We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable. + searchResponseSections = new SearchResponseSections( + newHits, + response.getAggregations(), + response.getSuggest(), + response.isTimedOut(), + response.isTerminatedEarly(), + new SearchProfileShardResults(response.getProfileResults()), + response.getNumReducePhases() + ); + } + + return new SearchResponse( + searchResponseSections, + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getTook().millis(), + response.getShardFailures(), + response.getClusters(), + response.pointInTimeId() + ); + } + + /** + * Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}. + * @param newHits the new array of {@link SearchHit} elements. + * @param response the search response to update + * @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced. + */ + public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) { + if (response.getHits() == null) { + throw new IllegalStateException("Response must have hits"); + } + SearchHits searchHits = new SearchHits( + newHits, + response.getHits().getTotalHits(), + response.getHits().getMaxScore(), + response.getHits().getSortFields(), + response.getHits().getCollapseField(), + response.getHits().getCollapseValues() + ); + return replaceHits(searchHits, response); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index ab6ab07739..6da9cb406a 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -39,6 +39,7 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; +import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -85,10 +86,11 @@ public void testGetRequestProcessors() { public void testGetResponseProcessors() { SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); Map responseProcessors = plugin.getResponseProcessors(parameters); - assertEquals(1, responseProcessors.size()); + assertEquals(2, responseProcessors.size()); assertTrue( responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory ); + assertTrue(responseProcessors.get(MLInferenceSearchResponseProcessor.TYPE) instanceof MLInferenceSearchResponseProcessor.Factory); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java new file mode 100644 index 0000000000..7d38597751 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -0,0 +1,1391 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.Client; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.AbstractBuilderTestCase; + +public class MLInferenceSearchResponseProcessorTests extends AbstractBuilderTestCase { + @Mock + private Client client; + + @Mock + private PipelineProcessingContext responseContext; + + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + private static final String PROCESSOR_TAG = "inference"; + private static final String DESCRIPTION = "inference_test"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + } + + /** + * Tests that an exception is thrown when the `processResponse` method is called, as this processor + * makes asynchronous calls. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseException() throws Exception { + + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + null, + null, + null, + null, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + try { + responseProcessor.processResponse(request, response); + + } catch (Exception e) { + assertEquals("ML inference search response processor make asynchronous calls and does not call processRequest", e.getMessage()); + } + } + + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseSuccess() throws Exception { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests create processor with many_to_one is false + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneToOneException() throws Exception { + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("one to one prediction is not supported yet.", e.getMessage()); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + + } + + /** + * Tests the successful processing of a response without any input-output mappings. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseNoMappingSuccess() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response without any input-output mappings. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseEmptyMappingSuccess() throws Exception { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + inputMap.add(input); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(DEFAULT_OUTPUT_FIELD_NAME).toString(), + "[{output=[{dataAsMap={response=[0.0, 1.0, 2.0, 3.0, 4.0]}}]}]" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response with a list of embeddings as the output. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "text" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text" : "value 0", "text_embedding":[0.1, 0.2]}, + * { "text" : "value 1", "text_embedding":[0.2, 0.2]}, + * { "text" : "value 2", "text_embedding":[0.3, 0.2]}, + * { "text" : "value 3","text_embedding":[0.4, 0.2]}, + * { "text" : "value 4","text_embedding":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays + .asList( + Arrays.asList(0.1, 0.2), + Arrays.asList(0.2, 0.2), + Arrays.asList(0.3, 0.2), + Arrays.asList(0.4, 0.2), + Arrays.asList(0.5, 0.2) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), Arrays.asList(0.1, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), Arrays.asList(0.2, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), Arrays.asList(0.3, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), Arrays.asList(0.4, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), Arrays.asList(0.5, 0.2)); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response where the existing document field is overridden. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOverrideSameField() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "text" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text":[0.1, 0.2]}, + * { "text":[0.2, 0.2]}, + * { "text":[0.3, 0.2]}, + * { "text":[0.4, 0.2]}, + * { "text":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + true, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays + .asList( + Arrays.asList(0.1, 0.2), + Arrays.asList(0.2, 0.2), + Arrays.asList(0.3, 0.2), + Arrays.asList(0.4, 0.2), + Arrays.asList(0.5, 0.2) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text"), Arrays.asList(0.1, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text"), Arrays.asList(0.2, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get("text"), Arrays.asList(0.3, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text"), Arrays.asList(0.4, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text"), Arrays.asList(0.5, 0.2)); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response where one input field is missing, + * and the `ignoreMissing` flag is set to true. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSuccess() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "textMissing" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text" : "value 0", "text_embedding":[0.1, 0.2]}, + * { "text" : "value 1", "text_embedding":[0.2, 0.2]}, + * { "textMissing" : "value 2"}, + * { "text" : "value 3","text_embedding":[0.4, 0.2]}, + * { "text" : "value 4","text_embedding":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + true + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponseMissingField(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays.asList(Arrays.asList(0.1, 0.2), Arrays.asList(0.2, 0.2), Arrays.asList(0.4, 0.2), Arrays.asList(0.5, 0.2)) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), Arrays.asList(0.1, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), Arrays.asList(0.2, 0.2)); + + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), Arrays.asList(0.4, 0.2)); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), Arrays.asList(0.5, 0.2)); + + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where one input field is missing, and an exception is expected + * when the `ignoreMissing` flag is set to false. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseListOfEmbeddingsMissingOneInputException() throws Exception { + /** + * sample response before inference + * { + * { "text" : "value 0" }, + * { "text" : "value 1" }, + * { "textMissing" : "value 2" }, + * { "text" : "value 3" }, + * { "text" : "value 4" } + * } + * + * sample response after inference + * { "text" : "value 0", "text_embedding":[0.1, 0.2]}, + * { "text" : "value 1", "text_embedding":[0.2, 0.2]}, + * { "textMissing" : "value 2"}, + * { "text" : "value 3","text_embedding":[0.4, 0.2]}, + * { "text" : "value 4","text_embedding":[0.5, 0.2]} + */ + + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponseMissingField(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "response", + Arrays.asList(Arrays.asList(0.1, 0.2), Arrays.asList(0.2, 0.2), Arrays.asList(0.4, 0.2), Arrays.asList(0.5, 0.2)) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals( + "cannot find all required input fields: [text] in hit:{\n" + + " \"_id\" : \"doc 2\",\n" + + " \"_score\" : 2.0,\n" + + " \"_source\" : {\n" + + " \"textMissing\" : \"value 2\"\n" + + " }\n" + + "}", + e.getMessage() + ); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response with two rounds of prediction. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseTwoRoundsOfPredictionSuccess() throws Exception { + String modelInputField = "inputs"; + String modelOutputField = "response"; + + // document fields for first round of prediction + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + + // document fields for second round of prediction + String originalDocumentField1 = "image"; + String newDocumentField1 = "image_embedding"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + Map input1 = new HashMap<>(); + input1.put(modelInputField, originalDocumentField1); + inputMap.add(input1); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + + Map output2 = new HashMap<>(); + output2.put(newDocumentField1, modelOutputField); + outputMap.add(output2); + + Map modelConfig = new HashMap<>(); + modelConfig.put("model_task_type", "TEXT_EMBEDDING"); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchRequest request = getSearchRequest(); + SearchResponse response = getSearchResponseTwoFields(5, true, originalDocumentField, originalDocumentField1); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the successful processing of a response with one model input and multiple model outputs. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseOneModelInputMultipleModelOutputs() throws Exception { + // one model input + String modelInputField = "inputs"; + String originalDocumentField = "text"; + + // two model outputs + String modelOutputField = "response"; + String newDocumentField = "text_embedding"; + String modelOutputField1 = "response_type"; + String newDocumentField1 = "embedding_type"; + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + output.put(newDocumentField1, modelOutputField1); + outputMap.add(output); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchRequest request = getSearchRequest(); + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of(modelOutputField, Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0), "response_type", "embedding_float")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField1), "embedding_float"); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField1), "embedding_float"); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where an exception occurs during prediction, and the `ignoreFailure` flag is set to false. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponsePredictionException() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly."); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Prediction Failed", e.getMessage()); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where an exception occurs during prediction, but the `ignoreFailure` flag is set to true. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponsePredictionExceptionIgnoreFailure() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Prediction Failed")); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(response, newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly."); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + /** + * Tests the case where there are no hits in the search response. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseEmptyHit() throws Exception { + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(0, true, fieldName); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(response, newSearchResponse); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + } + + private static SearchRequest getSearchRequest() { + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + return request; + } + + /** + * Helper method to create an instance of the MLInferenceSearchResponseProcessor with the specified parameters in + * single pair of input and output mapping. + * + * @param modelInputField the model input field name + * @param originalDocumentField the original query field name + * @param newDocumentField the new document field name + * @param override the flag indicating whether to override existing document field + * @param ignoreFailure the flag indicating whether to ignore failures or not + * @param ignoreMissing the flag indicating whether to ignore missing fields or not + * @return an instance of the MLInferenceSearchResponseProcessor + */ + private MLInferenceSearchResponseProcessor getMlInferenceSearchResponseProcessorSinglePairMapping( + String modelOutputField, + String modelInputField, + String originalDocumentField, + String newDocumentField, + boolean override, + boolean ignoreFailure, + boolean ignoreMissing + ) { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map model_config = new HashMap<>(); + model_config.put("truncate_result", "false"); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + model_config, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + ignoreMissing, + "remote", + false, + ignoreFailure, + override, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + return responseProcessor; + } + + private SearchResponse getSearchResponse(int size, boolean includeMapping, String fieldName) { + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + Map searchHitFields = new HashMap<>(); + if (includeMapping) { + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + } + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); + hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "\" : \"value " + i + "\" }")); + hits[i].score(i); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + private SearchResponse getSearchResponseMissingField(int size, boolean includeMapping, String fieldName) { + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + + if (i == (size % 3)) { + Map searchHitFields = new HashMap<>(); + if (includeMapping) { + searchHitFields.put(fieldName + "Missing", new DocumentField("value " + i, Collections.emptyList())); + } + searchHitFields.put(fieldName + "Missing", new DocumentField("value " + i, Collections.emptyList())); + hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); + hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "Missing" + "\" : \"value " + i + "\" }")); + hits[i].score(i); + } else { + Map searchHitFields = new HashMap<>(); + if (includeMapping) { + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + } + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); + hits[i].sourceRef(new BytesArray("{ \"" + fieldName + "\" : \"value " + i + "\" }")); + hits[i].score(i); + } + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + private SearchResponse getSearchResponseTwoFields(int size, boolean includeMapping, String fieldName, String fieldName1) { + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + Map searchHitFields = new HashMap<>(); + if (includeMapping) { + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + searchHitFields.put(fieldName1, new DocumentField("value " + i, Collections.emptyList())); + } + searchHitFields.put(fieldName, new DocumentField("value " + i, Collections.emptyList())); + searchHitFields.put(fieldName1, new DocumentField("value " + i, Collections.emptyList())); + hits[i] = new SearchHit(i, "doc " + i, searchHitFields, Collections.emptyMap()); + hits[i] + .sourceRef( + new BytesArray( + "{ \"" + fieldName + "\" : \"value " + i + "\", " + "\"" + fieldName1 + "\" : \"value " + i + "\" " + " }" + ) + ); + hits[i].score(i); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2L, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + private MLInferenceSearchResponseProcessor.Factory factory; + + @Mock + private NamedXContentRegistry xContentRegistry; + + @Before + public void init() { + factory = new MLInferenceSearchResponseProcessor.Factory(client, xContentRegistry); + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor with required fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateRequiredFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor for a local model. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessor() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } + + /** + * The model input field is required for using a local model + * when missing the model input field, expected to throw Exceptions + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessorMissingModelInputField() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + try { + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + } catch (Exception e) { + assertEquals(e.getMessage(), "Please provide model input when using a local model in ML Inference Processor"); + } + } + + /** + * Tests the case where the `model_id` field is missing in the configuration, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateNoFieldPresent() throws Exception { + Map config = new HashMap<>(); + try { + factory.create(Collections.emptyMap(), "no field", null, false, config, null); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[model_id] required property is missing")); + } + } + + /** + * Tests the case where the number of prediction tasks exceeds the maximum allowed value, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExceedMaxPredictionTasks() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + Map input2 = new HashMap<>(); + input2.put("inputs", "timestamp"); + inputMap.add(input2); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + ("The number of prediction task setting in this process is 3. It exceeds the max_prediction_tasks of 2. Please reduce the size of input_map or increase max_prediction_tasks.") + ); + } + } + + /** + * Tests the case where the length of the `output_map` list is greater than the length of the `input_map` list, + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testOutputMapsExceedInputMaps() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("hashtag_embedding", "response"); + outputMap.add(output2); + Map output3 = new HashMap<>(); + output2.put("hashtvg_embedding", "response"); + outputMap.add(output3); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of 2, while output_maps is in the length of 3. Please adjust mappings." + ); + + } + } + + /** + * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateOptionalFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + Map model_config = new HashMap<>(); + model_config.put("hidden_size", 768); + model_config.put("gradient_checkpointing", false); + model_config.put("position_embedding_type", "absolute"); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + + MLInferenceSearchResponseProcessor MLInferenceSearchResponseProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchResponseProcessor); + assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java new file mode 100644 index 0000000000..6bc7a54e9d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -0,0 +1,420 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; + +public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestCase { + + private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + private String openAIChatModelId; + private String bedrockEmbeddingModelId; + private String localModelId; + private final String completionModelConnectorEntity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private final String bedrockEmbeddingModelConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" + + " \"description\": \"The connector to bedrock Titan embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + + /** + * Registers two remote models and creates an index and documents before running the tests. + * + * @throws Exception if any error occurs during the setup + */ + @Before + public void setup() throws Exception { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5); + this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); + + String index_name = "daily_index"; + String createIndexRequestBody = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"diary_embedding_size\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + String uploadDocumentRequestBodyDoc1 = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"diary_embedding_size\": \"1536\",\n" // embedding size for ada model + + " \"weather\": \"sunny\"\n" + + " }"; + + String uploadDocumentRequestBodyDoc2 = "{\n" + + " \"id\": 2,\n" + + " \"diary\": [\"bored\",\"at home\"],\n" + + " \"diary_embedding_size\": \"768\",\n" // embedding size for local text embedding model + + " \"weather\": \"sunny\"\n" + + " }"; + + createIndex(index_name, createIndexRequestBody); + uploadDocument(index_name, "1", uploadDocumentRequestBodyDoc1); + uploadDocument(index_name, "2", uploadDocumentRequestBodyDoc2); + } + + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + } + + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and a nested list field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelNestedListField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"diary[0]\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"dairy_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "diary_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.dairy_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), -0.011842756, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.012508746, 0.005); + } + + /** + * Tests the ML inference processor with a local model. + * It registers, deploys, and gets a local model, creates a search pipeline with the ML inference processor + * configured to use the local model, and then performs a search using the pipeline. + * The test verifies that the query string is rewritten correctly based on the inference results + * from the local model. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorLocalModel() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"$.inference_results[0].output[0].data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline_local"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"768\"}}}}"; + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + System.out.println(response); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "768"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "bored"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "at home"); + + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 768); + Assert.assertEquals((Double) embeddingList.get(0), 0.54809606, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), 0.46797526, 0.005); + + } + + /** + * Creates a search pipeline processor with the given request body and pipeline name. + * + * @param requestBody the request body for creating the search pipeline processor + * @param pipelineName the name of the search pipeline + * @throws Exception if any error occurs during the creation of the search pipeline processor + */ + protected void createSearchPipelineProcessor(String requestBody, final String pipelineName) throws Exception { + Response pipelineCreateResponse = TestHelper + .makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, pipelineCreateResponse.getStatusLine().getStatusCode()); + + } + + /** + * Creates an index with the given name and request body. + * + * @param indexName the name of the index + * @param requestBody the request body for creating the index + * @throws Exception if any error occurs during the creation of the index + */ + protected void createIndex(String indexName, String requestBody) throws Exception { + Response response = makeRequest( + client(), + "PUT", + indexName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + /** + * Uploads a document to the specified index with the given document ID and JSON body. + * + * @param index the name of the index + * @param docId the document ID + * @param jsonBody the JSON body of the document + * @throws IOException if an I/O error occurs during the upload + */ + protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { + Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(jsonBody); + client().performRequest(request); + } + + /** + * Creates a MLRegisterModelInput instance with the specified configuration. + * + * @return the MLRegisterModelInput instance + * @throws IOException if an I/O error occurs during the creation of the input + * @throws InterruptedException if the thread is interrupted during the creation of the input + */ + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java new file mode 100644 index 0000000000..d133e00657 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/MapUtilsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.utils; + +import static org.junit.Assert.assertEquals; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +public class MapUtilsTests { + + @Test + public void testIncrementCounterForVersionedCounters() { + Map> versionedCounters = new HashMap<>(); + + MapUtils.incrementCounter(versionedCounters, 0, "key1"); + assertEquals(0, (int) versionedCounters.get(0).get("key1")); + + // Test incrementing counter for an existing version and key + MapUtils.incrementCounter(versionedCounters, 0, "key1"); + assertEquals(1, (int) versionedCounters.get(0).get("key1")); + + // Test incrementing counter for a new key in an existing version + MapUtils.incrementCounter(versionedCounters, 0, "key2"); + assertEquals(0, (int) versionedCounters.get(0).get("key2")); + + // Test incrementing counter for a new version + MapUtils.incrementCounter(versionedCounters, 1, "key3"); + assertEquals(0, (int) versionedCounters.get(1).get("key3")); + } + + @Test + public void testIncrementCounterForIntegerCounters() { + Map counters = new HashMap<>(); + + // Test incrementing counter for a new key + MapUtils.incrementCounter(counters, 1); + assertEquals(1, (int) counters.get(1)); + + // Test incrementing counter for an existing key + MapUtils.incrementCounter(counters, 1); + assertEquals(2, (int) counters.get(1)); + + // Test incrementing counter for a new key + MapUtils.incrementCounter(counters, 2); + assertEquals(1, (int) counters.get(2)); + } + + @Test + public void testGetCounterForVersionedCounters() { + Map> versionedCounters = new HashMap<>(); + versionedCounters.put(0, new HashMap<>()); + versionedCounters.put(1, new HashMap<>()); + versionedCounters.get(0).put("key1", 5); + versionedCounters.get(1).put("key2", 10); + + // Test getting counter for an existing key + assertEquals(5, MapUtils.getCounter(versionedCounters, 0, "key1")); + assertEquals(10, MapUtils.getCounter(versionedCounters, 1, "key2")); + + // Test getting counter for a non-existing key + assertEquals(-1, MapUtils.getCounter(versionedCounters, 0, "key3")); + assertEquals(0, MapUtils.getCounter(versionedCounters, 2, "key4")); + } + + @Test + public void testGetCounterForIntegerCounters() { + Map counters = new HashMap<>(); + counters.put(1, 5); + counters.put(2, 10); + + // Test getting counter for an existing key + assertEquals(5, MapUtils.getCounter(counters, 1)); + assertEquals(10, MapUtils.getCounter(counters, 2)); + + // Test getting counter for a non-existing key + assertEquals(0, MapUtils.getCounter(counters, 3)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/SearchResponseUtilTests.java b/plugin/src/test/java/org/opensearch/ml/utils/SearchResponseUtilTests.java new file mode 100644 index 0000000000..bf877ebc76 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/utils/SearchResponseUtilTests.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; +import org.junit.Test; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; + +public class SearchResponseUtilTests { + + @Test + public void testPrivateConstructor() throws NoSuchMethodException, + IllegalAccessException, + InvocationTargetException, + InstantiationException { + Constructor constructor = SearchResponseUtil.class.getDeclaredConstructor(); + constructor.setAccessible(true); + SearchResponseUtil instance = constructor.newInstance(); + assertNotNull(instance); + } + + @Test + public void testReplaceHits() { + SearchHit[] originalHits = new SearchHit[10]; + SearchHits originalSearchHits = new SearchHits(originalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0.5f); + SearchResponse originalResponse = new SearchResponse( + new InternalSearchResponse( + originalSearchHits, + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + SearchHit[] newHits = new SearchHit[10]; + + SearchResponse newResponse = SearchResponseUtil.replaceHits(newHits, originalResponse); + + assertNotNull(newResponse); + assertEquals(newHits.length, newResponse.getHits().getHits().length); + } + + @Test + public void testReplaceHitsWithSearchHits() throws IOException { + // Arrange + SearchHit[] originalHits = new SearchHit[10]; + SearchHits originalSearchHits = new SearchHits(originalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0.5f); + SearchResponse originalResponse = new SearchResponse( + new InternalSearchResponse( + originalSearchHits, + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + SearchHit[] newHits = new SearchHit[15]; + SearchHits newSearchHits = new SearchHits(newHits, new TotalHits(15, TotalHits.Relation.EQUAL_TO), 0.7f); + + SearchResponse newResponse = SearchResponseUtil.replaceHits(newSearchHits, originalResponse); + + assertNotNull(newResponse); + assertEquals(newHits.length, newResponse.getHits().getHits().length); + assertEquals(15, newResponse.getHits().getTotalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, newResponse.getHits().getTotalHits().relation); + assertEquals(0.7f, newResponse.getHits().getMaxScore(), 0.0001f); + } + + @Test + public void testReplaceHitsWithNonWriteableAggregations() { + SearchHit[] originalHits = new SearchHit[10]; + SearchHits originalSearchHits = new SearchHits(originalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0.5f); + + Aggregations nonWriteableAggregations = mock(Aggregations.class); + SearchResponse originalResponse = mock(SearchResponse.class); + when(originalResponse.getHits()).thenReturn(originalSearchHits); + when(originalResponse.getAggregations()).thenReturn(nonWriteableAggregations); + when(originalResponse.getSuggest()).thenReturn(new Suggest(Collections.emptyList())); + when(originalResponse.isTimedOut()).thenReturn(false); + when(originalResponse.isTerminatedEarly()).thenReturn(false); + when(originalResponse.getProfileResults()).thenReturn(Collections.emptyMap()); + when(originalResponse.getNumReducePhases()).thenReturn(1); + when(originalResponse.getTook()).thenReturn(new TimeValue(100)); + SearchHit[] newHits = new SearchHit[15]; + + SearchResponse newResponse = SearchResponseUtil + .replaceHits(new SearchHits(newHits, new TotalHits(15, TotalHits.Relation.EQUAL_TO), 0.7f), originalResponse); + + assertNotNull(newResponse); + assertEquals(newHits.length, newResponse.getHits().getHits().length); + assertEquals(15, newResponse.getHits().getTotalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, newResponse.getHits().getTotalHits().relation); + assertEquals(0.7f, newResponse.getHits().getMaxScore(), 0.0001f); + } + + @Test + public void testReplaceHitsWithNoHits() { + SearchResponse originalResponse = mock(SearchResponse.class); + when(originalResponse.getHits()).thenReturn(null); + + assertThrows(IllegalStateException.class, () -> SearchResponseUtil.replaceHits(new SearchHit[0], originalResponse)); + } +}