Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
rbhavna authored Jan 17, 2024
2 parents e56eb2f + 38c71c6 commit 6131325
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -293,13 +294,10 @@ private void runReAct(

StepListener firstListener;
AtomicReference<StepListener<MLTaskResponse>> lastLlmListener = new AtomicReference<>();
AtomicReference<StepListener<Object>> lastToolListener = new AtomicReference<>();
AtomicBoolean getFinalAnswer = new AtomicBoolean(false);
AtomicReference<String> lastTool = new AtomicReference<>();
AtomicReference<String> lastThought = new AtomicReference<>();
AtomicReference<String> lastAction = new AtomicReference<>();
AtomicReference<String> lastActionInput = new AtomicReference<>();
AtomicReference<String> lastActionResult = new AtomicReference<>();
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();

StepListener<?> lastStepListener = null;
Expand Down Expand Up @@ -470,28 +468,51 @@ private void runReAct(
Map<String, String> toolParams = new HashMap<>();
toolParams.put("input", actionInput);
if (tools.get(action).validate(toolParams)) {
if (tools.get(action) instanceof MLModelTool) {
Map<String, String> llmToolTmpParameters = new HashMap<>();
llmToolTmpParameters.putAll(tmpParameters);
llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters());
// TODO: support tool parameter override : langauge_model_tool.prompt
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, nextStepListener); // run tool
} else {
tools.get(action).run(toolParams, nextStepListener); // run tool
try {
String finalAction = action;
ActionListener<Object> toolListener = ActionListener
.wrap(r -> { ((ActionListener<Object>) nextStepListener).onResponse(r); }, e -> {
((ActionListener<Object>) nextStepListener)
.onResponse(
String
.format(
Locale.ROOT,
"Failed to run the tool %s with the error message %s.",
finalAction,
e.getMessage()
)
);
});
if (tools.get(action) instanceof MLModelTool) {
Map<String, String> llmToolTmpParameters = new HashMap<>();
llmToolTmpParameters.putAll(tmpParameters);
llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters());
// TODO: support tool parameter override : langauge_model_tool.prompt
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, toolListener); // run tool
} else {
tools.get(action).run(toolParams, toolListener); // run tool
}
} catch (Exception e) {
((ActionListener<Object>) nextStepListener)
.onResponse(
String
.format(
Locale.ROOT,
"Failed to run the tool %s with the error message %s.",
action,
e.getMessage()
)
);
}
} else {
lastActionResult.set("Tool " + action + " can't work for input: " + actionInput);
lastTool.set(action);
String res = "Tool " + action + " can't work for input: " + actionInput;
String res = String
.format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput);
((ActionListener<Object>) nextStepListener).onResponse(res);
}
} else {
lastTool.set(null);
lastToolListener.set(null);
((ActionListener<Object>) nextStepListener).onResponse("no access to this tool ");
lastActionResult.set("no access to this tool ");

String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action);
((ActionListener<Object>) nextStepListener).onResponse(res);
StringSubstitutor substitutor = new StringSubstitutor(
ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()),
"${parameters.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ public class MLChatAgentRunnerTest {
@Captor
private ArgumentCaptor<StepListener<Object>> nextStepListenerCaptor;

@Captor
private ArgumentCaptor<ActionListener<Object>> toolListenerCaptor;

private MLMemorySpec mlMemorySpec;
@Mock
private ConversationIndexMemory conversationIndexMemory;
Expand Down Expand Up @@ -134,14 +137,8 @@ public void setup() {
when(secondTool.getDescription()).thenReturn("Second tool description");
when(firstTool.validate(Mockito.anyMap())).thenReturn(true);
when(secondTool.validate(Mockito.anyMap())).thenReturn(true);
Mockito
.doAnswer(generateToolResponse("First tool response"))
.when(firstTool)
.run(Mockito.anyMap(), nextStepListenerCaptor.capture());
Mockito
.doAnswer(generateToolResponse("Second tool response"))
.when(secondTool)
.run(Mockito.anyMap(), nextStepListenerCaptor.capture());
Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), toolListenerCaptor.capture());
Mockito.doAnswer(generateToolResponse("Second tool response")).when(secondTool).run(Mockito.anyMap(), toolListenerCaptor.capture());

Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 1", "action", FIRST_TOOL)))
Expand Down Expand Up @@ -422,6 +419,58 @@ public void testToolNotFound() {
verify(secondTool, never()).run(any(), any());
}

@Test
public void testToolFailure() {
// Mock tool validation to return false
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with tools
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

Mockito
.doAnswer(generateToolFailure(new IllegalArgumentException("tool error")))
.when(firstTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
// Run the MLChatAgentRunner
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called
verify(firstTool).run(any(), any());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
assertNotNull(modelTensorOutput);
}

@Test
public void testToolThrowException() {
// Mock tool validation to return false
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with tools
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

Mockito
.doThrow(new IllegalArgumentException("tool error"))
.when(firstTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
// Run the MLChatAgentRunner
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called
verify(firstTool).run(any(), any());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
assertNotNull(modelTensorOutput);
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -463,4 +512,12 @@ private Answer generateToolResponse(String response) {
};
}

private Answer generateToolFailure(Exception e) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onFailure(e);
return null;
};
}

}

0 comments on commit 6131325

Please sign in to comment.