From ab2d736b3f4871a8c596bc8746bb9edbfbbc7f26 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 14 Oct 2024 15:30:06 +0800 Subject: [PATCH] Add UT and ITs Signed-off-by: zane-neo --- .../BedrockEmbeddingPreProcessFunction.java | 4 + ...edrockEmbeddingPreProcessFunctionTest.java | 8 ++ .../ml/rest/RestBedRockInferenceIT.java | 102 ++++++++++++++++++ .../BedRockEmbeddingV2ModelBodies.json | 66 ++++++++++++ 4 files changed, 180 insertions(+) create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index cbc140fcc1..34b72bee97 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -15,6 +15,9 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import lombok.extern.slf4j.Slf4j; + +@Slf4j public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { public BedrockEmbeddingPreProcessFunction() { @@ -40,6 +43,7 @@ public RemoteInferenceInputDataSet process(Map connectorParams, // Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html // Default dimension is 1024 int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024); + log.error("The bedrock dimensions parameter value is: {}", dimensions); Map processedResult = Map .of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions)); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index eb6e023c34..228baec782 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -84,4 +84,12 @@ public void process_TextDocsInput_withConnectorParams() { assertEquals(2, dataSet.getParameters().size()); assertEquals("1024", dataSet.getParameters().get("dimensions")); } + + @Test + public void process_TextDocsInput_withoutConnectorParams() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(Map.of(), mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("1024", dataSet.getParameters().get("dimensions")); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 286d45d308..d8e21471d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -16,6 +16,7 @@ import org.junit.Before; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -242,6 +243,107 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro } } + public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = "with_connector_dimensions"; + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add("Can you tell me a joke?"); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 512, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + + public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = "with_request_dimensions"; + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateMap.get("with_request_dimensions")), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512")) + .build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 512, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + private boolean tokenNotSet() { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { log.info("#### The AWS credentials are not set. Skipping test. ####"); diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json new file mode 100644 index 0000000000..a674843b94 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json @@ -0,0 +1,66 @@ +{ + "with_connector_dimensions": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "input_docs_processed_step_size": "1", + "dimensions": "512" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + + "with_request_dimensions": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "input_docs_processed_step_size": "1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +}