diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index d8f8d6da94..8d01c59bf2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -60,6 +60,7 @@ public class AgentUtils { public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix"; public static final String DISABLE_TRACE = "disable_trace"; public static final String VERBOSE = "verbose"; + public static final String LLM_GEN_INPUT = "llm_generated_input"; public static String addExamplesToPrompt(Map parameters, String prompt) { Map examplesMap = new HashMap<>(); @@ -472,6 +473,11 @@ public static Map constructToolParams( if (toolSpecConfigMap != null) { toolParams.putAll(toolSpecConfigMap); } + toolParams.put(LLM_GEN_INPUT, actionInput); + if (isJson(actionInput)) { + Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); + toolParams.putAll(params); + } if (tools.get(action).useOriginalInput()) { toolParams.put("input", question); lastActionInput.set(question); @@ -486,10 +492,6 @@ public static Map constructToolParams( } } else { toolParams.put("input", actionInput); - if (isJson(actionInput)) { - Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); - toolParams.putAll(params); - } } return toolParams; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 496ce00136..c2bf05205a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; @@ -603,11 +604,24 @@ public void testConstructToolParams() { String question = "dummy question"; String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; verifyConstructToolParams(question, actionInput, (toolParams) -> { - Assert.assertEquals(4, toolParams.size()); + Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(actionInput, toolParams.get("input")); Assert.assertEquals("abc", toolParams.get("detectorName")); Assert.assertEquals("sample-data", toolParams.get("indices")); Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); + }); + } + + @Test + public void testConstructToolParamsNullActionInput() { + String question = "dummy question"; + String actionInput = null; + verifyConstructToolParams(question, actionInput, (toolParams) -> { + Assert.assertEquals(3, toolParams.size()); + Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertNull(toolParams.get(LLM_GEN_INPUT)); + Assert.assertNull(toolParams.get("input")); }); } @@ -617,12 +631,65 @@ public void testConstructToolParams_UseOriginalInput() { String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; when(tool1.useOriginalInput()).thenReturn(true); verifyConstructToolParams(question, actionInput, (toolParams) -> { - Assert.assertEquals(2, toolParams.size()); + Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(question, toolParams.get("input")); Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); + Assert.assertEquals("sample-data", toolParams.get("indices")); + Assert.assertEquals("abc", toolParams.get("detectorName")); }); } + @Test + public void testConstructToolParams_PlaceholderConfigInput() { + String question = "dummy question"; + String actionInput = "action input"; + String preConfigInputStr = "Config Input: "; + Map tools = Map.of("tool1", tool1); + Map toolSpecMap = Map + .of( + "tool1", + MLToolSpec + .builder() + .type("tool1") + .parameters(Map.of("key1", "value1")) + .configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_input}")) + .build() + ); + AtomicReference lastActionInput = new AtomicReference<>(); + String action = "tool1"; + Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); + Assert.assertEquals(3, toolParams.size()); + Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input")); + Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); + } + + @Test + public void testConstructToolParams_PlaceholderConfigInputJson() { + String question = "dummy question"; + String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; + String preConfigInputStr = "Config Input: "; + Map tools = Map.of("tool1", tool1); + Map toolSpecMap = Map + .of( + "tool1", + MLToolSpec + .builder() + .type("tool1") + .parameters(Map.of("key1", "value1")) + .configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}")) + .build() + ); + AtomicReference lastActionInput = new AtomicReference<>(); + String action = "tool1"; + Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); + Assert.assertEquals(5, toolParams.size()); + Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input")); + Assert.assertEquals("value1", toolParams.get("key1")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); + } + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { Map tools = Map.of("tool1", tool1); Map toolSpecMap = Map diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 5044d0a506..41982abe5b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -706,7 +706,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(14, ((Map) argumentCaptor.getValue()).size()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -763,7 +763,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -793,7 +793,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input"));