diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index f7493e31fc..8d01c59bf2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -60,6 +60,7 @@ public class AgentUtils { public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix"; public static final String DISABLE_TRACE = "disable_trace"; public static final String VERBOSE = "verbose"; + public static final String LLM_GEN_INPUT = "llm_generated_input"; public static String addExamplesToPrompt(Map parameters, String prompt) { Map examplesMap = new HashMap<>(); @@ -472,7 +473,7 @@ public static Map constructToolParams( if (toolSpecConfigMap != null) { toolParams.putAll(toolSpecConfigMap); } - toolParams.put("llm_generated_action_input", actionInput); + toolParams.put(LLM_GEN_INPUT, actionInput); if (isJson(actionInput)) { Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); toolParams.putAll(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 0227c77ed3..c2bf05205a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; @@ -608,7 +609,7 @@ public void testConstructToolParams() { Assert.assertEquals("abc", toolParams.get("detectorName")); Assert.assertEquals("sample-data", toolParams.get("indices")); Assert.assertEquals("value1", toolParams.get("key1")); - Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); }); } @@ -619,7 +620,7 @@ public void testConstructToolParamsNullActionInput() { verifyConstructToolParams(question, actionInput, (toolParams) -> { Assert.assertEquals(3, toolParams.size()); Assert.assertEquals("value1", toolParams.get("key1")); - Assert.assertNull(toolParams.get("llm_generated_action_input")); + Assert.assertNull(toolParams.get(LLM_GEN_INPUT)); Assert.assertNull(toolParams.get("input")); }); } @@ -633,7 +634,7 @@ public void testConstructToolParams_UseOriginalInput() { Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(question, toolParams.get("input")); Assert.assertEquals("value1", toolParams.get("key1")); - Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); Assert.assertEquals("sample-data", toolParams.get("indices")); Assert.assertEquals("abc", toolParams.get("detectorName")); }); @@ -652,7 +653,7 @@ public void testConstructToolParams_PlaceholderConfigInput() { .builder() .type("tool1") .parameters(Map.of("key1", "value1")) - .configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_action_input}")) + .configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_input}")) .build() ); AtomicReference lastActionInput = new AtomicReference<>(); @@ -661,7 +662,7 @@ public void testConstructToolParams_PlaceholderConfigInput() { Assert.assertEquals(3, toolParams.size()); Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input")); Assert.assertEquals("value1", toolParams.get("key1")); - Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); } @Test @@ -686,7 +687,7 @@ public void testConstructToolParams_PlaceholderConfigInputJson() { Assert.assertEquals(5, toolParams.size()); Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input")); Assert.assertEquals("value1", toolParams.get("key1")); - Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input")); + Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); } private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) {