diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 287fbb8127..edf26b954d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import java.io.IOException; import java.time.Instant; @@ -322,40 +323,16 @@ public T createPayload(String action, Map parameters) { if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) { String payload = connectorAction.get().getRequestBody(); payload = fillNullParameters(parameters, payload); + parseParameters(parameters); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); + if (!isJson(payload)) { - String payloadAfterEscape = connectorAction.get().getRequestBody(); - Map escapedParameters = escapeMapValues(parameters); - StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}"); - payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape); - if (!isJson(payloadAfterEscape)) { - throw new IllegalArgumentException("Invalid payload: " + payload); - } else { - payload = payloadAfterEscape; - } + throw new IllegalArgumentException("Invalid payload: " + payload); } return (T) payload; } return (T) parameters.get("http_body"); - - } - - public static Map escapeMapValues(Map parameters) { - Map escapedMap = new HashMap<>(); - if (parameters != null) { - for (Map.Entry entry : parameters.entrySet()) { - String key = entry.getKey(); - String value = entry.getValue(); - String escapedValue = escapeValue(value); - escapedMap.put(key, escapedValue); - } - } - return escapedMap; - } - - private static String escapeValue(String value) { - return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t"); } protected String fillNullParameters(Map parameters, String payload) { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index e71636e01b..57c24c22fd 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -50,6 +50,7 @@ public class StringUtils { static { gson = new Gson(); } + public static final String TO_STRING_FUNCTION_NAME = ".toString()"; public static boolean isValidJsonString(String Json) { try { @@ -233,4 +234,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea return errorMessage + " Model ID: " + modelId; } } + + /** + * Collects the prefixes of the toString() method calls present in the values of the given map. + * + * @param map A map containing key-value pairs where the values may contain toString() method calls. + * @return A list of prefixes for the toString() method calls found in the map values. + */ + public static List collectToStringPrefixes(Map map) { + List prefixes = new ArrayList<>(); + for (String key : map.keySet()) { + String value = map.get(key); + if (value != null) { + Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}"); + Matcher matcher = pattern.matcher(value); + while (matcher.find()) { + String prefix = matcher.group(1); + prefixes.add(prefix); + } + } + } + return prefixes; + } + + /** + * Parses the given parameters map and processes the values containing toString() method calls. + * + * @param parameters A map containing key-value pairs where the values may contain toString() method calls. + * @return A new map with the processed values for the toString() method calls. + */ + public static Map parseParameters(Map parameters) { + if (parameters != null) { + List toStringParametersPrefixes = collectToStringPrefixes(parameters); + + if (!toStringParametersPrefixes.isEmpty()) { + for (String prefix : toStringParametersPrefixes) { + String value = parameters.get(prefix); + if (value != null) { + parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value)); + } + } + } + } + return parameters; + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index a84652791f..0115ac1376 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -6,7 +6,6 @@ package org.opensearch.ml.common.connector; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.utils.StringUtils.toJson; import java.io.IOException; import java.util.ArrayList; @@ -184,114 +183,6 @@ public void createPayload_InvalidJson() { connector.validatePayload(predictPayload); } - @Test - public void createPayloadWithString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - parameters.put("context", "document1"); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload); - } - - @Test - public void createPayloadWithList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - listOfDocuments.add("document2"); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - NestedListOfDocuments.add("document2"); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithMap() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedListOfMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - @Test public void createPayload() { HttpConnector connector = createHttpConnector(); diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index cf112d6ca3..a4b1460f39 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -6,7 +6,12 @@ package org.opensearch.ml.common.utils; import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; +import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; +import static org.opensearch.ml.common.utils.StringUtils.parseParameters; +import static org.opensearch.ml.common.utils.StringUtils.toJson; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -14,6 +19,7 @@ import java.util.Map; import java.util.Set; +import org.apache.commons.text.StringSubstitutor; import org.junit.Assert; import org.junit.Test; @@ -218,4 +224,203 @@ public void testGetErrorMessageWhenHiddenNull() { // Assert assertEquals(expected, result); } + + /** + * Tests the collectToStringPrefixes method with a map containing toString() method calls + * in the values. Verifies that the method correctly extracts the prefixes of the toString() + * method calls. + */ + @Test + public void testGetToStringPrefix() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + parameters.put("context", "${parameters.text.toString()}"); + + List prefixes = collectToStringPrefixes(parameters); + List expectPrefixes = new ArrayList<>(); + expectPrefixes.add("text"); + expectPrefixes.add("context"); + expectPrefixes.add("history"); + assertEquals(prefixes, expectPrefixes); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key. Verifies that the method correctly processes the list and adds + * the processed value to the map with the expected key. Also tests the string substitution + * using the processed values. + */ + @Test + public void testParseParametersListToString() { + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals(requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\"]\"}"); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key, and the "prompt" value containing escaped characters. Verifies + * that the method correctly processes the list and adds the processed value to the map + * with the expected key. Also tests the string substitution using the processed values. + */ + @Test + public void testParseParametersListToStringWithEscapedPrompt() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key, and the "prompt" value containing escaped characters. Verifies + * that the method correctly processes the list and adds the processed value to the map + * with the expected key. Also tests the string substitution using the processed values. + */ + @Test + public void testParseParametersListToStringModelConfig() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.model_config.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("model_config.context", toJson(listOfDocuments)); + + parseParameters(parameters); + assertEquals(parameters.get("model_config.context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a nested list of strings as the + * value for the "context" key. Verifies that the method correctly processes the nested + * list and adds the processed value to the map with the expected key. Also tests the + * string substitution using the processed values. + */ + @Test + public void testParseParametersNestedListToString() { + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + ArrayList NestedListOfDocuments = new ArrayList<>(); + NestedListOfDocuments.add("document2"); + listOfDocuments.add(toJson(NestedListOfDocuments)); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a map of strings as the value + * for the "context" key. Verifies that the method correctly processes the map and adds + * the processed value to the map with the expected key. Also tests the string substitution + * using the processed values. + */ + @Test + public void testParseParametersMapToString() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + parameters.put("context", toJson(mapOfDocuments)); + parameters.put("history", "hello\n"); + parseParameters(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"name\\\":\\\"John\\\"}"); + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: {\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a nested map of strings as the + * value for the "context" key. Verifies that the method correctly processes the nested + * map and adds the processed value to the map with the expected key. Also tests the + * string substitution using the processed values. + */ + @Test + public void testParseParametersNestedMapToString() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + Map nestedMapOfDocuments = new HashMap<>(); + nestedMapOfDocuments.put("city", "New York"); + mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); + parameters.put("context", toJson(mapOfDocuments)); + parameters.put("history", "hello\n"); + parseParameters(parameters); + assertEquals( + parameters.get("context" + TO_STRING_FUNCTION_NAME), + "{\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"}" + ); + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" + ); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index f3da7c77bc..38e62528f3 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -384,8 +384,8 @@ private void processPredictions( } } } - - modelParameters = StringUtils.getParameterMap(modelInputParameters); + Map modelParametersInString = StringUtils.getParameterMap(modelInputParameters); + modelParameters.putAll(modelParametersInString); Set inputMapKeys = new HashSet<>(modelParameters.keySet()); inputMapKeys.removeAll(modelConfigs.keySet()); diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 850e466ba6..62b397f84b 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -169,10 +169,10 @@ public void onFailure(Exception e) { /** * Tests create processor with one_to_one is true * with custom prompt - * with many to one prediction, 5 documents in hits are calling 1 prediction tasks + * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ - public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { + public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { String newDocumentField = "context"; String modelOutputField = "response"; @@ -202,6 +202,102 @@ public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { "{ \"prompt\": \"${model_config.prompt}\"}", client, TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is false + * with custom prompt + * with many to one prediction, 5 documents in hits are calling 1 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { + + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "llm_response"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, false ); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java index 64a9306691..9c82547623 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -36,6 +36,7 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC private String openAIChatModelId; private String bedrockEmbeddingModelId; private String localModelId; + private String bedrockClaudeModelId; private final String completionModelConnectorEntity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -106,6 +107,47 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC + " ]\n" + "}"; + private final String bedrockClaudeModelConnectorEntity = "{\n" + + " \"name\": \"BedRock Claude instant-v1 Connector\",\n" + + " \"description\": \"The connector to bedrock for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\",\n" + + " \"stop_sequences\": [\"\\n\\nHuman:\",\"\\nObservation:\",\"\\n\\tObservation:\",\"\\nObservation\",\"\\n\\tObservation\",\"\\n\\nQuestion\"]\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"stop_sequences\\\": ${parameters.stop_sequences}, \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + /** * Registers two remote models and creates an index and documents before running the tests. * @@ -119,7 +161,8 @@ public void setup() throws Exception { this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true); String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); - + String bedrockClaudeModelName = "bedrock claude model " + randomAlphaOfLength(5); + this.bedrockClaudeModelId = registerRemoteModel(bedrockClaudeModelConnectorEntity, bedrockClaudeModelName, true); String index_name = "daily_index"; String createIndexRequestBody = "{\n" + " \"mappings\": {\n" @@ -152,13 +195,14 @@ public void setup() throws Exception { /** * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. * It creates a search pipeline with the processor configured to use the remote model, - * performs a search using the pipeline, and verifies the inference results. - * + * performs a search using the pipeline, gathering search documents into context and added in a custom prompt + * Using a toString() in placeholder to specify the context needs to cast as string + * and verifies the inference results. * @throws Exception if any error occurs during the test */ - public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + public void testMLInferenceProcessorRemoteModelCustomPrompt() throws Exception { // Skip test if key is null - if (OPENAI_KEY == null) { + if (AWS_ACCESS_KEY_ID == null) { return; } String createPipelineRequestBody = "{\n" @@ -168,20 +212,26 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + " \"tag\": \"ml_inference\",\n" + " \"description\": \"This processor is going to run ml inference during search request\",\n" + " \"model_id\": \"" - + this.openAIChatModelId + + this.bedrockClaudeModelId + "\",\n" + + " \"function_name\": \"REMOTE\",\n" + " \"input_map\": [\n" + " {\n" - + " \"input\": \"weather\"\n" + + " \"context\": \"weather\"\n" + " }\n" + " ],\n" + " \"output_map\": [\n" + " {\n" - + " \"weather_embedding\": \"data[*].embedding\"\n" + + " \"llm_response\":\"$.response\"\n" + + " \n" + " }\n" + " ],\n" - + " \"ignore_missing\": false,\n" + + " \"model_config\": {\n" + + " \"prompt\":\"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"\n" + + " },\n" + + " \"ignore_missing\":false,\n" + " \"ignore_failure\": false\n" + + " \n" + " }\n" + " }\n" + " ]\n" @@ -190,18 +240,13 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; String index_name = "daily_index"; - String pipelineName = "weather_embedding_pipeline"; + String pipelineName = "qa_pipeline"; createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); Map response = searchWithPipeline(client(), index_name, pipelineName, query); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); - List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); - Assert.assertEquals(embeddingList.size(), 1536); - Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); - Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + System.out.println(response); + Assert.assertNotNull(JsonPath.parse(response).read("$.hits.hits[0]._source.llm_response")); + Assert.assertNotNull(JsonPath.parse(response).read("$.hits.hits[1]._source.llm_response")); } /** @@ -312,6 +357,61 @@ public void testMLInferenceProcessorRemoteModelNestedListField() throws Exceptio Assert.assertEquals((Double) embeddingList.get(1), -0.012508746, 0.005); } + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + } + /** * Tests the ML inference processor with a local model. * It registers, deploys, and gets a local model, creates a search pipeline with the ML inference processor