Skip to content

Commit

Permalink
ut for malformed response field
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Feb 7, 2024
1 parent 959918d commit 112c5a5
Showing 1 changed file with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,28 @@ public void testMLModelsWithDefaultOutputParserAndCustomizedResponseField() thro
assertEquals("action1", future.get());
}

@Test
public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throws ExecutionException, InterruptedException {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
doAnswer(invocation -> {

ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);

actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId", "response_field", "malformed field"));
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
tool.run(null, listener);

future.join();
assertEquals(null, future.get());
}

@Test
public void testMLModelsWithCustomizedOutputParser() {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
Expand Down

0 comments on commit 112c5a5

Please sign in to comment.