-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Add multi modal default preprocess function Signed-off-by: zane-neo <[email protected]> * Address comments Signed-off-by: zane-neo <[email protected]> * address comments Signed-off-by: zane-neo <[email protected]> * add IT Signed-off-by: zane-neo <[email protected]> * Fix IT Signed-off-by: zane-neo <[email protected]> * Update common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java Co-authored-by: Yaliang Wu <[email protected]> Signed-off-by: zane-neo <[email protected]> * fix test Signed-off-by: Yaliang Wu <[email protected]> * Add more ITs Signed-off-by: zane-neo <[email protected]> * Fix failure ITs Signed-off-by: zane-neo <[email protected]> * fix failure IT Signed-off-by: zane-neo <[email protected]> * Fix failure ITs Signed-off-by: zane-neo <[email protected]> * format code Signed-off-by: zane-neo <[email protected]> * Add error response to make it esay to figure out the failure root cause Signed-off-by: zane-neo <[email protected]> * format code Signed-off-by: zane-neo <[email protected]> * rebase main Signed-off-by: zane-neo <[email protected]> --------- Signed-off-by: zane-neo <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Co-authored-by: Yaliang Wu <[email protected]> (cherry picked from commit 0e89c17) Co-authored-by: zane-neo <[email protected]>
- Loading branch information
1 parent
c3b0d8c
commit 7eadf6d
Showing
7 changed files
with
427 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
...earch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* | ||
* * Copyright OpenSearch Contributors | ||
* * SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
/** | ||
* This class provides a pre-processing function for multi-modal input data. | ||
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. | ||
* The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input. | ||
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. | ||
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. | ||
*/ | ||
public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public MultiModalConnectorPreProcessFunction() { | ||
this.returnDirectlyForRemoteInferenceInput = true; | ||
} | ||
|
||
@Override | ||
public void validate(MLInput mlInput) { | ||
validateTextDocsInput(mlInput); | ||
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); | ||
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { | ||
throw new IllegalArgumentException("No input text or image provided"); | ||
} | ||
} | ||
|
||
/** | ||
* @param mlInput The input data to be processed. | ||
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. | ||
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. | ||
* The inputText will always show up in the first document, even it's null. | ||
*/ | ||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); | ||
Map<String, String> parametersMap = new HashMap<>(); | ||
parametersMap.put("inputText", inputData.getDocs().get(0)); | ||
if (inputData.getDocs().size() > 1) { | ||
parametersMap.put("inputImage", inputData.getDocs().get(1)); | ||
} | ||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build(); | ||
|
||
} | ||
} |
99 changes: 99 additions & 0 deletions
99
...h/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
public class MultiModalConnectorPreProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
MultiModalConnectorPreProcessFunction function; | ||
|
||
TextSimilarityInputDataSet textSimilarityInputDataSet; | ||
TextDocsInputDataSet textDocsInputDataSet; | ||
RemoteInferenceInputDataSet remoteInferenceInputDataSet; | ||
|
||
MLInput textEmbeddingInput; | ||
MLInput textSimilarityInput; | ||
MLInput remoteInferenceInput; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new MultiModalConnectorPreProcessFunction(); | ||
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); | ||
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); | ||
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); | ||
|
||
textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); | ||
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenNullInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Preprocess function input can't be null"); | ||
function.apply(null); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenWrongInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); | ||
function.apply(textSimilarityInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenCorrectInput_expectCorrectOutput() { | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(2, dataSet.getParameters().size()); | ||
assertEquals("hello", dataSet.getParameters().get("inputText")); | ||
assertEquals("world", dataSet.getParameters().get("inputImage")); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenInputTextOnly_expectInputTextShowUp() { | ||
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(1, dataSet.getParameters().size()); | ||
assertEquals("hello", dataSet.getParameters().get("inputText")); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("No input text or image provided"); | ||
List<String> docs = new ArrayList<>(); | ||
docs.add(null); | ||
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { | ||
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); | ||
assertEquals(remoteInferenceInputDataSet, dataSet); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.