diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 3a5a3427a8..723da8c07d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; @@ -21,6 +22,7 @@ public class MLPreProcessFunction { private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; + public static final String IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.cohere.multimodal_embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding"; @@ -37,7 +39,10 @@ public class MLPreProcessFunction { BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); + CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction = + new CohereMultiModalEmbeddingPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..80c615cb10 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunction.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +public class CohereMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public CohereMultiModalEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.isEmpty() || (docs.size() == 1 && docs.getFirst() == null)) { + throw new IllegalArgumentException("No image provided"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map parametersMap = new HashMap<>(); + + /** + * Cohere multi-modal model expects either image or texts, not both. + * For image, customer can use this pre-process function. For texts, customer can use + * connector.pre_process.cohere.embedding + * Cohere expects An array of image data URIs for the model to embed. Maximum number of images per call is 1. + */ + parametersMap.put("images", inputData.getDocs().getFirst()); + return RemoteInferenceInputDataSet + .builder() + .parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))) + .build(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..e16f56287d --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +public class CohereMultiModalEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CohereMultiModalEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; + + @Before + public void setUp() { + function = new CohereMultiModalEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(List.of("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(List.of("imageString")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("images", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + } + + @Test + public void testProcess_whenNullInput_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void testProcess_whenWrongInput_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + function.apply(textSimilarityInput); + } + + @Test + public void testProcess_whenCorrectInput_expectCorrectOutput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("imageString", dataSet.getParameters().get("images")); + + } + + @Test + public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No image provided"); + List docs = new ArrayList<>(); + docs.add(null); + TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + } + + @Test + public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java index 6ea8da20f9..4442ff7339 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java @@ -39,7 +39,7 @@ public class MultiModalConnectorPreProcessFunctionTest { @Before public void setUp() { function = new MultiModalConnectorPreProcessFunction(); - textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(List.of("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder()