From 73198b1c23b635681cba2be756222a05baaec6e1 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 23 Jul 2024 13:45:01 -0700 Subject: [PATCH] pass all parameters including chat_history to run tools Signed-off-by: Jing Zhang --- .../algorithms/agent/MLChatAgentRunner.java | 5 +- .../agent/MLChatAgentRunnerTest.java | 54 ++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4740565dd1..4b14f1af17 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -473,7 +473,10 @@ private static void runTool( llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { - tools.get(action).run(toolParams, toolListener); // run tool + Map parameters = new HashMap<>(); + parameters.putAll(tmpParameters); + parameters.putAll(toolParams); + tools.get(action).run(parameters, toolListener); // run tool } } catch (Exception e) { nextStepListener diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 760d9f9c06..1a1bc27dd1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest { @Captor private ArgumentCaptor> mlMemoryManagerCapture; @Captor - private ArgumentCaptor> ToolParamsCapture; + private ArgumentCaptor> toolParamsCapture; @Before @SuppressWarnings("unchecked") @@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() { Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class)); } + @Test + public void testToolExecutionWithChatHistoryParameter() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .parameters(Map.of("firsttoolspec", "firsttoolspec")) + .description("first tool spec") + .type(FIRST_TOOL) + .includeOutputInAgentResponse(false) + .build(); + MLToolSpec secondToolSpec = MLToolSpec + .builder() + .name(SECOND_TOOL) + .parameters(Map.of("secondtoolspec", "secondtoolspec")) + .description("second tool spec") + .type(SECOND_TOOL) + .includeOutputInAgentResponse(true) + .build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .memory(mlMemorySpec) + .llm(llmSpec) + .description("mlagent description") + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + List interactionList = generateInteractions(2); + Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build(); + interactionList.add(inProgressInteraction); + listener.onResponse(interactionList); + return null; + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + + doAnswer(generateToolResponse("First tool response")) + .when(firstTool) + .run(toolParamsCapture.capture(), toolListenerCaptor.capture()); + + HashMap params = new HashMap<>(); + params.put(MESSAGE_HISTORY_LIMIT, "5"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); + Assert.assertFalse(chatHistory.contains("input-99")); + Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue()); + Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY)); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();