Skip to content

Commit

Permalink
add llm generated action input as parameters for tool execution in co…
Browse files Browse the repository at this point in the history
…nversational agent

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Dec 2, 2024
1 parent 78a304a commit b1c518a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ public static Map<String, String> constructToolParams(
if (toolSpecConfigMap != null) {
toolParams.putAll(toolSpecConfigMap);
}
toolParams.put("llm_generated_action_input", actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
if (tools.get(action).useOriginalInput()) {
toolParams.put("input", question);
lastActionInput.set(question);
Expand All @@ -486,10 +491,6 @@ public static Map<String, String> constructToolParams(
}
} else {
toolParams.put("input", actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
}
return toolParams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,12 @@ public void testConstructToolParams() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(4, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(actionInput, toolParams.get("input"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
});
}

Expand All @@ -617,12 +618,65 @@ public void testConstructToolParams_UseOriginalInput() {
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
when(tool1.useOriginalInput()).thenReturn(true);
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(2, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(question, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
});
}

@Test
public void testConstructToolParams_PlaceholderConfigInput() {
String question = "dummy question";
String actionInput = "action input";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_action_input}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
}

@Test
public void testConstructToolParams_PlaceholderConfigInputJson() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
}

private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
Expand Down

0 comments on commit b1c518a

Please sign in to comment.