Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass all parameters including chat_history to run tools #2714

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,10 @@ private static void runTool(
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, toolListener); // run tool
} else {
tools.get(action).run(toolParams, toolListener); // run tool
Map<String, String> parameters = new HashMap<>();
parameters.putAll(tmpParameters);
parameters.putAll(toolParams);
tools.get(action).run(parameters, toolListener); // run tool
}
} catch (Exception e) {
nextStepListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest {
@Captor
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;
@Captor
private ArgumentCaptor<Map<String, String>> ToolParamsCapture;
private ArgumentCaptor<Map<String, String>> toolParamsCapture;

@Before
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -706,7 +706,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
assertEquals(14, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() {
Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
}

@Test
public void testToolExecutionWithChatHistoryParameter() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec
.builder()
.name(FIRST_TOOL)
.parameters(Map.of("firsttoolspec", "firsttoolspec"))
.description("first tool spec")
.type(FIRST_TOOL)
.includeOutputInAgentResponse(false)
.build();
MLToolSpec secondToolSpec = MLToolSpec
.builder()
.name(SECOND_TOOL)
.parameters(Map.of("secondtoolspec", "secondtoolspec"))
.description("second tool spec")
.type(SECOND_TOOL)
.includeOutputInAgentResponse(true)
.build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.memory(mlMemorySpec)
.llm(llmSpec)
.description("mlagent description")
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();

doAnswer(invocation -> {
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
List<Interaction> interactionList = generateInteractions(2);
Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build();
interactionList.add(inProgressInteraction);
listener.onResponse(interactionList);
return null;
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture());

doAnswer(generateToolResponse("First tool response"))
.when(firstTool)
.run(toolParamsCapture.capture(), toolListenerCaptor.capture());

HashMap<String, String> params = new HashMap<>();
params.put(MESSAGE_HISTORY_LIMIT, "5");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY);
Assert.assertFalse(chatHistory.contains("input-99"));
Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue());
Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY));
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down
Loading