diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index d2d65ebdfd..758e1e3ce0 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,6 +8,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -22,6 +23,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; + public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; @@ -34,7 +36,9 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 5701e0fa3a..c4c88532ef 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -15,16 +15,35 @@ import org.opensearch.script.TemplateScript; import java.util.Collections; +import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Function; import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; +/** + * This abstract class represents a pre-processing function for a connector. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true. + */ @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { + /** + * This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet. + * If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed. + */ protected boolean returnDirectlyForRemoteInferenceInput; + /** + * Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet. + * + * @param mlInput the MLInput object to be processed + * @return the RemoteInferenceInputDataSet resulting from the pre-processing function + * @throws IllegalArgumentException if the input MLInput object is null + */ @Override public RemoteInferenceInputDataSet apply(MLInput mlInput) { if (mlInput == null) { @@ -42,9 +61,17 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { public abstract RemoteInferenceInputDataSet process(MLInput mlInput); + /** + * Validates the input of a pre-process function for text documents. + * + * @param mlInput the input data to be validated + * @throws IllegalArgumentException if the input dataset is not an instance of TextDocsInputDataSet + * or if there is no input text or image provided + */ public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { - throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); + log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName())); + throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'"); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java new file mode 100644 index 0000000000..008b1efe58 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -0,0 +1,59 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +/** + * This class provides a pre-processing function for multi-modal input data. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input. + * The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + */ +public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction { + + public MultiModalConnectorPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { + throw new IllegalArgumentException("No input text or image provided"); + } + } + + /** + * @param mlInput The input data to be processed. + * This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + * The inputText will always show up in the first document, even it's null. + */ + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map parametersMap = new HashMap<>(); + parametersMap.put("inputText", inputData.getDocs().get(0)); + if (inputData.getDocs().size() > 1) { + parametersMap.put("inputImage", inputData.getDocs().get(1)); + } + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build(); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java new file mode 100644 index 0000000000..4bc4c4cd8f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class MultiModalConnectorPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MultiModalConnectorPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; + + @Before + public void setUp() { + function = new MultiModalConnectorPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + } + + @Test + public void testProcess_whenNullInput_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void testProcess_whenWrongInput_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + function.apply(textSimilarityInput); + } + + @Test + public void testProcess_whenCorrectInput_expectCorrectOutput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + assertEquals("world", dataSet.getParameters().get("inputImage")); + } + + @Test + public void testProcess_whenInputTextOnly_expectInputTextShowUp() { + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + } + + @Test + public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No input text or image provided"); + List docs = new ArrayList<>(); + docs.add(null); + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + } + + @Test + public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 4bb43d5748..d7e9f450ce 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -52,6 +52,7 @@ import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; import org.opensearch.common.io.PathUtils; @@ -96,6 +97,9 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds @@ -900,8 +904,14 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + Response response = null; + try { + response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + } catch (ResponseException e) { + log.error(e.getMessage(), e); + response = e.getResponse(); + } return parseResponseToMap(response); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index fea981afe7..286d45d308 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -19,7 +20,9 @@ import org.opensearch.ml.common.utils.StringUtils; import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +@Log4j2 public class RestBedRockInferenceIT extends MLCommonsRestTestCase { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); @@ -35,7 +38,7 @@ public void setup() throws IOException, InterruptedException { public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + if (tokenNotSet()) { return; } String templates = Files @@ -88,4 +91,162 @@ private void validateOutput(String errorMsg, Map output) { assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); } + + public void test_bedrock_multimodal_model() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + String imageBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + + } + + public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } + + public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add(null); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("status")); + assertTrue(errorMsg, String.valueOf(inferenceResult.get("status")).contains("400")); + assertTrue(errorMsg, inferenceResult.containsKey("error")); + assertTrue(errorMsg, inferenceResult.get("error") instanceof Map); + assertEquals(errorMsg, "illegal_argument_exception", ((Map) inferenceResult.get("error")).get("type")); + assertEquals(errorMsg, "No input text or image provided", ((Map) inferenceResult.get("error")).get("reason")); + } + } + + private boolean tokenNotSet() { + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return true; + } + return false; + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json new file mode 100644 index 0000000000..dbba50c434 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json @@ -0,0 +1,63 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: multimodal", + "description": "The connector to bedrock Titan multimodal model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-image-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", + "pre_process_function": "connector.pre_process.multimodal.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: multimodal", + "description": "The connector to bedrock Titan multimodal model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-image-v1", + "input_docs_processed_step_size": "2" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", + "pre_process_function": "connector.pre_process.multimodal.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +}