Skip to content

Commit

Permalink
Fix bedrock embedding generation issue (opensearch-project#2495)
Browse files Browse the repository at this point in the history
* Fix bedrock connector embedding generation issue

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* add IT

Signed-off-by: zane-neo <[email protected]>

* add ITs

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* change input to fix number format exception in local

Signed-off-by: zane-neo <[email protected]>

* Add log to identify the failure IT root cause

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* remove debug log

Signed-off-by: zane-neo <[email protected]>

* Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

Co-authored-by: Yaliang Wu <[email protected]>
Signed-off-by: zane-neo <[email protected]>

* address comments

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
zane-neo and ylwu-amzn authored Jun 11, 2024
1 parent 9fa49f4 commit 62e238f
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask

/**
* Calculate the chunk size.
* @param textDocsInputDataSet
* @param textDocsInputDataSet Input dataset in textDocsInputDataSet format.
* @return Tuple of chunk size and step size.
*/
private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) {
Expand All @@ -117,11 +117,15 @@ private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputD
throw new IllegalArgumentException("no " + action + " action found");
}
String preProcessFunction = connectorAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
if (preProcessFunction == null) {
// default preprocess case, consider this a batch.
return Tuple.tuple(1, textDocsLength);
} else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
|| !MLPreProcessFunction.contains(preProcessFunction)) {
// bedrock and user defined preprocess script, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
}
// consider as batch.
// Other cases: non-bedrock and user defined preprocess script, consider as batch.
return Tuple.tuple(1, textDocsLength);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,85 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre
);
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("aws_sigv4")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executeAction(
PREDICT.name(),
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
actionListener
);
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("aws_sigv4")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executeAction(
PREDICT.name(),
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
actionListener
);
}

@Test
public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() {
ConnectorAction predictAction = ConnectorAction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,13 @@ public Map predictTextEmbedding(String modelId) throws IOException {
return result;
}

public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException {
String requestBody = TestHelper.toJsonString(input);
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
return parseResponseToMap(response);
}

public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
return (modelProfile) -> {
if (modelProfile.containsKey("model_state")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import lombok.SneakyThrows;

public class RestBedRockInferenceIT extends MLCommonsRestTestCase {
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";

@SneakyThrows
@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);
}

public void test_bedrock_embedding_model() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) {
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = templateEntry.getKey();
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName);
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateEntry.getValue()),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 2, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, output.get(1) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0));
validateOutput(errorMsg, (Map) output.get(1));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"without_step_size": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v1"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
},
"with_step_size": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v1",
"input_docs_processed_step_size": "1"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
}
}

0 comments on commit 62e238f

Please sign in to comment.