forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: zane-neo <[email protected]>
- Loading branch information
Showing
3 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* | ||
* * Copyright OpenSearch Contributors | ||
* * SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
|
||
import lombok.SneakyThrows; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class RestBedRockInferenceIT extends MLCommonsRestTestCase { | ||
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); | ||
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); | ||
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); | ||
private static final String GITHUB_CI_AWS_REGION = "ap-northeast-1"; | ||
|
||
@SneakyThrows | ||
@Before | ||
public void setup() throws IOException, InterruptedException { | ||
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); | ||
Thread.sleep(20000); | ||
} | ||
|
||
public void test_bedrock_multimodal_model() throws Exception { | ||
String imageBase64 = | ||
"iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; | ||
// Skip test if key is null | ||
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { | ||
log.info("#### The AWS credentials are not set. Skipping test. ####"); | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") | ||
.toURI() | ||
) | ||
); | ||
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class); | ||
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) { | ||
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); | ||
String testCaseName = templateEntry.getKey(); | ||
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
StringUtils.gson.toJson(templateEntry.getValue()), | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
|
||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictRemoteModel(modelId, mlInput); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 1, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List); | ||
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size()); | ||
} | ||
|
||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
...src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
{ | ||
"without_step_size": { | ||
"name": "Amazon Bedrock Connector: multimodal", | ||
"description": "The connector to bedrock Titan multimodal model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-image-v1", | ||
"input_docs_processed_step_size": "2" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage}\" }", | ||
"pre_process_function": "connector.pre_process.multimodal.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
}, | ||
"with_step_size": { | ||
"name": "Amazon Bedrock Connector: multimodal", | ||
"description": "The connector to bedrock Titan multimodal model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-image-v1", | ||
"input_docs_processed_step_size": "2" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImages\": \"${parameters.inputImages}\" }", | ||
"pre_process_function": "connector.pre_process.multimodal.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
} | ||
} |