From cc786f70b7ef57a4e20b00534c539c6fb6943b1e Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Wed, 24 Jul 2024 09:54:58 -0700 Subject: [PATCH] Add initial search request inference processor (#2616) * add initial search request inference processor Signed-off-by: Mingshi Liu * Add ITs for MLInferenceSearchRequestProcessor Signed-off-by: Mingshi Liu * skip running OPENAI when key is not present and fix yaml test issue Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- .../ml/plugin/MachineLearningPlugin.java | 7 +- .../MLInferenceSearchRequestProcessor.java | 578 ++++++++ .../ml/processor/ModelExecutor.java | 8 +- .../ml/plugin/MachineLearningPluginTests.java | 4 +- ...LInferenceSearchRequestProcessorTests.java | 1290 +++++++++++++++++ .../RestMLInferenceIngestProcessorIT.java | 2 +- ...stMLInferenceSearchRequestProcessorIT.java | 380 +++++ .../org/opensearch/ml/utils/TestData.java | 2 + .../30_inference_search_request_processor.yml | 41 + 9 files changed, 2308 insertions(+), 4 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java create mode 100644 plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java create mode 100644 plugin/src/yamlRestTest/resources/rest-api-spec/test/30_inference_search_request_processor.yml diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 470922d58f..7e9ab6d940 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -213,6 +213,7 @@ import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.processor.MLInferenceIngestProcessor; +import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; @@ -977,7 +978,11 @@ public Map> getRequestProcesso GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory(() -> this.ragSearchPipelineEnabled) ); - + requestProcessors + .put( + MLInferenceSearchRequestProcessor.TYPE, + new MLInferenceSearchRequestProcessor.Factory(parameters.client, parameters.namedXContentRegistry) + ); return requestProcessors; } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java new file mode 100644 index 0000000000..9c97091057 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -0,0 +1,578 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.text.StringSubstitutor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; + +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.ReadContext; + +/** + * MLInferenceSearchRequestProcessor requires a modelId string to call model inferences + * maps fields from query string for model input, and maps model inference output to the query strings or query template + * this processor also handles dot path notation for nested object( map of array) by rewriting json path accordingly + */ +public class MLInferenceSearchRequestProcessor extends AbstractProcessor implements SearchRequestProcessor, ModelExecutor { + private final NamedXContentRegistry xContentRegistry; + private static final Logger logger = LogManager.getLogger(MLInferenceSearchRequestProcessor.class); + private final InferenceProcessorAttributes inferenceProcessorAttributes; + private final boolean ignoreMissing; + private final String functionName; + private String queryTemplate; + private final boolean fullResponsePath; + private final boolean ignoreFailure; + private final String modelInput; + private static Client client; + public static final String TYPE = "ml_inference"; + // allow to ignore a field from mapping is not present in the query, and when the output field is not found in the + // prediction outcomes, return the whole prediction outcome by skipping filtering + public static final String IGNORE_MISSING = "ignore_missing"; + public static final String QUERY_TEMPLATE = "query_template"; + public static final String FUNCTION_NAME = "function_name"; + public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; + public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }"; + // At default, ml inference processor allows maximum 10 prediction tasks running in parallel + // it can be overwritten using max_prediction_tasks when creating processor + public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + + protected MLInferenceSearchRequestProcessor( + String modelId, + String queryTemplate, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, + int maxPredictionTask, + String tag, + String description, + boolean ignoreMissing, + String functionName, + boolean fullResponsePath, + boolean ignoreFailure, + String modelInput, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(tag, description, ignoreFailure); + this.inferenceProcessorAttributes = new InferenceProcessorAttributes( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask + ); + this.ignoreMissing = ignoreMissing; + this.functionName = functionName; + this.fullResponsePath = fullResponsePath; + this.queryTemplate = queryTemplate; + this.ignoreFailure = ignoreFailure; + this.modelInput = modelInput; + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + /** + * Process a SearchRequest without receiving request-scoped state. + * Implement this method if the processor makes no asynchronous calls. + * But this processor is going to make asynchronous calls. + * @param request the search request (which may have been modified by an earlier processor) + * @return the modified search request + * @throws Exception implementation-specific processing exception + */ + + @Override + public SearchRequest processRequest(SearchRequest request) throws Exception { + throw new RuntimeException("ML inference search request processor make asynchronous calls and does not call processRequest"); + } + + /** + * Transform a {@link SearchRequest}. + * Make one or more predictions, rewrite query in a search request + * Executed on the coordinator node before any {@link org.opensearch.action.search.SearchPhase} + * executes. + *

+ * Expert method: Implement this if the processor needs to make asynchronous calls. Otherwise, implement processRequest. + * @param request the executed {@link SearchRequest} + * @param requestListener callback to be invoked on successful processing or on failure + */ + @Override + public void processRequestAsync( + SearchRequest request, + PipelineProcessingContext requestContext, + ActionListener requestListener + ) { + + try { + if (request.source() == null) { + throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request."); + } + + String queryString = request.source().toString(); + + rewriteQueryString(request, queryString, requestListener); + + } catch (Exception e) { + if (ignoreFailure) { + requestListener.onResponse(request); + } else { + requestListener.onFailure(e); + } + } + } + + /** + * Rewrites the query string based on the input and output mappings and the ML model output. + * + * @param request the {@link SearchRequest} to be rewritten + * @param queryString the original query string + * @param requestListener the {@link ActionListener} to be notified when the rewriting is complete + * @throws IOException if an I/O error occurs during the rewriting process + */ + private void rewriteQueryString(SearchRequest request, String queryString, ActionListener requestListener) + throws IOException { + List> processInputMap = inferenceProcessorAttributes.getInputMaps(); + List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); + int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0; + + if (inputMapSize == 0) { + requestListener.onResponse(request); + return; + } + + try { + if (!validateQueryFieldInQueryString(processInputMap, processOutputMap, queryString)) { + requestListener.onResponse(request); + } + } catch (Exception e) { + if (ignoreMissing) { + requestListener.onResponse(request); + return; + } else { + requestListener.onFailure(e); + return; + } + } + + ActionListener> rewriteRequestListener = createRewriteRequestListener( + request, + queryString, + requestListener, + processOutputMap + ); + GroupedActionListener> batchPredictionListener = createBatchPredictionListener( + rewriteRequestListener, + inputMapSize + ); + + for (int inputMapIndex = 0; inputMapIndex < inputMapSize; inputMapIndex++) { + processPredictions(queryString, processInputMap, inputMapIndex, batchPredictionListener); + } + + } + + /** + * Creates an {@link ActionListener} that handles the response from the ML model inference and updates the query string + * or query template with the model output. + * + * @param request the {@link SearchRequest} to be updated + * @param queryString the original query string + * @param requestListener the {@link ActionListener} to be notified when the query string or query template is updated + * @param processOutputMap the list of output mappings + * @return an {@link ActionListener} that handles the response from the ML model inference + */ + private ActionListener> createRewriteRequestListener( + SearchRequest request, + String queryString, + ActionListener requestListener, + List> processOutputMap + ) { + return new ActionListener<>() { + @Override + public void onResponse(Map multipleMLOutputs) { + for (Map.Entry entry : multipleMLOutputs.entrySet()) { + Integer mappingIndex = entry.getKey(); + MLOutput mlOutput = entry.getValue(); + Map outputMapping = processOutputMap.get(mappingIndex); + try { + if (queryTemplate == null) { + Object incomeQueryObject = JsonPath.parse(queryString).read("$"); + updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder( + xContentRegistry, + StringUtils.toJson(incomeQueryObject) + ); + request.source(searchSourceBuilder); + requestListener.onResponse(request); + } else { + String newQueryString = updateQueryTemplate(queryTemplate, outputMapping, mlOutput); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(xContentRegistry, newQueryString); + request.source(searchSourceBuilder); + requestListener.onResponse(request); + } + } catch (Exception e) { + if (ignoreMissing || ignoreFailure) { + logger.error("Failed in writing prediction outcomes to new query", e); + requestListener.onResponse(request); + + } else { + requestListener.onFailure(e); + } + } + } + } + + @Override + public void onFailure(Exception e) { + if (ignoreFailure) { + logger.error("Failed in writing prediction outcomes to new query", e); + requestListener.onResponse(request); + + } else { + requestListener.onFailure(e); + } + } + + private void updateIncomeQueryObject(Object incomeQueryObject, Map outputMapping, MLOutput mlOutput) { + for (Map.Entry outputMapEntry : outputMapping.entrySet()) { + String newQueryField = outputMapEntry.getKey(); + String modelOutputFieldName = outputMapEntry.getValue(); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + String jsonPathExpression = "$." + newQueryField; + JsonPath.parse(incomeQueryObject).set(jsonPathExpression, modelOutputValue); + } + } + + private String updateQueryTemplate(String queryTemplate, Map outputMapping, MLOutput mlOutput) { + Map valuesMap = new HashMap<>(); + for (Map.Entry outputMapEntry : outputMapping.entrySet()) { + String newQueryField = outputMapEntry.getKey(); + String modelOutputFieldName = outputMapEntry.getValue(); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + valuesMap.put(newQueryField, modelOutputValue); + } + StringSubstitutor sub = new StringSubstitutor(valuesMap); + return sub.replace(queryTemplate); + } + }; + } + + /** + * Creates a {@link GroupedActionListener} that collects the responses from multiple ML model inferences. + * + * @param rewriteRequestListner the {@link ActionListener} to be notified when all ML model inferences are complete + * @param inputMapSize the number of input mappings + * @return a {@link GroupedActionListener} that handles the responses from multiple ML model inferences + */ + private GroupedActionListener> createBatchPredictionListener( + ActionListener> rewriteRequestListner, + int inputMapSize + ) { + return new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection> mlOutputMapCollection) { + Map mlOutputMaps = new HashMap<>(); + for (Map mlOutputMap : mlOutputMapCollection) { + mlOutputMaps.putAll(mlOutputMap); + } + rewriteRequestListner.onResponse(mlOutputMaps); + } + + @Override + public void onFailure(Exception e) { + logger.error("Prediction Failed:", e); + rewriteRequestListner.onFailure(e); + } + }, Math.max(inputMapSize, 1)); + } + + /** + * Validates that the query fields specified in the input and output mappings exist in the query string. + * + * @param processInputMap the list of input mappings + * @param processOutputMap the list of output mappings + * @param queryString the query string to be validated + * @return true if all query fields exist in the query string, false otherwise + */ + private boolean validateQueryFieldInQueryString( + List> processInputMap, + List> processOutputMap, + String queryString + ) { + // Suppress errors thrown by JsonPath and instead return null if a path does not exist in a JSON blob. + Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(Option.SUPPRESS_EXCEPTIONS); + ReadContext jsonData = JsonPath.using(suppressExceptionConfiguration).parse(queryString); + + // check all values if exists in query + for (Map inputMap : processInputMap) { + for (Map.Entry entry : inputMap.entrySet()) { + // the inputMap takes in model input as keys and query fields as value + String queryField = entry.getValue(); + String pathData = jsonData.read(queryField); + if (pathData == null) { + throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString()); + } + } + } + if (queryTemplate == null) { + for (Map outputMap : processOutputMap) { + for (Map.Entry entry : outputMap.entrySet()) { + String queryField = entry.getKey(); + String pathData = jsonData.read(queryField); + if (pathData == null) { + throw new IllegalArgumentException( + "cannot find field: " + queryField + " in query string: " + jsonData.jsonString() + ); + } + } + } + } + return true; + + } + + /** + * Processes the ML model inference for a given input mapping index. + * + * @param queryString the original query string + * @param processInputMap the list of input mappings + * @param inputMapIndex the index of the input mapping to be processed + * @param batchPredictionListener the {@link GroupedActionListener} to be notified when the ML model inference is complete + * @throws IOException if an I/O error occurs during the processing + */ + private void processPredictions( + String queryString, + List> processInputMap, + int inputMapIndex, + GroupedActionListener batchPredictionListener + ) throws IOException { + Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { + modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + Map inputMapping = new HashMap<>(); + + if (processInputMap != null) { + inputMapping = processInputMap.get(inputMapIndex); + Object newQuery = JsonPath.parse(queryString).read("$"); + for (Map.Entry entry : inputMapping.entrySet()) { + // model field as key, query field name as value + String modelInputFieldName = entry.getKey(); + String queryFieldName = entry.getValue(); + String queryFieldValue = JsonPath.parse(newQuery).read(queryFieldName); + modelParameters.put(modelInputFieldName, queryFieldValue); + } + } + + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + + ActionRequest request = getMLModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { + + @Override + public void onResponse(MLTaskResponse mlTaskResponse) { + MLOutput mlOutput = mlTaskResponse.getOutput(); + Map mlOutputMap = new HashMap<>(); + mlOutputMap.put(inputMapIndex, mlOutput); + batchPredictionListener.onResponse(mlOutputMap); + } + + @Override + public void onFailure(Exception e) { + batchPredictionListener.onFailure(e); + } + }); + + } + + /** + * Creates a SearchSourceBuilder instance from the given query string. + * + * @param xContentRegistry the XContentRegistry instance to be used for parsing + * @param queryString the query template string to be parsed + * @return a SearchSourceBuilder instance created from the query string + * @throws IOException if an I/O error occurs during parsing + */ + private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry xContentRegistry, String queryString) + throws IOException { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryString); + ensureExpectedToken(XContentParser.Token.START_OBJECT, queryParser.nextToken(), queryParser); + + searchSourceBuilder.parseXContent(queryParser); + return searchSourceBuilder; + } + + /** + * Returns the type of the processor. + * + * @return the type of the processor as a string + */ + @Override + public String getType() { + return TYPE; + } + + /** + * A factory class for creating instances of the MLInferenceSearchRequestProcessor. + * This class implements the Processor.Factory interface for creating SearchRequestProcessor instances. + */ + public static class Factory implements Processor.Factory { + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Constructs a new instance of the Factory class. + * + * @param client the Client instance to be used by the Factory + * @param xContentRegistry the xContentRegistry instance to be used by the Factory + */ + public Factory(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + /** + * Creates a new instance of the MLInferenceSearchRequestProcessor. + * + * @param processorFactories a map of processor factories + * @param processorTag the tag of the processor + * @param description the description of the processor + * @param ignoreFailure a flag indicating whether to ignore failures or not + * @param config the configuration map for the processor + * @param pipelineContext the pipeline context + * @return a new instance of the MLInferenceSearchRequestProcessor + */ + @Override + public MLInferenceSearchRequestProcessor create( + Map> processorFactories, + String processorTag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) { + String modelId = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, MODEL_ID); + String queryTemplate = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, QUERY_TEMPLATE); + Map modelConfigInput = ConfigurationUtils.readOptionalMap(TYPE, processorTag, config, MODEL_CONFIG); + + List> inputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, INPUT_MAP); + List> outputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, OUTPUT_MAP); + int maxPredictionTask = ConfigurationUtils + .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); + boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT); + + // if model input is not provided for remote models, use default value + if (functionName.equalsIgnoreCase("remote")) { + modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT; + } else if (modelInput == null) { + // if model input is not provided for local models, throw exception since it is mandatory here + throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor"); + } + boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name()); + boolean fullResponsePath = ConfigurationUtils + .readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath); + + ignoreFailure = ConfigurationUtils + .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); + + // convert model config user input data structure to Map + Map modelConfigMaps = null; + if (modelConfigInput != null) { + modelConfigMaps = StringUtils.getParameterMap(modelConfigInput); + } + // check if the number of prediction tasks exceeds max prediction tasks + if (inputMaps != null && inputMaps.size() > maxPredictionTask) { + throw new IllegalArgumentException( + "The number of prediction task setting in this process is " + + inputMaps.size() + + ". It exceeds the max_prediction_tasks of " + + maxPredictionTask + + ". Please reduce the size of input_map or increase max_prediction_tasks." + ); + } + + return new MLInferenceSearchRequestProcessor( + modelId, + queryTemplate, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask, + processorTag, + description, + ignoreMissing, + functionName, + fullResponsePath, + ignoreFailure, + modelInput, + client, + xContentRegistry + ); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index aa69a0b72e..9922a8a94b 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -189,7 +189,13 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam return modelTensorOutputMap; } else { try { - return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); + Object modelOutputValue = JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); + if (modelOutputValue == null) { + throw new IllegalArgumentException( + "model inference output cannot find such json path: " + modelOutputFieldName + " in " + modelTensorOutputMap + ); + } + return modelOutputValue; } catch (Exception e) { if (ignoreMissing) { return modelTensorOutputMap; diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index de9b040238..6da9cb406a 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -38,6 +38,7 @@ import org.opensearch.ml.common.spi.MLCommonsExtension; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.tools.MLModelTool; +import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SearchPipelinePlugin; @@ -74,10 +75,11 @@ public void testGetSearchExts() { public void testGetRequestProcessors() { SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); Map requestProcessors = plugin.getRequestProcessors(parameters); - assertEquals(1, requestProcessors.size()); + assertEquals(2, requestProcessors.size()); assertTrue( requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory ); + assertTrue(requestProcessors.get(MLInferenceSearchRequestProcessor.TYPE) instanceof MLInferenceSearchRequestProcessor.Factory); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java new file mode 100644 index 0000000000..b48b870fdc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -0,0 +1,1290 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.MLInferenceSearchRequestProcessor.*; +import static org.opensearch.ml.processor.MLInferenceSearchRequestProcessor.MODEL_INPUT; + +import java.util.*; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ingest.Processor; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.AbstractBuilderTestCase; + +public class MLInferenceSearchRequestProcessorTests extends AbstractBuilderTestCase { + + @Mock + private Client client; + + @Mock + private PipelineProcessingContext requestContext; + + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + private static final String PROCESSOR_TAG = "inference"; + private static final String DESCRIPTION = "inference_test"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + } + + /** + * Tests that an exception is thrown when the `processRequest` method is called, as this processor + * makes asynchronous calls and does not support synchronous processing. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessRequestException() throws Exception { + + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + try { + requestProcessor.processRequest(request); + + } catch (Exception e) { + assertEquals("ML inference search request processor make asynchronous calls and does not call processRequest", e.getMessage()); + } + } + + /** + * Tests the case where no input or output mappings are provided. The original search request + * should be returned without any modifications. + * + * @throws Exception if an error occurs during the test + */ + public void testProcessRequestAsyncWithNoMappings() throws Exception { + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + null, + null, + null, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, request.source().query()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a single string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteSingleStringTermQuerySuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "eng")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + QueryBuilder expectedQuery = new TermQueryBuilder("text", "eng"); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(expectedQuery, newSearchRequest.source().query()); + assertEquals(request.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of multiple string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteMultipleStringTermQuerySuccess() throws Exception { + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "car, truck, vehicle")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + /** + * example term query: {"query":{"term":{"text":{"value":"car, truck, vehicle","boost":1.0}}}} + */ + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + QueryBuilder expectedQuery = new TermQueryBuilder("text", "car, truck, vehicle"); + assertEquals(expectedQuery, newSearchRequest.source().query()); + assertEquals(request.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a double in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteDoubleQuerySuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", 0.123)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example term query: {"query":{"term":{"text":{"value":0.123,"boost":1.0}}}} + */ + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + QueryBuilder expectedQuery = new TermQueryBuilder("text", 0.123); + assertEquals(expectedQuery, newSearchRequest.source().query()); + assertEquals(request.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a term query to a range query based on the model output + * and the provided query template. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteStringFromTermQueryToRangeQuerySuccess() throws Exception { + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = "{\"query\":{\"range\":{\"text\":{\"gte\":${modelPredictionScore}}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "0.123")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":foo,"boost":1.0}}}} + */ + /** + * example output range query: {"query":{"range":{"text":{"gte":"2"}}}} + */ + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + RangeQueryBuilder expectedQuery = new RangeQueryBuilder("text"); + expectedQuery.from(0.123); + expectedQuery.includeLower(true); + assertEquals(expectedQuery, newSearchRequest.source().query()); + assertEquals(request.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a term query to a range query based on the model output + * and the provided query template, where the model output is a double value. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteDoubleFromTermQueryToRangeQuerySuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = "{\"query\":{\"range\":{\"text\":{\"gte\":${modelPredictionScore}}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", 0.123)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + /** + * example output range query: {"query":{"range":{"text":{"gte":0.123}}}} + */ + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + RangeQueryBuilder expectedQuery = new RangeQueryBuilder("text"); + expectedQuery.from(0.123); + expectedQuery.includeLower(true); + assertEquals(expectedQuery, newSearchRequest.source().query()); + assertEquals(request.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a term query to a geometry query based on the model output + * and the provided query template, where the model output is a list of coordinates. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteListFromTermQueryToGeometryQuerySuccess() throws Exception { + + String queryTemplate = "{\n" + + " \"query\": {\n" + + " \"geo_shape\" : {\n" + + " \"location\" : {\n" + + " \"shape\" : {\n" + + " \"type\" : \"Envelope\",\n" + + " \"coordinates\" : ${modelPredictionOutcome} \n" + + " },\n" + + " \"relation\" : \"intersects\"\n" + + " },\n" + + " \"ignore_unmapped\" : false,\n" + + " \"boost\" : 42.0\n" + + " }\n" + + " }\n" + + "}"; + + String expectedNewQueryString = "{\n" + + " \"query\": {\n" + + " \"geo_shape\" : {\n" + + " \"location\" : {\n" + + " \"shape\" : {\n" + + " \"type\" : \"Envelope\",\n" + + " \"coordinates\" : [ [ 0.0, 6.0], [ 4.0, 2.0] ]\n" + + " },\n" + + " \"relation\" : \"intersects\"\n" + + " },\n" + + " \"ignore_unmapped\" : false,\n" + + " \"boost\" : 42.0\n" + + " }\n" + + " }\n" + + "}"; + + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionOutcome"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(Arrays.asList(0.0, 6.0), Arrays.asList(4.0, 2.0)))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "Seattle"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + XContentParser parser = createParser(JsonXContent.jsonXContent, expectedNewQueryString); + SearchSourceBuilder expectedSearchSourceBuilder = new SearchSourceBuilder(); + expectedSearchSourceBuilder.parseXContent(parser); + SearchRequest expectedRequest = new SearchRequest().source(expectedSearchSourceBuilder); + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(expectedRequest.source().query(), newSearchRequest.source().query()); + assertEquals(expectedRequest.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + + /** + * Tests the scenario where an exception occurs during the model inference process. + * The test sets up a mock client that simulates a failure during the model execution, + * and verifies that the appropriate exception is propagated to the listener. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_InferenceException() { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Executing Model failed with exception.")); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception ex) { + try { + throw ex; + } catch (Exception e) { + assertEquals("Executing Model failed with exception.", e.getMessage()); + } + } + }; + + requestProcessor.processRequestAsync(request, requestContext, listener); + } + + /** + * Tests the scenario where an exception occurs during the model inference process, + * but the `ignoreFailure` flag is set to true. In this case, the original search + * request should be returned without any modifications. + */ + public void testExecute_InferenceExceptionIgnoreFailure() { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "query.term.text.value"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + true, + false + ); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Executing Model failed with exception.")); + return null; + + }).when(client).execute(any(), any(), any()); + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + } + + @Override + public void onFailure(Exception ex) { + throw new RuntimeException("error handling not properly"); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, listener); + + } + + /** + * Tests the case where the query string is null, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testNullQueryStringException() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = ""; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + SearchRequest request = new SearchRequest(); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception ex) { + try { + throw ex; + } catch (Exception e) { + assertEquals("query body is empty, cannot processor inference on empty query request.", e.getMessage()); + } + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the case where the query string is null, but the `ignoreFailure` flag is set to true. + * The original search request should be returned without any modifications. + * + * @throws Exception if an error occurs during the test + */ + public void testNullQueryStringIgnoreFailure() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = ""; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + true, + false + ); + SearchRequest request = new SearchRequest(); + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertNull(newSearchRequest.source()); + } + + @Override + public void onFailure(Exception ex) { + throw new RuntimeException("Failed in executing processRequestAsync."); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the case where the query template contains an invalid query format, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_invalidQueryFormatInQueryTemplateException() throws Exception { + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + // typo in query + String queryTemplate = "{\"query\":{\"range1\":{\"text\":{\"gte\":${modelPredictionScore}}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "0.123")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":foo,"boost":1.0}}}} + */ + /** + * example output range query: {"query":{"range":{"text":{"gte":"2"}}}} + */ + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("unknown query [range1] did you mean [range]?", e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the case where the query field specified in the input mapping is not found in the original query string, + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_queryFieldNotFoundInOriginalQueryException() throws Exception { + String modelInputField = "inputs"; + // test typo in query field name + String originalQueryField = "query.term.text.value1"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = ""; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception ex) { + assertEquals( + "cannot find field: query.term.text.value1 in query string: {\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}}}", + ex.getMessage() + ); + } + }; + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + + /** + * Tests the case where the query field specified in the input mapping is not found in the query template, + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_queryFieldNotFoundInQueryTemplateException() throws Exception { + String modelInputField = "inputs"; + // test typo in query field name + String originalQueryField = "query.term.text.value1"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = "{\"query\":{\"range\":{\"text\":{\"gte\":\"${modelPredictionScore}\"}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + /** + * example output range query: {"query":{"range":{"text":{"gte":0.123}}}} + */ + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception ex) { + assertEquals( + "cannot find field: query.term.text.value1 in query string: {\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}}}", + ex.getMessage() + ); + } + }; + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the case where the query field specified in the input mapping is not found in the query template, + * but the `ignoreFailure` flag is set to true. The original search request should be returned without any modifications. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_queryFieldNotFoundInQueryTemplateIgnoreFailure() throws Exception { + String modelInputField = "inputs"; + // test typo in query field name + String originalQueryField = "query.term.text.value1"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = "{\"query\":{\"range\":{\"text\":{\"gte\":\"${modelPredictionScore}\"}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + true, + false + ); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(newSearchRequest.source().query(), incomingQuery); + } + + @Override + public void onFailure(Exception ex) { + throw new RuntimeException("error handling not properly"); + } + }; + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the case where the query field specified in the input mapping is not found in the query template, + * but the `ignoreMissing` flag is set to true. The original search request should be returned without any modifications. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_queryFieldNotFoundInQueryTemplateIgnoreMissing() throws Exception { + String modelInputField = "inputs"; + // test typo in query field name + String originalQueryField = "query.term.text.value1"; + String newQueryField = "modelPredictionScore"; + String modelOutputField = "response"; + String queryTemplate = "{\"query\":{\"range\":{\"text\":{\"gte\":\"${modelPredictionScore}\"}}}}"; + + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + queryTemplate, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + true + ); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + /** + * example input term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(newSearchRequest.source().query(), incomingQuery); + } + + @Override + public void onFailure(Exception ex) { + throw new RuntimeException("error handling not properly"); + } + }; + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Helper method to create an instance of the MLInferenceSearchRequestProcessor with the specified parameters. + * + * @param queryTemplate the query template + * @param modelInputField the model input field name + * @param originalQueryField the original query field name + * @param newQueryField the new query field name + * @param modelOutputField the model output field name + * @param ignoreFailure the flag indicating whether to ignore failures or not + * @param ignoreMissing the flag indicating whether to ignore missing fields or not + * @return an instance of the MLInferenceSearchRequestProcessor + */ + private MLInferenceSearchRequestProcessor getMlInferenceSearchRequestProcessor( + String queryTemplate, + String modelInputField, + String originalQueryField, + String newQueryField, + String modelOutputField, + boolean ignoreFailure, + boolean ignoreMissing + ) { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalQueryField); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newQueryField, modelOutputField); + outputMap.add(output); + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + queryTemplate, + inputMap, + outputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + ignoreMissing, + "remote", + false, + ignoreFailure, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + return requestProcessor; + } + + /** + * Tests the creation of the MLInferenceSearchRequestProcessor with required fields. + * + * @throws Exception if an error occurs during the test + */ + private MLInferenceSearchRequestProcessor.Factory factory; + + @Mock + private NamedXContentRegistry xContentRegistry; + + @Before + public void init() { + factory = new MLInferenceSearchRequestProcessor.Factory(client, xContentRegistry); + } + + /** + * Tests the creation of the MLInferenceSearchRequestProcessor with required fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateRequiredFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchRequestProcessor); + assertEquals(MLInferenceSearchRequestProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchRequestProcessor.getType(), MLInferenceSearchRequestProcessor.TYPE); + } + + /** + * Tests the creation of the MLInferenceSearchRequestProcessor for a local model. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessor() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchRequestProcessor); + assertEquals(MLInferenceSearchRequestProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchRequestProcessor.getType(), MLInferenceSearchRequestProcessor.TYPE); + } + + /** + * The model input field is required for using a local model + * when missing the model input field, expected to throw Exceptions + * + * @throws Exception if an error occurs during the test + */ + public void testCreateLocalModelProcessorMissingModelInputField() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + try { + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchRequestProcessor); + } catch (Exception e) { + assertEquals(e.getMessage(), "Please provide model input when using a local model in ML Inference Processor"); + } + } + + /** + * Tests the case where the `input_map` field is missing in the configuration, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateNoFieldPresent() throws Exception { + Map config = new HashMap<>(); + try { + factory.create(Collections.emptyMap(), "no field", null, false, config, null); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[model_id] required property is missing")); + } + } + + /** + * Tests the case where the `input_map` field is missing in the configuration, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testMissingInputMapFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + String processorTag = randomAlphaOfLength(10); + try { + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[input_map] required property is missing")); + } + } + + /** + * Tests the case where the `output_map` field is missing in the configuration, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testMissingOutputMapFields() throws Exception { + Map config = new HashMap<>(); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + config.put(MODEL_ID, "model1"); + String processorTag = randomAlphaOfLength(10); + config.put(INPUT_MAP, inputMap); + try { + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[output_map] required property is missing")); + } + } + + /** + * Tests the case where the number of prediction tasks exceeds the maximum allowed value, and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testExceedMaxPredictionTasks() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + Map input2 = new HashMap<>(); + input2.put("inputs", "timestamp"); + inputMap.add(input2); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + ("The number of prediction task setting in this process is 3. It exceeds the max_prediction_tasks of 2. Please reduce the size of input_map or increase max_prediction_tasks.") + ); + } + } + + /** + * Tests the case where the length of the `output_map` list is greater than the length of the `input_map` list, + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testOutputMapsExceedInputMaps() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("hashtag_embedding", "response"); + outputMap.add(output2); + Map output3 = new HashMap<>(); + output2.put("hashtvg_embedding", "response"); + outputMap.add(output3); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), ("The length of output_map and the length of input_map do no match.")); + } + } + + /** + * Tests the creation of the MLInferenceSearchRequestProcessor with optional fields. + * + * @throws Exception if an error occurs during the test + */ + public void testCreateOptionalFields() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + Map model_config = new HashMap<>(); + model_config.put("hidden_size", 768); + model_config.put("gradient_checkpointing", false); + model_config.put("position_embedding_type", "absolute"); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + + MLInferenceSearchRequestProcessor MLInferenceSearchRequestProcessor = factory + .create(Collections.emptyMap(), processorTag, null, false, config, null); + assertNotNull(MLInferenceSearchRequestProcessor); + assertEquals(MLInferenceSearchRequestProcessor.getTag(), processorTag); + assertEquals(MLInferenceSearchRequestProcessor.getType(), MLInferenceSearchRequestProcessor.TYPE); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index f8d623fc74..54ab526dee 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -243,11 +243,11 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception { String index_name = "book_index"; createPipelineProcessor(createPipelineRequestBody, "embedding_pipeline"); createIndex(index_name, createIndexRequestBody); - // Skip test if key is null if (OPENAI_KEY == null) { return; } uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); List embeddingList = JsonPath.parse(document).read("_source.book[*].chunk.text[*].context_embedding"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java new file mode 100644 index 0000000000..84906ef6f2 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java @@ -0,0 +1,380 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; + +/** + * test ml inference search request processor to rewrite query with inference results + */ +public class RestMLInferenceSearchRequestProcessorIT extends MLCommonsRestTestCase { + private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + private String openAIChatModelId; + private String bedrockEmbeddingModelId; + + private String localModelId; + private final String completionModelConnectorEntity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": [\\\"${parameters.input}\\\"], \\\"model\\\": \\\"${parameters.model}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private final String bedrockEmbeddingModelConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" + + " \"description\": \"The connector to bedrock Titan embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + + /** + * register two remote models and create an index and document before tests + * @throws Exception + */ + @Before + public void setup() throws Exception { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5); + this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); + + String index_name = "daily_index"; + String createIndexRequestBody = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"diary_embedding_size\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + String uploadDocumentRequestBodyDoc1 = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"diary_embedding_size\": \"1536\",\n" // embedding size for ada model + + " \"weather\": \"rainy\"\n" + + " }"; + + String uploadDocumentRequestBodyDoc2 = "{\n" + + " \"id\": 2,\n" + + " \"diary\": [\"bored\",\"at home\"],\n" + + " \"diary_embedding_size\": \"768\",\n" // embedding size for local text embedding model + + " \"weather\": \"sunny\"\n" + + " }"; + + createIndex(index_name, createIndexRequestBody); + uploadDocument(index_name, "1", uploadDocumentRequestBodyDoc1); + uploadDocument(index_name, "2", uploadDocumentRequestBodyDoc2); + } + + /** + * Tests the ML inference processor with a remote model to rewrite the query string. + * It creates a search pipeline with the ML inference processor, + * and then performs a search using the pipeline. The test verifies that the query string is rewritten + * correctly based on the inference results from the remote model. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelRewriteQueryString() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"request_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"query.term.diary_embedding_size.value\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"query.term.diary_embedding_size.value\": \"data[0].embedding.length()\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\":false,\n" + + " \"ignore_failure\": false\n" + + " \n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"happy\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "diary_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "rainy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + } + + /** + * Tests the ML inference processor with a remote model to rewrite the query type. + * It creates a search pipeline with the ML inference processor configured to rewrite + * a term query to a range query based on the inference results from the remote model. + * The test then performs a search using the pipeline and verifies that the query type + * is rewritten correctly. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelRewriteQueryType() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"request_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.bedrockEmbeddingModelId + + "\",\n" + + " \"query_template\": \"{\\\"query\\\":{\\\"range\\\":{\\\"diary_embedding_size\\\":{\\\"lte\\\":${modelPrediction}}}}}\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"query.term.diary_embedding_size.value\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"modelPrediction\": \"embedding.length()\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String index_name = "daily_index"; + + String pipelineName = "diary_embedding_pipeline_range_query"; + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"happy\"}}}}"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + + assertEquals((int) JsonPath.parse(response).read("$.hits.hits.length()"), 2); + } + + /** + * Tests the ML inference processor with a local model. + * It registers, deploys, and gets a local model, creates a search pipeline with the ML inference processor + * configured to use the local model, and then performs a search using the pipeline. + * The test verifies that the query string is rewritten correctly based on the inference results + * from the local model. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorLocalModel() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"request_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"query.term.diary_embedding_size.value\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"query.term.diary_embedding_size.value\": \"$.inference_results[0].output[0].data.length()\"\n" + + " }\n" + + " ],\n" + + " \n" + + " \"ignore_missing\":false,\n" + + " \"ignore_failure\": false\n" + + " \n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String index_name = "daily_index"; + String pipelineName = "diary_embedding_pipeline_local"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + String query = "{\"query\":{\"term\":{\"diary_embedding_size\":{\"value\":\"bored\"}}}}"; + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "768"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "bored"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "at home"); + } + + protected void createSearchPipelineProcessor(String requestBody, final String pipelineName) throws Exception { + Response pipelineCreateResponse = TestHelper + .makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, pipelineCreateResponse.getStatusLine().getStatusCode()); + + } + + protected void createIndex(String indexName, String requestBody) throws Exception { + Response response = makeRequest( + client(), + "PUT", + indexName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { + Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(jsonBody); + client().performRequest(request); + } + + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java index afd39c7d1e..ab7acf38a0 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java @@ -31,6 +31,8 @@ public class TestData { public static final String SENTENCE_TRANSFORMER_MODEL_URL = "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true"; public static final String TIME_FIELD = "timestamp"; + public static final String HUGGINGFACE_TRANSFORMER_MODEL_HASH_VALUE = + "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021"; public static final String TARGET_FIELD = "price"; public static DataFrame constructTestDataFrame(int size) { diff --git a/plugin/src/yamlRestTest/resources/rest-api-spec/test/30_inference_search_request_processor.yml b/plugin/src/yamlRestTest/resources/rest-api-spec/test/30_inference_search_request_processor.yml new file mode 100644 index 0000000000..8fdcb79fbc --- /dev/null +++ b/plugin/src/yamlRestTest/resources/rest-api-spec/test/30_inference_search_request_processor.yml @@ -0,0 +1,41 @@ +--- +teardown: + - do: + search_pipeline.delete: + id: "my_pipeline" + ignore: 404 + +--- +"Test ML Inference Search Request Processor": + - skip: + version: " - 2.15.99" + reason: "Added in 2.16.0" + - do: + search_pipeline.put: + id: "my_pipeline" + body: > + { + "request_processors": [ + { + "ml_inference": { + "tag": "ml_inference", + "description": "This processor is going to run ml inference during search request", + "model_id": "KaJtepABtwewtgmtfKhq", + "input_map": [ + { + "texts": "query.term.text.value" + } + ], + "output_map": [ + { + "query.term.text.value": "response_type" + } + ], + "ignore_missing":false, + "ignore_failure": false + + } + } + ] + } + - match: { acknowledged: true }