From cf9ed90f3c5b66765d6410d93af8464e432164ca Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 23 Jul 2024 23:31:30 -0700 Subject: [PATCH] pass all parameters including chat_history to run tools (#2714) Signed-off-by: Jing Zhang --- .../algorithms/agent/MLChatAgentRunner.java | 5 +- .../agent/MLChatAgentRunnerTest.java | 58 ++++++++++++++++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4740565dd1..4b14f1af17 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -473,7 +473,10 @@ private static void runTool( llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { - tools.get(action).run(toolParams, toolListener); // run tool + Map parameters = new HashMap<>(); + parameters.putAll(tmpParameters); + parameters.putAll(toolParams); + tools.get(action).run(parameters, toolListener); // run tool } } catch (Exception e) { nextStepListener diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 760d9f9c06..0d7d24eddc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest { @Captor private ArgumentCaptor> mlMemoryManagerCapture; @Captor - private ArgumentCaptor> ToolParamsCapture; + private ArgumentCaptor> toolParamsCapture; @Before @SuppressWarnings("unchecked") @@ -706,7 +706,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(3, ((Map) argumentCaptor.getValue()).size()); + assertEquals(14, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(3, ((Map) argumentCaptor.getValue()).size()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() { Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class)); } + @Test + public void testToolExecutionWithChatHistoryParameter() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .parameters(Map.of("firsttoolspec", "firsttoolspec")) + .description("first tool spec") + .type(FIRST_TOOL) + .includeOutputInAgentResponse(false) + .build(); + MLToolSpec secondToolSpec = MLToolSpec + .builder() + .name(SECOND_TOOL) + .parameters(Map.of("secondtoolspec", "secondtoolspec")) + .description("second tool spec") + .type(SECOND_TOOL) + .includeOutputInAgentResponse(true) + .build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .memory(mlMemorySpec) + .llm(llmSpec) + .description("mlagent description") + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + List interactionList = generateInteractions(2); + Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build(); + interactionList.add(inProgressInteraction); + listener.onResponse(interactionList); + return null; + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + + doAnswer(generateToolResponse("First tool response")) + .when(firstTool) + .run(toolParamsCapture.capture(), toolListenerCaptor.capture()); + + HashMap params = new HashMap<>(); + params.put(MESSAGE_HISTORY_LIMIT, "5"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); + Assert.assertFalse(chatHistory.contains("input-99")); + Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue()); + Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY)); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();