Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jun 6, 2024
1 parent aee2a0b commit fece431
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ public abstract class ConnectorPreProcessFunction implements Function<MLInput, R

protected boolean returnDirectlyForRemoteInferenceInput;

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param mlInput the MLInput object to be processed
* @return the RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(MLInput mlInput) {
if (mlInput == null) {
Expand All @@ -50,15 +57,18 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
* @param mlInput the input data to be validated
* @throws IllegalArgumentException if the input dataset is not an instance of TextDocsInputDataSet
* or if there is no input text or image provided
*/
public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName()));
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
}
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input text or image provided");
}
}

protected String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
Expand All @@ -31,6 +32,10 @@ public MultiModalConnectorPreProcessFunction() {
@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input text or image provided");
}
}

/**
Expand Down

0 comments on commit fece431

Please sign in to comment.