From e82d89e48749d7d99c1c5b8bb094000a2dd1aa8a Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jun 2024 17:10:18 -0700 Subject: [PATCH] fix test Signed-off-by: Yaliang Wu --- .../preprocess/MultiModalConnectorPreProcessFunction.java | 1 + .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 54d56f12f1..231c68c48d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4656acc6b7..d967ae0676 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -73,7 +73,7 @@ public void test_bedrock_multimodal_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - Map inferenceResult = predictTextEmbedding(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size());