From 9478bfaa6e5b84627d63927b6b5f10a2946de2ca Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 09:08:40 +0800 Subject: [PATCH] Add multi modal default preprocess function Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 4 + .../ConnectorPreProcessFunction.java | 5 + ...MultiModalEmbeddingPreProcessFunction.java | 39 +++++++ ...iModalEmbeddingPreProcessFunctionTest.java | 101 ++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index d2d65ebdfd..c7aabb3aff 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,6 +8,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -22,6 +23,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; + public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; @@ -34,7 +36,9 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 5701e0fa3a..906f719157 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -15,6 +15,7 @@ import org.opensearch.script.TemplateScript; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.function.Function; @@ -46,6 +47,10 @@ public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); } + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.size() == 1 && docs.get(0) == null) { + throw new IllegalArgumentException("No input text or image provided"); + } } protected String executeScript(ScriptService scriptService, String painlessScript, Map params) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..19ce98f1b7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java @@ -0,0 +1,39 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public MultiModalEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + // The input will must have inputText even it's null, input image is optional. + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + if (inputData.getDocs().size() == 1) { + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))))).build(); + } else { + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "inputImage", inputData.getDocs().get(1))))).build(); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..d5319fd6ce --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.w3c.dom.Text; + +import java.rmi.Remote; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class MultiModalEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MultiModalEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; + + @Before + public void setUp() { + function = new MultiModalEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + function.apply(textSimilarityInput); + } + + @Test + public void process_input_text_image() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + assertEquals("world", dataSet.getParameters().get("inputImage")); + } + + @Test + public void process_input_text_only() { + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + } + + @Test + public void process_input_text_null() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No input text or image provided"); + List docs = new ArrayList<>(); + docs.add(null); + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + } + + @Test + public void process_RemoteInferenceInput() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } +}