From 083abad726933629557028047cb27c482fd950ec Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Mon, 21 Oct 2024 11:09:39 -0700 Subject: [PATCH] Enable pass query string to input_map in ml inference search response processor (#2899) * enable add query_text to model_config Signed-off-by: Mingshi Liu * change javadoc Signed-off-by: Mingshi Liu * add more tests Signed-off-by: Mingshi Liu * use standard json path config Signed-off-by: Mingshi Liu * add example in javadoc Signed-off-by: Mingshi Liu * read query mapping from input_map Signed-off-by: Mingshi Liu * recognize query mapping by prefix _request. Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- common/build.gradle | 1 - .../ml/common/utils/StringUtilsTest.java | 50 +++ .../MLInferenceSearchResponseProcessor.java | 125 ++++-- .../ml/processor/ModelExecutor.java | 8 +- ...InferenceSearchResponseProcessorTests.java | 371 +++++++++++++++++- 5 files changed, 508 insertions(+), 47 deletions(-) diff --git a/common/build.gradle b/common/build.gradle index 8b81080a3b..60edb3101a 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -26,7 +26,6 @@ dependencies { compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly group: 'org.json', name: 'json', version: '20231013' - implementation('com.google.guava:guava:32.1.2-jre') { exclude group: 'com.google.guava', module: 'failureaccess' exclude group: 'com.google.code.findbugs', module: 'jsr305' diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index aed76c5658..ef5307ae46 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; +import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath; import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath; import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import static org.opensearch.ml.common.utils.StringUtils.toJson; @@ -457,4 +458,53 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() { String result = getJsonPath(input); assertEquals("$.response.body.data[*].embedding", result); } + + @Test + public void testisValidJSONPath_InvalidInputs() { + Assert.assertFalse(isValidJSONPath("..bar")); + Assert.assertFalse(isValidJSONPath(".")); + Assert.assertFalse(isValidJSONPath("..")); + Assert.assertFalse(isValidJSONPath("foo.bar.")); + Assert.assertFalse(isValidJSONPath(".foo.bar.")); + } + + @Test + public void testisValidJSONPath_NullInput() { + Assert.assertFalse(isValidJSONPath(null)); + } + + @Test + public void testisValidJSONPath_EmptyInput() { + Assert.assertFalse(isValidJSONPath("")); + } + + @Test + public void testisValidJSONPath_ValidInputs() { + Assert.assertTrue(isValidJSONPath("foo")); + Assert.assertTrue(isValidJSONPath("foo.bar")); + Assert.assertTrue(isValidJSONPath("foo.bar.baz")); + Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux")); + Assert.assertTrue(isValidJSONPath(".foo")); + Assert.assertTrue(isValidJSONPath("$.foo")); + Assert.assertTrue(isValidJSONPath(".foo.bar")); + Assert.assertTrue(isValidJSONPath("$.foo.bar")); + } + + @Test + public void testisValidJSONPath_WithFilter() { + Assert.assertTrue(isValidJSONPath("$.store['book']")); + Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']")); + Assert.assertTrue(isValidJSONPath("$.store.book[0]")); + Assert.assertTrue(isValidJSONPath("$.store.book[1,2]")); + Assert.assertTrue(isValidJSONPath("$.store.book[-1:] ")); + Assert.assertTrue(isValidJSONPath("$.store.book[0:2]")); + Assert.assertTrue(isValidJSONPath("$.store.book[*]")); + Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]")); + Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]")); + Assert.assertTrue(isValidJSONPath("$..author")); + Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]")); + Assert.assertTrue(isValidJSONPath("$.store.book[0,1]")); + Assert.assertTrue(isValidJSONPath("$['store','warehouse']")); + Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title")); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 56e98474c7..e39b7f4b74 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -6,6 +6,7 @@ package org.opensearch.ml.processor; import static java.lang.Math.max; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; @@ -55,12 +56,11 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; -import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor, ModelExecutor { + public static final String REQUEST_PREFIX = "_request."; private final NamedXContentRegistry xContentRegistry; private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class); private final InferenceProcessorAttributes inferenceProcessorAttributes; @@ -155,6 +155,8 @@ public void processResponseAsync( try { SearchHit[] hits = response.getHits().getHits(); // skip processing when there is no hit + + String queryString = request.source().toString(); if (hits.length == 0) { responseListener.onResponse(response); return; @@ -183,7 +185,7 @@ public void processResponseAsync( ); } - rewriteResponseDocuments(mlInferenceSearchResponse, responseListener); + rewriteResponseDocuments(mlInferenceSearchResponse, responseListener, queryString); } else { // if one to one, make one hit search response and run rewriteResponseDocuments GroupedActionListener combineResponseListener = getCombineResponseGroupedActionListener( @@ -198,7 +200,7 @@ public void processResponseAsync( newHits[0] = hit; SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response); ActionListener oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed); - rewriteResponseDocuments(oneHitResponse, oneHitListener); + rewriteResponseDocuments(oneHitResponse, oneHitListener, queryString); // if any OneHitListener failure, try stop the rest of the predictions if (isOneHitListenerFailed.get()) { break; @@ -305,9 +307,11 @@ public void onFailure(Exception e) { * * @param response the search response * @param responseListener the listener to be notified when the response is processed + * @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n" * @throws IOException if an I/O error occurs during the rewriting process */ - private void rewriteResponseDocuments(SearchResponse response, ActionListener responseListener) throws IOException { + private void rewriteResponseDocuments(SearchResponse response, ActionListener responseListener, String queryString) + throws IOException { List> processInputMap = inferenceProcessorAttributes.getInputMaps(); List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size(); @@ -329,7 +333,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener> processInputMap, int inputMapIndex, GroupedActionListener> batchPredictionListener, - Map hitCountInPredictions + Map hitCountInPredictions, + String queryString ) throws IOException { Map modelParameters = new HashMap<>(); Map modelConfigs = new HashMap<>(); if (inferenceProcessorAttributes.getModelConfigMaps() != null) { - modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); - modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + Map modelConfigMapsInput = inferenceProcessorAttributes.getModelConfigMaps(); + + modelParameters.putAll(modelConfigMapsInput); + modelConfigs.putAll(modelConfigMapsInput); + } Map modelInputParameters = new HashMap<>(); @@ -364,33 +373,52 @@ private void processPredictions( Map inputMapping; if (processInputMap != null && !processInputMap.isEmpty()) { inputMapping = processInputMap.get(inputMapIndex); + boolean isRequestInputMissing = checkIsRequestInputMissing(queryString, inputMapping); + if (isRequestInputMissing) { + if (!ignoreMissing) { + throw new IllegalArgumentException( + "Missing required input field in query body. input_map: " + inputMapping.values() + ", query body:" + queryString + ); + } + } for (SearchHit hit : hits) { Map document = hit.getSourceAsMap(); - boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping); - if (!isModelInputMissing) { + boolean isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping); + if (!isDocumentFieldMissing) { MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex); for (Map.Entry entry : inputMapping.entrySet()) { // model field as key, document field name as value String modelInputFieldName = entry.getKey(); String documentFieldName = entry.getValue(); - - Object documentJson = JsonPath.parse(document).read("$"); - Configuration configuration = Configuration - .builder() - .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) - .build(); - - Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName); - if (documentValue != null) { - // when not existed in the map, add into the modelInputParameters map - updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue); + // read the query string when the mapping field name starts with "$._request." or "_request." + // skip when modelInputParameters already has this modelInputFieldName to avoid duplicate read + if (StringUtils.isValidJSONPath(documentFieldName) + && (documentFieldName.startsWith("$." + REQUEST_PREFIX) || documentFieldName.startsWith(REQUEST_PREFIX)) + && !modelInputParameters.containsKey(modelInputFieldName)) { + String requestFieldName = documentFieldName.replaceFirst(REQUEST_PREFIX, ""); + + Object queryText = JsonPath.using(suppressExceptionConfiguration).parse(queryString).read(requestFieldName); + if (queryText != null) { + modelInputParameters.put(modelInputFieldName, toJson(queryText)); + } + } else { + Object documentValue = JsonPath.using(suppressExceptionConfiguration).parse(document).read(documentFieldName); + if (documentValue != null) { + // when not existed in the map, add into the modelInputParameters map + updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue); + } } } } else { // when document does not contain the documentFieldName, skip when ignoreMissing if (!ignoreMissing) { throw new IllegalArgumentException( - "cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit + "cannot find all required input fields: " + + inputMapping.values() + + " in hit:" + + hit + + " and query body:" + + queryString ); } } @@ -542,11 +570,11 @@ public void onResponse(Map multipleMLOutputs) { Map inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap); Map outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap); - boolean isModelInputMissing = false; + boolean isDocumentFieldMissing = false; if (processInputMap != null && !processInputMap.isEmpty()) { - isModelInputMissing = checkIsModelInputMissing(document, inputMapping); + isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping); } - if (!isModelInputMissing) { + if (!isDocumentFieldMissing) { // Iterate over outputMapping for (Map.Entry outputMapEntry : outputMapping.entrySet()) { @@ -637,22 +665,45 @@ public void onFailure(Exception e) { /** * Checks if the document is missing any of the required input fields specified in the input mapping. + * When model config contains the default model_input value, it's not considered as missing model input. * * @param document the document map * @param inputMapping the input mapping * @return true if the document is missing any of the required input fields, false otherwise */ - private boolean checkIsModelInputMissing(Map document, Map inputMapping) { - boolean isModelInputMissing = false; - for (Map.Entry inputMapEntry : inputMapping.entrySet()) { - String oldDocumentFieldName = inputMapEntry.getValue(); - boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName); - if (!checkSingleModelInputPresent) { - isModelInputMissing = true; - break; - } - } - return isModelInputMissing; + private boolean checkIsDocumentFieldMissing(Map document, Map inputMapping) { + return inputMapping + .values() + .stream() + .filter(fieldName -> !(fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX))) + .anyMatch(fieldName -> { + boolean isFieldPresentInDocument = document != null && hasField(document, fieldName); + boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null + && this.inferenceProcessorAttributes.modelConfigMaps.containsKey(fieldName); + return !isFieldPresentInDocument && !isFieldPresentInModelConfig; + }); + } + + /** + * Checks if the request is missing any of the required input fields specified in the input mapping. + * When model config contains the default model_input value, it's not considered as missing model input. + * + * @param queryString the query body in string format, e.g., "{ \"query\": { \"match_all\": {} } }\n" + * @param inputMapping the input mapping + * @return true if the document is missing any of the required input fields, false otherwise + */ + private boolean checkIsRequestInputMissing(String queryString, Map inputMapping) { + return inputMapping + .values() + .stream() + .filter(fieldName -> fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX)) + .map(fieldName -> fieldName.replaceFirst(REQUEST_PREFIX, "")) + .anyMatch(requestFieldName -> { + boolean isFieldPresentInQuery = queryString != null && hasField(queryString, requestFieldName); + boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null + && this.inferenceProcessorAttributes.modelConfigMaps.containsKey(requestFieldName); + return !isFieldPresentInQuery && !isFieldPresentInModelConfig; + }); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index cf17afd904..d32308d2ef 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -282,8 +282,12 @@ default String toString(Object originalFieldValue) { } default boolean hasField(Object json, String path) { - Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path); - + Object value; + if (json instanceof String) { + value = JsonPath.using(suppressExceptionConfiguration).parse((String) json).read(path); + } else { + value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path); + } if (value != null) { return true; } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index f462408943..dedae5f1bd 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -6,11 +6,13 @@ package org.opensearch.ml.processor; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; @@ -35,6 +37,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; @@ -52,11 +55,16 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -171,6 +179,348 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, response, responseContext, listener); } + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * read the query text from input_mapping + * @throws Exception if an error occurs during the test + */ + @Test + public void testProcessResponseSuccessReadQueryTextFromInputMap() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("query_text", "$._request.query.term.text.value"); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "text_similarity", + false, + false, + false, + "{ \"query_text\": \"${input_map.query_text}\", \"text_docs\":${input_map.text_docs}}", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + responseProcessor.processResponseAsync(request, response, responseContext, listener); + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getQueryText()), "foo"); + assertEquals(toJson(inputDataSet.getTextDocs()), "[\"value 0\",\"value 1\",\"value 2\",\"value 3\",\"value 4\"]"); + } + + /** + * Tests read the query size and sort field from request + * @throws Exception if an error occurs during the test + */ + @Test + public void testProcessResponseSuccessReadRequestMetaFieldFromInputMap() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("query_text", "$._request.query.term.text.value"); + input.put("sort", "$._request.sort"); + input.put("size", "$._request.size"); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + responseProcessor.processResponseAsync(request, response, responseContext, listener); + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"size\":\"5\",\"sort\":\"[{\\\"text\\\":{\\\"order\\\":\\\"asc\\\"}}]\",\"text_docs\":\"[\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]\",\"query_text\":\"foo\"}" + ); + } + + /** + * Tests read the query text based on input_mapping + * when the query mapping is not found, expect to + * @throws Exception + */ + @Test + public void testProcessResponseSuccessReadQueryTextException() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("query_text", "$._request.query.term.text.value1"); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + throw new RuntimeException("error handling not properly"); + } + + @Override + public void onFailure(Exception e) { + assertEquals( + e.getMessage(), + "Missing required input field in query body. input_map: [text, $._request.query.term.text.value1], query body:{\"size\":5,\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}},\"sort\":[{\"text\":{\"order\":\"asc\"}}]}" + ); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(0)).execute(any(), any(), any()); + } + + /** + * Tests read the query text based on input_mapping, but query text not found + * isRequestInputMissing is true, isDocumentFieldMissing is false + * when the query mapping is not found, ignoreMissing then expect to read the document input + * @throws Exception + */ + @Test + public void testProcessResponseSuccessReadQueryTextExceptionIgnoreMissing() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("query_text", "_request.query.term.text.value1"); + input.put(modelInputField, originalDocumentField); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + true, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("error handling not properly"); + } + }; + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + responseProcessor.processResponseAsync(request, response, responseContext, listener); + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"text_docs\":\"[\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]\"}" + ); + + } + /** * Tests create processor with one_to_one is true * with custom prompt @@ -180,7 +530,14 @@ public void onFailure(Exception e) { @Test public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { - String newDocumentField = "context"; + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "llm_response"; String modelOutputField = "response"; List> outputMap = new ArrayList<>(); Map output = new HashMap<>(); @@ -190,11 +547,11 @@ public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { modelConfig .put( "prompt", - "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${input_map.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" ); MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", - null, + inputMap, outputMap, modelConfig, DEFAULT_MAX_PREDICTION_TASKS, @@ -205,7 +562,7 @@ public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { false, false, false, - "{ \"prompt\": \"${model_config.prompt}\"}", + "{ \"parameters\": ${ml_inference.parameters} }", client, TEST_XCONTENT_REGISTRY_FOR_QUERY, true @@ -1891,7 +2248,7 @@ public void onFailure(Exception e) { + " \"texttypo\" : \"value 0\",\n" + " \"image\" : \"value 0\"\n" + " }\n" - + "}", + + "} and query body:{\"size\":5,\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}},\"sort\":[{\"text\":{\"order\":\"asc\"}}]}", e.getMessage() ); } @@ -2690,7 +3047,7 @@ public void onFailure(Exception e) { + " \"_source\" : {\n" + " \"textMissing\" : \"value 2\"\n" + " }\n" - + "}", + + "} and query body:{\"size\":5,\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}},\"sort\":[{\"text\":{\"order\":\"asc\"}}]}", e.getMessage() ); } @@ -3569,7 +3926,7 @@ public void onFailure(Exception e) { private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); - SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).size(5).sort("text"); SearchRequest request = new SearchRequest().source(source); return request; }