Skip to content

Commit

Permalink
pass all parameters including chat_history to run tools
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Jul 23, 2024
1 parent 9b413a7 commit bef8770
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,10 @@ private static void runTool(
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, toolListener); // run tool
} else {
tools.get(action).run(toolParams, toolListener); // run tool
Map<String, String> parameters = new HashMap<>();
parameters.putAll(tmpParameters);
parameters.putAll(toolParams);
tools.get(action).run(parameters, toolListener); // run tool
}
} catch (Exception e) {
nextStepListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest {
@Captor
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;
@Captor
private ArgumentCaptor<Map<String, String>> ToolParamsCapture;
private ArgumentCaptor<Map<String, String>> toolParamsCapture;

@Before
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -706,7 +706,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
assertEquals(14, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() {
Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
}

@Test
public void testToolExecutionWithChatHistoryParameter() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec
.builder()
.name(FIRST_TOOL)
.parameters(Map.of("firsttoolspec", "firsttoolspec"))
.description("first tool spec")
.type(FIRST_TOOL)
.includeOutputInAgentResponse(false)
.build();
MLToolSpec secondToolSpec = MLToolSpec
.builder()
.name(SECOND_TOOL)
.parameters(Map.of("secondtoolspec", "secondtoolspec"))
.description("second tool spec")
.type(SECOND_TOOL)
.includeOutputInAgentResponse(true)
.build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.memory(mlMemorySpec)
.llm(llmSpec)
.description("mlagent description")
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();

doAnswer(invocation -> {
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
List<Interaction> interactionList = generateInteractions(2);
Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build();
interactionList.add(inProgressInteraction);
listener.onResponse(interactionList);
return null;
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture());

doAnswer(generateToolResponse("First tool response"))
.when(firstTool)
.run(toolParamsCapture.capture(), toolListenerCaptor.capture());

HashMap<String, String> params = new HashMap<>();
params.put(MESSAGE_HISTORY_LIMIT, "5");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY);
Assert.assertFalse(chatHistory.contains("input-99"));
Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue());
Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY));
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down

0 comments on commit bef8770

Please sign in to comment.