Skip to content

Commit

Permalink
Fix bedrock connector embedding generation issue
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jun 4, 2024
1 parent 9b072c4 commit e22fef5
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti

/**
* Calculate the chunk size.
* @param textDocsInputDataSet
* @param textDocsInputDataSet Input dataset in textDocsInputDataSet format.
* @return Tuple of chunk size and step size.
*/
private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) {
default Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) {
int textDocsLength = textDocsInputDataSet.getDocs().size();
Map<String, String> parameters = getConnector().getParameters();
if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
Expand All @@ -107,11 +107,15 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
throw new IllegalArgumentException("no predict action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
if (preProcessFunction == null) {
// default preprocess case, consider this a batch.
return Tuple.tuple(1, textDocsLength);
} else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
|| !MLPreProcessFunction.contains(preProcessFunction)) {
// bedrock and user defined preprocess script, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
}
// consider as batch.
//Other cases: non-bedrock and user defined preprocess script, consider as batch.
return Tuple.tuple(1, textDocsLength);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,75 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("aws_sigv4")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("aws_sigv4")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}
}

0 comments on commit e22fef5

Please sign in to comment.