Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add action input as parameters for tool execution in conversational agent #3200

Merged
merged 3 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ public static Map<String, String> constructToolParams(
if (toolSpecConfigMap != null) {
toolParams.putAll(toolSpecConfigMap);
}
toolParams.put("llm_generated_action_input", actionInput);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe no need to mention action explicitly considering REST API uses tool. User may feel confused about tool and action. How about just llm_generated_input ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can also consider using constant for this string since it is reused in the tests

if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
if (tools.get(action).useOriginalInput()) {
toolParams.put("input", question);
lastActionInput.set(question);
Expand All @@ -486,10 +491,6 @@ public static Map<String, String> constructToolParams(
}
} else {
toolParams.put("input", actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
}
return toolParams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,24 @@ public void testConstructToolParams() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(4, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(actionInput, toolParams.get("input"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
});
}

@Test
public void testConstructToolParamsNullActionInput() {
String question = "dummy question";
String actionInput = null;
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertNull(toolParams.get("llm_generated_action_input"));
Assert.assertNull(toolParams.get("input"));
});
}

Expand All @@ -617,12 +630,65 @@ public void testConstructToolParams_UseOriginalInput() {
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
when(tool1.useOriginalInput()).thenReturn(true);
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(2, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(question, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
});
}

@Test
public void testConstructToolParams_PlaceholderConfigInput() {
String question = "dummy question";
String actionInput = "action input";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_action_input}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
}

@Test
public void testConstructToolParams_PlaceholderConfigInputJson() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get("llm_generated_action_input"));
}

private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(14, ((Map) argumentCaptor.getValue()).size());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -763,7 +763,7 @@ public void testToolConfig() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
// The value of input should be "config_value".
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down Expand Up @@ -793,7 +793,7 @@ public void testToolConfigWithInputPlaceholder() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down
Loading