Skip to content

Commit

Permalink
change llm_generated_action_input to llm_generated_input
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Dec 30, 2024
1 parent 997c50b commit 7b57dfd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class AgentUtils {
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
public static final String DISABLE_TRACE = "disable_trace";
public static final String VERBOSE = "verbose";
public static final String LLM_GEN_INPUT = "llm_generated_input";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
Expand Down Expand Up @@ -472,7 +473,7 @@ public static Map<String, String> constructToolParams(
if (toolSpecConfigMap != null) {
toolParams.putAll(toolSpecConfigMap);
}
toolParams.put("llm_generated_action_input", actionInput);
toolParams.put(LLM_GEN_INPUT, actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
Expand Down Expand Up @@ -608,7 +609,7 @@ public void testConstructToolParams() {
Assert.assertEquals("abc", toolParams.get("detectorName"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
});
}

Expand All @@ -619,7 +620,7 @@ public void testConstructToolParamsNullActionInput() {
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertNull(toolParams.get("llm_generated_action_input"));
Assert.assertNull(toolParams.get(LLM_GEN_INPUT));
Assert.assertNull(toolParams.get("input"));
});
}
Expand All @@ -633,7 +634,7 @@ public void testConstructToolParams_UseOriginalInput() {
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(question, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
});
Expand Down Expand Up @@ -661,7 +662,7 @@ public void testConstructToolParams_PlaceholderConfigInput() {
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}

@Test
Expand All @@ -686,7 +687,7 @@ public void testConstructToolParams_PlaceholderConfigInputJson() {
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}

private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
Expand Down

0 comments on commit 7b57dfd

Please sign in to comment.