diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 7e79c98d45..ab03859271 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -106,6 +106,9 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener> nextStepListenerCaptor; + @Captor + private ArgumentCaptor> memoryMapCaptor; + @Before @SuppressWarnings("unchecked") public void setup() { @@ -154,9 +157,18 @@ private Answer generateToolTensorResponse() { public void testRunWithIncludeOutputNotSet() { final Map params = new HashMap<>(); params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id"); MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); + ConversationIndexMemory memory = mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(memory); + return null; + }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); + final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") @@ -170,6 +182,11 @@ public void testRunWithIncludeOutputNotSet() { // Respond with last tool output assertEquals(SECOND_TOOL, agentOutput.get(0).getName()); assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult()); + + verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class)); + Map additionalInfo = (Map) memoryMapCaptor.getValue().get("additional_info"); + assertEquals(1, additionalInfo.size()); + assertNotNull(additionalInfo.get(SECOND_TOOL + ".output")); } @Test() @@ -188,9 +205,17 @@ public void testRunWithNoToolSpec() { public void testRunWithIncludeOutputSet() { final Map params = new HashMap<>(); params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id"); MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(true).build(); MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build(); MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); + ConversationIndexMemory memory = mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(memory); + return null; + }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") @@ -206,6 +231,10 @@ public void testRunWithIncludeOutputSet() { assertEquals(SECOND_TOOL, agentOutput.get(1).getName()); assertEquals(FIRST_TOOL_RESPONSE, agentOutput.get(0).getResult()); assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(1).getResult()); + + verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class)); + Map additionalInfo = (Map) memoryMapCaptor.getValue().get("additional_info"); + assertEquals(2, additionalInfo.size()); } @Test