Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

send agent execution response after saving memory (#1999) #2030

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -29,11 +30,13 @@
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -55,6 +58,7 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.ml.repackage.com.google.common.collect.Lists;

Expand Down Expand Up @@ -376,6 +380,64 @@ private void runReAct(
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
String finalAnswer2 = finalAnswer;
// Composite execution response and reply.
final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())
)
.build()
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor
.builder()
.name("response")
.dataAsMap(
ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)
)
.build()
)
)
.build()
);
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
}, listener::onFailure));
// Sending execution response by internalListener is after the trace and answer saving.
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(2, executionListener);
if (conversationIndexMemory != null) {
String finalAnswer1 = finalAnswer;
// Create final trace message.
Expand All @@ -387,71 +449,23 @@ private void runReAct(
.finalAnswer(true)
.sessionId(sessionId)
.build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
// Update root interaction.
// Save last trace and update final answer in parallel.
conversationIndexMemory
.save(
msgTemp,
parentInteractionId,
traceNumber.addAndGet(1),
null,
ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
);
conversationIndexMemory
.getMemoryManager()
.updateInteraction(
parentInteractionId,
ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo),
ActionListener.<UpdateResponse>wrap(updateResponse -> {
log.info("Updated final answer into interaction id: {}", parentInteractionId);
log.info("Final answer: {}", finalAnswer1);
}, e -> log.error("Failed to update root interaction", e))
ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
);
}
cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer).build())
)
.build()
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor
.builder()
.name("response")
.dataAsMap(
ImmutableMap.of("response", finalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo)
)
.build()
)
)
.build()
);
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
return;
}

Expand Down Expand Up @@ -679,4 +693,27 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private GroupedActionListener<ActionResponse> createGroupedListener(final int size, final ActionListener<Boolean> listener) {
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
@Override
public void onResponse(final Collection<ActionResponse> responses) {
CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class);
log.info("saved message with interaction id: {}", createInteractionResponse.getId());
UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class);
log.info("Updated final answer into interaction id: {}", updateResponse.getId());

listener.onResponse(true);
}

@Override
public void onFailure(final Exception e) {
listener.onFailure(e);
}
}, size);
}

@SuppressWarnings("unchecked")
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
return (A) responses.stream().filter(c::isInstance).findFirst().get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -53,6 +54,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.MLMemoryManager;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;

public class MLChatAgentRunnerTest {
Expand Down Expand Up @@ -97,13 +99,21 @@ public class MLChatAgentRunnerTest {
private ConversationIndexMemory conversationIndexMemory;
@Mock
private MLMemoryManager mlMemoryManager;
@Mock
private CreateInteractionResponse createInteractionResponse;
@Mock
private UpdateResponse updateResponse;

@Mock
private ConversationIndexMemory.Factory memoryFactory;
@Captor
private ArgumentCaptor<ActionListener<ConversationIndexMemory>> memoryFactoryCapture;
@Captor
private ArgumentCaptor<ActionListener<List<Interaction>>> memoryInteractionCapture;
@Captor
private ArgumentCaptor<ActionListener<CreateInteractionResponse>> conversationIndexMemoryCapture;
@Captor
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;

@Before
@SuppressWarnings("unchecked")
Expand All @@ -127,6 +137,18 @@ public void setup() {
listener.onResponse(conversationIndexMemory);
return null;
}).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture());
when(createInteractionResponse.getId()).thenReturn("create_interaction_id");
doAnswer(invocation -> {
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
listener.onResponse(createInteractionResponse);
return null;
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
when(updateResponse.getId()).thenReturn("update_interaction_id");
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
return null;
}).when(mlMemoryManager).updateInteraction(any(), any(), mlMemoryManagerCapture.capture());

mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool);
Expand Down Expand Up @@ -620,6 +642,57 @@ public void testToolThrowException() {
assertNotNull(modelTensorOutput);
}

@Test
public void testToolParameters() {
// Mock tool validation to return false.
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with a tool including two parameters.
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent.
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

// Run the MLChatAgentRunner.
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called.
verify(firstTool).run(any(), any());
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
assertNotNull(modelTensorOutput);
}

@Test
public void testSaveLastTraceFailure() {
// Mock tool validation to return true.
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with tools
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

doAnswer(invocation -> {
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
listener.onFailure(new IllegalArgumentException());
return null;
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
// Run the MLChatAgentRunner
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called
verify(firstTool).run(any(), any());

Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down
Loading