From 222c4cdc2bde12e6da8481430cc50483332dcec3 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 09:08:40 +0800 Subject: [PATCH 01/15] Add multi modal default preprocess function Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 4 + .../ConnectorPreProcessFunction.java | 5 + ...MultiModalEmbeddingPreProcessFunction.java | 39 +++++++ ...iModalEmbeddingPreProcessFunctionTest.java | 101 ++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index d2d65ebdfd..c7aabb3aff 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,6 +8,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -22,6 +23,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; + public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; @@ -34,7 +36,9 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 5701e0fa3a..906f719157 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -15,6 +15,7 @@ import org.opensearch.script.TemplateScript; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.function.Function; @@ -46,6 +47,10 @@ public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); } + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.size() == 1 && docs.get(0) == null) { + throw new IllegalArgumentException("No input text or image provided"); + } } protected String executeScript(ScriptService scriptService, String painlessScript, Map params) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..19ce98f1b7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java @@ -0,0 +1,39 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public MultiModalEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + // The input will must have inputText even it's null, input image is optional. + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + if (inputData.getDocs().size() == 1) { + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))))).build(); + } else { + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "inputImage", inputData.getDocs().get(1))))).build(); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..d5319fd6ce --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.w3c.dom.Text; + +import java.rmi.Remote; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class MultiModalEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MultiModalEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; + + @Before + public void setUp() { + function = new MultiModalEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + function.apply(textSimilarityInput); + } + + @Test + public void process_input_text_image() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + assertEquals("world", dataSet.getParameters().get("inputImage")); + } + + @Test + public void process_input_text_only() { + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + } + + @Test + public void process_input_text_null() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No input text or image provided"); + List docs = new ArrayList<>(); + docs.add(null); + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + } + + @Test + public void process_RemoteInferenceInput() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } +} From 12dece7886245fcab3b7500b4628e03910dc295a Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 15:56:55 +0800 Subject: [PATCH 02/15] Address comments Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 4 ++-- .../ConnectorPreProcessFunction.java | 10 +++++++++- ...ultiModalConnectorPreProcessFunction.java} | 18 ++++++++++++++--- ...ModalConnectorPreProcessFunctionTest.java} | 20 +++++++++---------- 4 files changed, 35 insertions(+), 17 deletions(-) rename common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/{MultiModalEmbeddingPreProcessFunction.java => MultiModalConnectorPreProcessFunction.java} (54%) rename common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/{MultiModalEmbeddingPreProcessFunctionTest.java => MultiModalConnectorPreProcessFunctionTest.java} (85%) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index c7aabb3aff..758e1e3ce0 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,7 +8,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; -import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -36,7 +36,7 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); - MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction(); + MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 906f719157..d049dc8956 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -16,11 +16,18 @@ import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Function; import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; +/** + * This abstract class represents a pre-processing function for a connector. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true. + */ @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { @@ -45,10 +52,11 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { + log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName())); throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); } List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); - if (docs.size() == 1 && docs.get(0) == null) { + if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { throw new IllegalArgumentException("No input text or image provided"); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java similarity index 54% rename from common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java rename to common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 19ce98f1b7..05cc317856 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -15,9 +15,16 @@ import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; -public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { +/** + * This class provides a pre-processing function for multi-modal input data. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input. + * The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + */ +public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction { - public MultiModalEmbeddingPreProcessFunction() { + public MultiModalConnectorPreProcessFunction() { this.returnDirectlyForRemoteInferenceInput = true; } @@ -26,7 +33,12 @@ public void validate(MLInput mlInput) { validateTextDocsInput(mlInput); } - // The input will must have inputText even it's null, input image is optional. + /** + * @param mlInput The input data to be processed. + * This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + * The inputText will always show up in the first document, even it's null. + */ @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java similarity index 85% rename from common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java rename to common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java index d5319fd6ce..4bc4c4cd8f 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java @@ -14,9 +14,7 @@ import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.w3c.dom.Text; -import java.rmi.Remote; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -24,11 +22,11 @@ import static org.junit.Assert.assertEquals; -public class MultiModalEmbeddingPreProcessFunctionTest { +public class MultiModalConnectorPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - MultiModalEmbeddingPreProcessFunction function; + MultiModalConnectorPreProcessFunction function; TextSimilarityInputDataSet textSimilarityInputDataSet; TextDocsInputDataSet textDocsInputDataSet; @@ -40,7 +38,7 @@ public class MultiModalEmbeddingPreProcessFunctionTest { @Before public void setUp() { - function = new MultiModalEmbeddingPreProcessFunction(); + function = new MultiModalConnectorPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); @@ -51,21 +49,21 @@ public void setUp() { } @Test - public void process_NullInput() { + public void testProcess_whenNullInput_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Preprocess function input can't be null"); function.apply(null); } @Test - public void process_WrongInput() { + public void testProcess_whenWrongInput_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); function.apply(textSimilarityInput); } @Test - public void process_input_text_image() { + public void testProcess_whenCorrectInput_expectCorrectOutput() { MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); RemoteInferenceInputDataSet dataSet = function.apply(mlInput); assertEquals(2, dataSet.getParameters().size()); @@ -74,7 +72,7 @@ public void process_input_text_image() { } @Test - public void process_input_text_only() { + public void testProcess_whenInputTextOnly_expectInputTextShowUp() { TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); RemoteInferenceInputDataSet dataSet = function.apply(mlInput); @@ -83,7 +81,7 @@ public void process_input_text_only() { } @Test - public void process_input_text_null() { + public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("No input text or image provided"); List docs = new ArrayList<>(); @@ -94,7 +92,7 @@ public void process_input_text_null() { } @Test - public void process_RemoteInferenceInput() { + public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); assertEquals(remoteInferenceInputDataSet, dataSet); } From 28a6ad2eac0769ff10badd113a86b2508886c2fe Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 16:03:31 +0800 Subject: [PATCH 03/15] address comments Signed-off-by: zane-neo --- .../ConnectorPreProcessFunction.java | 18 ++++++++++++++---- .../MultiModalConnectorPreProcessFunction.java | 5 +++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index d049dc8956..4235e82823 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -33,6 +33,13 @@ public abstract class ConnectorPreProcessFunction implements Function docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); - if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { - throw new IllegalArgumentException("No input text or image provided"); - } } protected String executeScript(ScriptService scriptService, String painlessScript, Map params) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 05cc317856..5e044d2cd6 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import java.util.List; import java.util.Map; import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; @@ -31,6 +32,10 @@ public MultiModalConnectorPreProcessFunction() { @Override public void validate(MLInput mlInput) { validateTextDocsInput(mlInput); + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { + throw new IllegalArgumentException("No input text or image provided"); + } } /** From da814af2fddcde160497cea54c69fd1f844d1b4f Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 16:27:19 +0800 Subject: [PATCH 04/15] add IT Signed-off-by: zane-neo --- .../ml/rest/RestBedRockInferenceIT.java | 61 +++++++++++++++++- .../BedRockMultiModalConnectorBodies.json | 64 +++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index fea981afe7..b4c1ddcf65 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -1,6 +1,8 @@ /* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * */ package org.opensearch.ml.rest; @@ -19,7 +21,9 @@ import org.opensearch.ml.common.utils.StringUtils; import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +@Log4j2 public class RestBedRockInferenceIT extends MLCommonsRestTestCase { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); @@ -88,4 +92,57 @@ private void validateOutput(String errorMsg, Map output) { assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); } + + public void test_bedrock_multimodal_model() throws Exception { + String imageBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictRemoteModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json new file mode 100644 index 0000000000..84ee1b1623 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json @@ -0,0 +1,64 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: multimodal", + "description": "The connector to bedrock Titan multimodal model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-image-v1", + "input_docs_processed_step_size": "2" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage}\" }", + "pre_process_function": "connector.pre_process.multimodal.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: multimodal", + "description": "The connector to bedrock Titan multimodal model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-image-v1", + "input_docs_processed_step_size": "2" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImages\": \"${parameters.inputImages}\" }", + "pre_process_function": "connector.pre_process.multimodal.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} From 6b308966480fcdefba5609de013c28c931278566 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 16:59:58 +0800 Subject: [PATCH 05/15] Fix IT Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 2 +- .../ml/rest/templates/BedRockMultiModalConnectorBodies.json | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index b4c1ddcf65..5db1e56a75 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -141,7 +141,7 @@ public void test_bedrock_multimodal_model() throws Exception { assertEquals(errorMsg, 1, outputList.size()); assertTrue(errorMsg, outputList.get(0) instanceof Map); assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json index 84ee1b1623..ff8628fe2a 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json @@ -7,8 +7,7 @@ "parameters": { "region": "%s", "service_name": "bedrock", - "model_name": "amazon.titan-embed-image-v1", - "input_docs_processed_step_size": "2" + "model_name": "amazon.titan-embed-image-v1" }, "credential": { "access_key": "%s", From c7472f7fb83d5d4ce575e1e515adaefb12bf762c Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 11 Jun 2024 10:38:24 +0800 Subject: [PATCH 06/15] Update common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java Co-authored-by: Yaliang Wu Signed-off-by: zane-neo --- .../MultiModalConnectorPreProcessFunction.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 5e044d2cd6..54d56f12f1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -47,10 +47,12 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - if (inputData.getDocs().size() == 1) { - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))))).build(); - } else { - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "inputImage", inputData.getDocs().get(1))))).build(); + Map parametersMap = new HashMap<>(); + parametersMap.put("inputText", inputData.getDocs().get(0)); + if (inputData.getDocs().size() > 1) { + parametersMap.put("inputImage", inputData.getDocs().get(1)); } + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build(); + } } From 01b9caf56f4139ef6e862a8f196a8d57c48c0207 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jun 2024 17:24:27 -0700 Subject: [PATCH 07/15] fix test Signed-off-by: Yaliang Wu --- .../preprocess/MultiModalConnectorPreProcessFunction.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 54d56f12f1..231c68c48d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import java.util.HashMap; import java.util.List; import java.util.Map; From 919183d1d748faa2827c5452751df9c07db10fce Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 08:18:37 +0800 Subject: [PATCH 08/15] Add more ITs Signed-off-by: zane-neo --- ...MultiModalConnectorPreProcessFunction.java | 2 +- .../ml/rest/MLCommonsRestTestCase.java | 10 +- .../ml/rest/RestBedRockInferenceIT.java | 107 +++++++++++++++++- .../BedRockMultiModalConnectorBodies.json | 4 +- 4 files changed, 113 insertions(+), 10 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 231c68c48d..008b1efe58 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -48,7 +48,7 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - Map parametersMap = new HashMap<>(); + Map parametersMap = new HashMap<>(); parametersMap.put("inputText", inputData.getDocs().get(0)); if (inputData.getDocs().size() > 1) { parametersMap.put("inputImage", inputData.getDocs().get(1)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index abfaee0236..341498f5da 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -59,6 +59,7 @@ import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; import org.opensearch.common.io.PathUtils; @@ -915,8 +916,13 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + Response response = null; + try { + response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + } catch (ResponseException e) { + response = e.getResponse(); + } return parseResponseToMap(response); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 5db1e56a75..6ba8a1a969 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.rest; @@ -10,6 +8,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -131,7 +130,7 @@ public void test_bedrock_multimodal_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - Map inferenceResult = predictRemoteModel(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -145,4 +144,102 @@ public void test_bedrock_multimodal_model() throws Exception { } } + + public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } + + public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add(null); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("status")); + assertEquals(errorMsg, 400, inferenceResult.get("status")); + assertTrue(errorMsg, inferenceResult.containsKey("error")); + assertTrue(errorMsg, inferenceResult.get("error") instanceof Map); + assertEquals(errorMsg, "illegal_argument_exception", ((Map) inferenceResult.get("error")).get("type")); + assertEquals(errorMsg, "No input text or image provided", ((Map) inferenceResult.get("error")).get("reason")); + } + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json index ff8628fe2a..dbba50c434 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json @@ -23,7 +23,7 @@ "content-type": "application/json", "x-amz-content-sha256": "required" }, - "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage}\" }", + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", "pre_process_function": "connector.pre_process.multimodal.embedding", "post_process_function": "connector.post_process.bedrock.embedding" } @@ -54,7 +54,7 @@ "content-type": "application/json", "x-amz-content-sha256": "required" }, - "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImages\": \"${parameters.inputImages}\" }", + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", "pre_process_function": "connector.pre_process.multimodal.embedding", "post_process_function": "connector.post_process.bedrock.embedding" } From fe228364420fb4f0655c58da046ca41bf7eafdb8 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 16:05:01 +0800 Subject: [PATCH 09/15] Fix failure ITs Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 6ba8a1a969..74a309a7c9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -235,7 +235,7 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("status")); - assertEquals(errorMsg, 400, inferenceResult.get("status")); + assertEquals(errorMsg, 400, Integer.parseInt(String.valueOf(inferenceResult.get("status")))); assertTrue(errorMsg, inferenceResult.containsKey("error")); assertTrue(errorMsg, inferenceResult.get("error") instanceof Map); assertEquals(errorMsg, "illegal_argument_exception", ((Map) inferenceResult.get("error")).get("type")); From 10c18bf1e0d456defc584ae807c6cb664968877c Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 20:17:13 +0800 Subject: [PATCH 10/15] fix failure IT Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 74a309a7c9..b4b5b9a555 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -235,7 +235,7 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("status")); - assertEquals(errorMsg, 400, Integer.parseInt(String.valueOf(inferenceResult.get("status")))); + assertTrue(errorMsg, String.valueOf(inferenceResult.get("status")).contains("400")); assertTrue(errorMsg, inferenceResult.containsKey("error")); assertTrue(errorMsg, inferenceResult.get("error") instanceof Map); assertEquals(errorMsg, "illegal_argument_exception", ((Map) inferenceResult.get("error")).get("type")); From 43d13fbe7f8eb0e7ef89520b218245e265f0f0a3 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 13 Jun 2024 13:49:56 +0800 Subject: [PATCH 11/15] Fix failure ITs Signed-off-by: zane-neo --- .../functions/preprocess/ConnectorPreProcessFunction.java | 6 +++++- .../java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 4235e82823..c4c88532ef 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -31,6 +31,10 @@ @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { + /** + * This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet. + * If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed. + */ protected boolean returnDirectlyForRemoteInferenceInput; /** @@ -67,7 +71,7 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName())); - throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); + throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'"); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 341498f5da..22f5cbb275 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -38,6 +38,7 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import lombok.extern.log4j.Log4j2; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -104,6 +105,7 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; +@Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds @@ -921,6 +923,7 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); } catch (ResponseException e) { + log.error(e.getMessage(), e); response = e.getResponse(); } return parseResponseToMap(response); From 8eacfec294da720a09366e2e50917ab2d9b6d3ab Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 13 Jun 2024 13:53:33 +0800 Subject: [PATCH 12/15] format code Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 22f5cbb275..788a98239c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -38,7 +38,6 @@ import java.util.function.Consumer; import java.util.stream.Collectors; -import lombok.extern.log4j.Log4j2; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -105,6 +104,8 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; +import lombok.extern.log4j.Log4j2; + @Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); From 62a45c35e20ba08f6aa4d7ad936b7391c5c454f8 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 14 Jun 2024 15:54:03 +0800 Subject: [PATCH 13/15] Add error response to make it esay to figure out the failure root cause Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index b4b5b9a555..d47fad8a31 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -114,7 +114,6 @@ public void test_bedrock_multimodal_model() throws Exception { for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); String modelId = registerRemoteModel( String .format( @@ -131,6 +130,7 @@ public void test_bedrock_multimodal_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -165,7 +165,6 @@ public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); String modelId = registerRemoteModel( String .format( @@ -182,6 +181,7 @@ public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -215,7 +215,6 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); String modelId = registerRemoteModel( String .format( @@ -234,6 +233,7 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("status")); assertTrue(errorMsg, String.valueOf(inferenceResult.get("status")).contains("400")); assertTrue(errorMsg, inferenceResult.containsKey("error")); From f0054fd44aedb891ca2ef9b928fe35ff22d08434 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 14 Jun 2024 16:54:47 +0800 Subject: [PATCH 14/15] format code Signed-off-by: zane-neo --- .../org/opensearch/ml/rest/RestBedRockInferenceIT.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index d47fad8a31..e9df3768bf 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -130,7 +130,8 @@ public void test_bedrock_multimodal_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -181,7 +182,8 @@ public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -233,7 +235,8 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); - String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("status")); assertTrue(errorMsg, String.valueOf(inferenceResult.get("status")).contains("400")); assertTrue(errorMsg, inferenceResult.containsKey("error")); From 31fb324d191b9d7867fe18180adc6161d1d9ef07 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 17 Jun 2024 09:29:45 +0800 Subject: [PATCH 15/15] rebase main Signed-off-by: zane-neo --- .../ml/rest/RestBedRockInferenceIT.java | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index e9df3768bf..286d45d308 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -38,7 +38,7 @@ public void setup() throws IOException, InterruptedException { public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + if (tokenNotSet()) { return; } String templates = Files @@ -93,11 +93,8 @@ private void validateOutput(String errorMsg, Map output) { } public void test_bedrock_multimodal_model() throws Exception { - String imageBase64 = - "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { - log.info("#### The AWS credentials are not set. Skipping test. ####"); + if (tokenNotSet()) { return; } String templates = Files @@ -126,7 +123,8 @@ public void test_bedrock_multimodal_model() throws Exception { bedrockEmbeddingModelName, true ); - + String imageBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); @@ -148,8 +146,7 @@ public void test_bedrock_multimodal_model() throws Exception { public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { - log.info("#### The AWS credentials are not set. Skipping test. ####"); + if (tokenNotSet()) { return; } String templates = Files @@ -199,8 +196,7 @@ public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() throws Exception { // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { - log.info("#### The AWS credentials are not set. Skipping test. ####"); + if (tokenNotSet()) { return; } String templates = Files @@ -245,4 +241,12 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro assertEquals(errorMsg, "No input text or image provided", ((Map) inferenceResult.get("error")).get("reason")); } } + + private boolean tokenNotSet() { + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return true; + } + return false; + } }