forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bedrock embedding generation issue (opensearch-project#2495)
* Fix bedrock connector embedding generation issue Signed-off-by: zane-neo <[email protected]> * format code Signed-off-by: zane-neo <[email protected]> * add IT Signed-off-by: zane-neo <[email protected]> * add ITs Signed-off-by: zane-neo <[email protected]> * format code Signed-off-by: zane-neo <[email protected]> * change input to fix number format exception in local Signed-off-by: zane-neo <[email protected]> * Add log to identify the failure IT root cause Signed-off-by: zane-neo <[email protected]> * format code Signed-off-by: zane-neo <[email protected]> * remove debug log Signed-off-by: zane-neo <[email protected]> * Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu <[email protected]> Signed-off-by: zane-neo <[email protected]> * address comments Signed-off-by: zane-neo <[email protected]> --------- Signed-off-by: zane-neo <[email protected]> Co-authored-by: Yaliang Wu <[email protected]>
- Loading branch information
Showing
5 changed files
with
248 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
|
||
import lombok.SneakyThrows; | ||
|
||
public class RestBedRockInferenceIT extends MLCommonsRestTestCase { | ||
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); | ||
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); | ||
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); | ||
private static final String GITHUB_CI_AWS_REGION = "us-west-2"; | ||
|
||
@SneakyThrows | ||
@Before | ||
public void setup() throws IOException, InterruptedException { | ||
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); | ||
Thread.sleep(20000); | ||
} | ||
|
||
public void test_bedrock_embedding_model() throws Exception { | ||
// Skip test if key is null | ||
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") | ||
.toURI() | ||
) | ||
); | ||
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class); | ||
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) { | ||
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); | ||
String testCaseName = templateEntry.getKey(); | ||
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
StringUtils.gson.toJson(templateEntry.getValue()), | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
|
||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 2, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, output.get(1) instanceof Map); | ||
validateOutput(errorMsg, (Map) output.get(0)); | ||
validateOutput(errorMsg, (Map) output.get(1)); | ||
} | ||
} | ||
|
||
private void validateOutput(String errorMsg, Map<String, Object> output) { | ||
assertTrue(errorMsg, output.containsKey("output")); | ||
assertTrue(errorMsg, output.get("output") instanceof List); | ||
List outputList = (List) output.get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size()); | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
{ | ||
"without_step_size": { | ||
"name": "Amazon Bedrock Connector: embedding", | ||
"description": "The connector to bedrock Titan embedding model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-text-v1" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }", | ||
"pre_process_function": "connector.pre_process.bedrock.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
}, | ||
"with_step_size": { | ||
"name": "Amazon Bedrock Connector: embedding", | ||
"description": "The connector to bedrock Titan embedding model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-text-v1", | ||
"input_docs_processed_step_size": "1" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }", | ||
"pre_process_function": "connector.pre_process.bedrock.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
} | ||
} |