Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jun 11, 2024
1 parent 35e8133 commit a9d0963
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
Expand All @@ -645,7 +645,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
.executeAction(
PREDICT.name(),
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
actionListener
);
}

@Test
Expand All @@ -669,7 +673,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
Expand All @@ -680,7 +684,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
.executeAction(
PREDICT.name(),
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
actionListener
);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,24 @@ public void test_bedrock_embedding_model() throws Exception {

TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictRemoteModel(modelId, mlInput);
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 2, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
assertTrue(errorMsg, output.get(1) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0));
validateOutput(errorMsg, (Map) output.get(1));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}
}

0 comments on commit a9d0963

Please sign in to comment.