From b1c518a7194d52c53411705cf126de1403225914 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 4 Nov 2024 10:27:41 -0800 Subject: [PATCH] add llm generated action input as parameters for tool execution in conversational agent Signed-off-by: Jing Zhang --- .../engine/algorithms/agent/AgentUtils.java | 9 +-- .../algorithms/agent/AgentUtilsTest.java | 58 ++++++++++++++++++- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index d8f8d6da94..f7493e31fc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -472,6 +472,11 @@ public static Map constructToolParams( if (toolSpecConfigMap != null) { toolParams.putAll(toolSpecConfigMap); } + toolParams.put("llm_generated_action_input", actionInput); + if (isJson(actionInput)) { + Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); + toolParams.putAll(params); + } if (tools.get(action).useOriginalInput()) { toolParams.put("input", question); lastActionInput.set(question); @@ -486,10 +491,6 @@ public static Map constructToolParams( } } else { toolParams.put("input", actionInput); - if (isJson(actionInput)) { - Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); - toolParams.putAll(params); - } } return toolParams; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 496ce00136..347707ad3e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -603,11 +603,12 @@ public void testConstructToolParams() { String question = "dummy question"; String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; verifyConstructToolParams(question, actionInput, (toolParams) -> { - Assert.assertEquals(4, toolParams.size()); + Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(actionInput, toolParams.get("input")); Assert.assertEquals("abc", toolParams.get("detectorName")); Assert.assertEquals("sample-data", toolParams.get("indices")); Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); }); } @@ -617,12 +618,65 @@ public void testConstructToolParams_UseOriginalInput() { String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; when(tool1.useOriginalInput()).thenReturn(true); verifyConstructToolParams(question, actionInput, (toolParams) -> { - Assert.assertEquals(2, toolParams.size()); + Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(question, toolParams.get("input")); Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + Assert.assertEquals("sample-data", toolParams.get("indices")); + Assert.assertEquals("abc", toolParams.get("detectorName")); }); } + @Test + public void testConstructToolParams_PlaceholderConfigInput() { + String question = "dummy question"; + String actionInput = "action input"; + String preConfigInputStr = "Config Input: "; + Map tools = Map.of("tool1", tool1); + Map toolSpecMap = Map + .of( + "tool1", + MLToolSpec + .builder() + .type("tool1") + .parameters(Map.of("key1", "value1")) + .configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_action_input}")) + .build() + ); + AtomicReference lastActionInput = new AtomicReference<>(); + String action = "tool1"; + Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); + Assert.assertEquals(3, toolParams.size()); + Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input")); + Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + } + + @Test + public void testConstructToolParams_PlaceholderConfigInputJson() { + String question = "dummy question"; + String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; + String preConfigInputStr = "Config Input: "; + Map tools = Map.of("tool1", tool1); + Map toolSpecMap = Map + .of( + "tool1", + MLToolSpec + .builder() + .type("tool1") + .parameters(Map.of("key1", "value1")) + .configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}")) + .build() + ); + AtomicReference lastActionInput = new AtomicReference<>(); + String action = "tool1"; + Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); + Assert.assertEquals(5, toolParams.size()); + Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input")); + Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + } + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { Map tools = Map.of("tool1", tool1); Map toolSpecMap = Map