Skip to content

Commit

Permalink
Update memory if tool output needs to be included in response (#2010) (
Browse files Browse the repository at this point in the history
…#2018)

Signed-off-by: Arjun kumar Giri <[email protected]>
(cherry picked from commit b62b0de)

Co-authored-by: arjunkumargiri <[email protected]>
  • Loading branch information
1 parent 9860c5c commit f41d8a0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
? previousToolSpec.getName() + ".output"
: previousToolSpec.getType() + ".output";

String outputResponse = parseResponse(output);
params.put(outputKey, escapeJson(outputResponse));

if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
if (output instanceof ModelTensorOutput) {
flowAgentOutput.addAll(((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors());
Expand All @@ -117,11 +120,9 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
ModelTensor stepOutput = ModelTensor.builder().name(key).result(result).build();
flowAgentOutput.add(stepOutput);
}
}

String outputResponse = parseResponse(output);
params.put(outputKey, escapeJson(outputResponse));
additionalInfo.put(outputKey, outputResponse);
additionalInfo.put(outputKey, outputResponse);
}

if (finalI == toolSpecs.size()) {
updateMemory(additionalInfo, memorySpec, memoryId, parentInteractionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ public class MLFlowAgentRunnerTest {
@Captor
private ArgumentCaptor<StepListener<Object>> nextStepListenerCaptor;

@Captor
private ArgumentCaptor<Map<String, Object>> memoryMapCaptor;

@Before
@SuppressWarnings("unchecked")
public void setup() {
Expand Down Expand Up @@ -154,9 +157,18 @@ private Answer generateToolTensorResponse() {
public void testRunWithIncludeOutputNotSet() {
final Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id");
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build();
ConversationIndexMemory memory = mock(ConversationIndexMemory.class);
Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(memory);
return null;
}).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any());
Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager);

final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
Expand All @@ -170,6 +182,11 @@ public void testRunWithIncludeOutputNotSet() {
// Respond with last tool output
assertEquals(SECOND_TOOL, agentOutput.get(0).getName());
assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult());

verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class));
Map<String, Object> additionalInfo = (Map<String, Object>) memoryMapCaptor.getValue().get("additional_info");
assertEquals(1, additionalInfo.size());
assertNotNull(additionalInfo.get(SECOND_TOOL + ".output"));
}

@Test()
Expand All @@ -188,9 +205,17 @@ public void testRunWithNoToolSpec() {
public void testRunWithIncludeOutputSet() {
final Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id");
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(true).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build();
ConversationIndexMemory memory = mock(ConversationIndexMemory.class);
Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(memory);
return null;
}).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any());
Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager);
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
Expand All @@ -206,6 +231,10 @@ public void testRunWithIncludeOutputSet() {
assertEquals(SECOND_TOOL, agentOutput.get(1).getName());
assertEquals(FIRST_TOOL_RESPONSE, agentOutput.get(0).getResult());
assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(1).getResult());

verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class));
Map<String, Object> additionalInfo = (Map<String, Object>) memoryMapCaptor.getValue().get("additional_info");
assertEquals(2, additionalInfo.size());
}

@Test
Expand Down

0 comments on commit f41d8a0

Please sign in to comment.