diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 7a783059fe..6d03601d08 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -29,11 +30,13 @@ import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -55,6 +58,7 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.Lists; @@ -376,6 +380,64 @@ private void runReAct( } if (finalAnswer != null) { finalAnswer = finalAnswer.trim(); + String finalAnswer2 = finalAnswer; + // Composite execution response and reply. + final ActionListener executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> { + cotModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build()) + ) + .build() + ); + + List finalModelTensors = new ArrayList<>(); + finalModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), + ModelTensor + .builder() + .name(MLAgentExecutor.PARENT_INTERACTION_ID) + .result(parentInteractionId) + .build() + ) + ) + .build() + ); + finalModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + Collections + .singletonList( + ModelTensor + .builder() + .name("response") + .dataAsMap( + ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo) + ) + .build() + ) + ) + .build() + ); + getFinalAnswer.set(true); + if (verbose) { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); + } else { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); + } + }, listener::onFailure)); + // Sending execution response by internalListener is after the trace and answer saving. + final GroupedActionListener groupedListener = createGroupedListener(2, executionListener); if (conversationIndexMemory != null) { String finalAnswer1 = finalAnswer; // Create final trace message. @@ -387,71 +449,23 @@ private void runReAct( .finalAnswer(true) .sessionId(sessionId) .build(); - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null); - // Update root interaction. + // Save last trace and update final answer in parallel. + conversationIndexMemory + .save( + msgTemp, + parentInteractionId, + traceNumber.addAndGet(1), + null, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); conversationIndexMemory .getMemoryManager() .updateInteraction( parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), - ActionListener.wrap(updateResponse -> { - log.info("Updated final answer into interaction id: {}", parentInteractionId); - log.info("Final answer: {}", finalAnswer1); - }, e -> log.error("Failed to update root interaction", e)) + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) ); } - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer).build()) - ) - .build() - ); - - List finalModelTensors = new ArrayList<>(); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor - .builder() - .name(MLAgentExecutor.PARENT_INTERACTION_ID) - .result(parentInteractionId) - .build() - ) - ) - .build() - ); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .name("response") - .dataAsMap( - ImmutableMap.of("response", finalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo) - ) - .build() - ) - ) - .build() - ); - getFinalAnswer.set(true); - if (verbose) { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); - } else { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); - } return; } @@ -679,4 +693,27 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } + private GroupedActionListener createGroupedListener(final int size, final ActionListener listener) { + return new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(final Collection responses) { + CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class); + log.info("saved message with interaction id: {}", createInteractionResponse.getId()); + UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class); + log.info("Updated final answer into interaction id: {}", updateResponse.getId()); + + listener.onResponse(true); + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }, size); + } + + @SuppressWarnings("unchecked") + private static A extractResponse(final Collection responses, Class c) { + return (A) responses.stream().filter(c::isInstance).findFirst().get(); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 5a700fd602..5eaa10306a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -35,6 +35,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; import org.opensearch.action.StepListener; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -53,6 +54,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; public class MLChatAgentRunnerTest { @@ -97,6 +99,10 @@ public class MLChatAgentRunnerTest { private ConversationIndexMemory conversationIndexMemory; @Mock private MLMemoryManager mlMemoryManager; + @Mock + private CreateInteractionResponse createInteractionResponse; + @Mock + private UpdateResponse updateResponse; @Mock private ConversationIndexMemory.Factory memoryFactory; @@ -104,6 +110,10 @@ public class MLChatAgentRunnerTest { private ArgumentCaptor> memoryFactoryCapture; @Captor private ArgumentCaptor>> memoryInteractionCapture; + @Captor + private ArgumentCaptor> conversationIndexMemoryCapture; + @Captor + private ArgumentCaptor> mlMemoryManagerCapture; @Before @SuppressWarnings("unchecked") @@ -127,6 +137,18 @@ public void setup() { listener.onResponse(conversationIndexMemory); return null; }).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture()); + when(createInteractionResponse.getId()).thenReturn("create_interaction_id"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(createInteractionResponse); + return null; + }).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture()); + when(updateResponse.getId()).thenReturn("update_interaction_id"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), mlMemoryManagerCapture.capture()); mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap); when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool); @@ -668,6 +690,31 @@ public void testToolParameters() { assertNotNull(modelTensorOutput); } + @Test + public void testSaveLastTraceFailure() { + // Mock tool validation to return true. + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with tools + MLAgent mlAgent = createMLAgentWithTools(); + + // Create parameters for the agent + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onFailure(new IllegalArgumentException()); + return null; + }).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture()); + // Run the MLChatAgentRunner + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called + verify(firstTool).run(any(), any()); + + Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class)); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();