From 3ec993461fae8b6da800c011bba39ababc2aadb3 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 9 May 2024 07:50:49 +0800 Subject: [PATCH] add UTs Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 3 - ...MultiModalEmbeddingPreProcessFunction.java | 36 -------- .../remote/AwsConnectorExecutorTest.java | 83 +++++++++++++++++++ 3 files changed, 83 insertions(+), 39 deletions(-) delete mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 9499c0f0a7..e09d1c550e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,7 +8,6 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; -import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -35,13 +34,11 @@ public class MLPreProcessFunction { CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction(); OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); - MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); - PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java deleted file mode 100644 index 6cd3576d7c..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.connector.functions.preprocess; - -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; - -import java.util.Map; - -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - -public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { - - public MultiModalEmbeddingPreProcessFunction() { - this.returnDirectlyForRemoteInferenceInput = true; - } - - @Override - public void validate(MLInput mlInput) { - validateTextDocsInput(mlInput); - } - - @Override - public RemoteInferenceInputDataSet process(MLInput mlInput) { - TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - if (inputData.getDocs().size() == 1) { - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("inputText", inputData.getDocs().get(0)))).build(); - } else { - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("inputText", inputData.getDocs().get(0), "inputImage", inputData.getDocs().get(1)))).build(); - } - } -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 5e1a9dfacb..86cda87343 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; @@ -30,6 +31,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; @@ -42,6 +44,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -67,10 +70,14 @@ public class AwsConnectorExecutorTest { Encryptor encryptor; + @Mock + private ScriptService scriptService; + @Before public void setUp() { MockitoAnnotations.openMocks(this); encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}")); } @Test @@ -282,4 +289,80 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof IllegalArgumentException; } + + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); + assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException; + assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage()); + } + + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction("\n StringBuilder builder = new StringBuilder();\n builder.append(\"\\\"\");\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\"\\\"\");\n def parameters = \"{\" +\"\\\"text_inputs\\\":\" + builder + \"}\";\n return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + } }