Skip to content

Commit

Permalink
send agent execution response after saving memory (#1999)
Browse files Browse the repository at this point in the history
* send agent execution response after saving memory

Signed-off-by: Jing Zhang <[email protected]>

* spotless

Signed-off-by: Jing Zhang <[email protected]>

---------

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Feb 5, 2024
1 parent 2c63e8d commit c53d586
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -29,11 +30,13 @@
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -55,6 +58,7 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.ml.repackage.com.google.common.collect.Lists;

Expand Down Expand Up @@ -376,6 +380,64 @@ private void runReAct(
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
String finalAnswer2 = finalAnswer;
// Composite execution response and reply.
final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())
)
.build()
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor
.builder()
.name("response")
.dataAsMap(
ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)
)
.build()
)
)
.build()
);
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
}, listener::onFailure));
// Sending execution response by internalListener is after the trace and answer saving.
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(2, executionListener);
if (conversationIndexMemory != null) {
String finalAnswer1 = finalAnswer;
// Create final trace message.
Expand All @@ -387,71 +449,23 @@ private void runReAct(
.finalAnswer(true)
.sessionId(sessionId)
.build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
// Update root interaction.
// Save last trace and update final answer in parallel.
conversationIndexMemory
.save(
msgTemp,
parentInteractionId,
traceNumber.addAndGet(1),
null,
ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
);
conversationIndexMemory
.getMemoryManager()
.updateInteraction(
parentInteractionId,
ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo),
ActionListener.<UpdateResponse>wrap(updateResponse -> {
log.info("Updated final answer into interaction id: {}", parentInteractionId);
log.info("Final answer: {}", finalAnswer1);
}, e -> log.error("Failed to update root interaction", e))
ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
);
}
cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer).build())
)
.build()
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor
.builder()
.name("response")
.dataAsMap(
ImmutableMap.of("response", finalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo)
)
.build()
)
)
.build()
);
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
return;
}

Expand Down Expand Up @@ -679,4 +693,27 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private GroupedActionListener<ActionResponse> createGroupedListener(final int size, final ActionListener<Boolean> listener) {
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
@Override
public void onResponse(final Collection<ActionResponse> responses) {
CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class);
log.info("saved message with interaction id: {}", createInteractionResponse.getId());
UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class);
log.info("Updated final answer into interaction id: {}", updateResponse.getId());

listener.onResponse(true);
}

@Override
public void onFailure(final Exception e) {
listener.onFailure(e);
}
}, size);
}

@SuppressWarnings("unchecked")
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
return (A) responses.stream().filter(c::isInstance).findFirst().get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -53,6 +54,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.MLMemoryManager;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;

public class MLChatAgentRunnerTest {
Expand Down Expand Up @@ -97,13 +99,21 @@ public class MLChatAgentRunnerTest {
private ConversationIndexMemory conversationIndexMemory;
@Mock
private MLMemoryManager mlMemoryManager;
@Mock
private CreateInteractionResponse createInteractionResponse;
@Mock
private UpdateResponse updateResponse;

@Mock
private ConversationIndexMemory.Factory memoryFactory;
@Captor
private ArgumentCaptor<ActionListener<ConversationIndexMemory>> memoryFactoryCapture;
@Captor
private ArgumentCaptor<ActionListener<List<Interaction>>> memoryInteractionCapture;
@Captor
private ArgumentCaptor<ActionListener<CreateInteractionResponse>> conversationIndexMemoryCapture;
@Captor
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;

@Before
@SuppressWarnings("unchecked")
Expand All @@ -127,6 +137,18 @@ public void setup() {
listener.onResponse(conversationIndexMemory);
return null;
}).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture());
when(createInteractionResponse.getId()).thenReturn("create_interaction_id");
doAnswer(invocation -> {
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
listener.onResponse(createInteractionResponse);
return null;
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
when(updateResponse.getId()).thenReturn("update_interaction_id");
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
return null;
}).when(mlMemoryManager).updateInteraction(any(), any(), mlMemoryManagerCapture.capture());

mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool);
Expand Down Expand Up @@ -668,6 +690,31 @@ public void testToolParameters() {
assertNotNull(modelTensorOutput);
}

@Test
public void testSaveLastTraceFailure() {
// Mock tool validation to return true.
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with tools
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

doAnswer(invocation -> {
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
listener.onFailure(new IllegalArgumentException());
return null;
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
// Run the MLChatAgentRunner
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called
verify(firstTool).run(any(), any());

Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down

0 comments on commit c53d586

Please sign in to comment.