diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 7d1b4b2ac0..fa262583bb 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -90,15 +90,22 @@ jacocoTestReport { dependsOn test } +List jacocoExclusions = [ + // TODO: add more unit test to meet the minimal test coverage. + 'org.opensearch.ml.engine.algorithms.agent.MLConversationalFlowAgentRunner' +] + jacocoTestCoverageVerification { violationRules { rule { limit { counter = 'LINE' + excludes = jacocoExclusions minimum = 0.65 //TODO: increase coverage to 0.90 } limit { counter = 'BRANCH' + excludes = jacocoExclusions minimum = 0.55 //TODO: increase coverage to 0.85 } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7e9bdd1ab0..5268f4a559 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.agent; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; @@ -14,6 +15,7 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -26,6 +28,7 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; @@ -185,4 +188,13 @@ public static String parseInputFromLLMReturn(Map retMap) { } } + + public static int getMessageHistoryLimit(Map params) { + String messageHistoryLimitStr = params.get(MESSAGE_HISTORY_LIMIT); + return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS; + } + + public static String getToolName(MLToolSpec toolSpec) { + return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType(); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 52156771bd..cb779abf31 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -65,6 +65,7 @@ public class MLAgentExecutor implements Executable { public static final String QUESTION = "question"; public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; + public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit"; private Client client; private Settings settings; @@ -281,6 +282,15 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { switch (mlAgent.getType()) { case "flow": return new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); + case "conversational_flow": + return new MLConversationalFlowAgentRunner( + client, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap + ); case "conversational": return new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); default: diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index ae629e80ab..f944aa90e9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; @@ -110,11 +111,13 @@ public MLChatAgentRunner( this.memoryFactoryMap = memoryFactoryMap; } + @Override public void run(MLAgent mlAgent, Map params, ActionListener listener) { String memoryType = mlAgent.getMemory().getType(); String memoryId = params.get(MLAgentExecutor.MEMORY_ID); String appType = mlAgent.getAppType(); String title = params.get(MLAgentExecutor.QUESTION); + int messageHistoryLimit = getMessageHistoryLimit(params); ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { @@ -152,7 +155,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener { log.error("Failed to get chat history", e); listener.onFailure(e); - })); + }), messageHistoryLimit); }, listener::onFailure)); } @@ -370,7 +373,7 @@ private void runReAct( if (conversationIndexMemory != null) { ConversationIndexMessage msgTemp = ConversationIndexMessage .conversationIndexMessageBuilder() - .type("ReAct") + .type(memory.getType()) .question(question) .response(thought) .finalAnswer(false) @@ -443,7 +446,7 @@ private void runReAct( // Create final trace message. ConversationIndexMessage msgTemp = ConversationIndexMessage .conversationIndexMessageBuilder() - .type("ReAct") + .type(memory.getType()) .question(question) .response(finalAnswer1) .finalAnswer(true) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java new file mode 100644 index 0000000000..c2471e4f35 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -0,0 +1,449 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; +import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.action.StepListener; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; +import org.opensearch.ml.engine.memory.ConversationIndexMessage; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; + +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Data +@NoArgsConstructor +public class MLConversationalFlowAgentRunner implements MLAgentRunner { + + public static final String CHAT_HISTORY = "chat_history"; + public static final String SELECTED_TOOLS = "selected_tools"; + private Client client; + private Settings settings; + private ClusterService clusterService; + private NamedXContentRegistry xContentRegistry; + private Map toolFactories; + private Map memoryFactoryMap; + + public MLConversationalFlowAgentRunner( + Client client, + Settings settings, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Map toolFactories, + Map memoryFactoryMap + ) { + this.client = client; + this.settings = settings; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.toolFactories = toolFactories; + this.memoryFactoryMap = memoryFactoryMap; + } + + @Override + public void run(MLAgent mlAgent, Map params, ActionListener listener) { + String appType = mlAgent.getAppType(); + String memoryId = params.get(MLAgentExecutor.MEMORY_ID); + String parentInteractionId = params.get(MLAgentExecutor.PARENT_INTERACTION_ID); + if (appType == null || mlAgent.getMemory() == null) { + runAgent(mlAgent, params, listener, null, memoryId, parentInteractionId); + return; + } + + // TODO: refactor to extract common part with MLChatAgentRunner and MLFlowAgentRunner + String memoryType = mlAgent.getMemory().getType(); + String title = params.get(QUESTION); + int messageHistoryLimit = getMessageHistoryLimit(params); + + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { + memory.getMessages(ActionListener.>wrap(r -> { + List messageList = new ArrayList<>(); + for (Interaction next : r) { + String question = next.getInput(); + String response = next.getResponse(); + // As we store the conversation with empty response first and then update when have final answer, + // filter out those in-flight requests when run in parallel + if (Strings.isNullOrEmpty(response)) { + continue; + } + messageList + .add( + ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId(memory.getConversationId()) + .question(question) + .response(response) + .build() + ); + } + + StringBuilder chatHistoryBuilder = new StringBuilder(); + if (messageList.size() > 0) { + chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n"); + for (Message message : messageList) { + chatHistoryBuilder.append(message.toString()).append("\n"); + } + params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + } + + runAgent(mlAgent, params, listener, memory, memory.getConversationId(), parentInteractionId); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + }), messageHistoryLimit); + }, listener::onFailure)); + } + + private void runAgent( + MLAgent mlAgent, + Map params, + ActionListener listener, + ConversationIndexMemory memory, + String memoryId, + String parentInteractionId + ) { + + StepListener firstStepListener = null; + Tool firstTool = null; + List flowAgentOutput = new ArrayList<>(); + Map firstToolExecuteParams = null; + StepListener previousStepListener = null; + Map additionalInfo = new ConcurrentHashMap<>(); + String selectedToolsStr = params.get(SELECTED_TOOLS); + List toolSpecs = getMlToolSpecs(mlAgent, selectedToolsStr); + + if (toolSpecs == null || toolSpecs.size() == 0) { + listener.onFailure(new IllegalArgumentException("no tool configured")); + return; + } + AtomicInteger traceNumber = new AtomicInteger(0); + if (memory != null) { + flowAgentOutput.add(ModelTensor.builder().name(MEMORY_ID).result(memoryId).build()); + flowAgentOutput.add(ModelTensor.builder().name(PARENT_INTERACTION_ID_FIELD).result(parentInteractionId).build()); + } + + MLMemorySpec memorySpec = mlAgent.getMemory(); + for (int i = 0; i <= toolSpecs.size(); i++) { + if (i == 0) { + MLToolSpec toolSpec = toolSpecs.get(i); + Tool tool = createTool(toolSpec); + firstStepListener = new StepListener<>(); + previousStepListener = firstStepListener; + firstTool = tool; + firstToolExecuteParams = getToolExecuteParams(toolSpec, params); + } else { + MLToolSpec previousToolSpec = toolSpecs.get(i - 1); + StepListener nextStepListener = new StepListener<>(); + int finalI = i; + previousStepListener.whenComplete(output -> { + processOutput( + params, + listener, + memory, + memoryId, + parentInteractionId, + toolSpecs, + flowAgentOutput, + additionalInfo, + traceNumber, + memorySpec, + previousToolSpec, + finalI, + output, + nextStepListener + ); + }, e -> { + log.error("Failed to run flow agent", e); + listener.onFailure(e); + }); + previousStepListener = nextStepListener; + } + } + if (toolSpecs.size() == 1) { + firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> { + MLToolSpec toolSpec = toolSpecs.get(0); + processOutput( + params, + listener, + memory, + memoryId, + parentInteractionId, + toolSpecs, + flowAgentOutput, + additionalInfo, + traceNumber, + memorySpec, + toolSpec, + 1, + output, + null + ); + }, e -> { listener.onFailure(e); })); + } else { + firstTool.run(firstToolExecuteParams, firstStepListener); + } + } + + private static List getMlToolSpecs(MLAgent mlAgent, String selectedToolsStr) { + List toolSpecs = mlAgent.getTools(); + if (selectedToolsStr != null) { + List selectedTools = gson.fromJson(selectedToolsStr, List.class); + Map toolNameSpecMap = new HashMap<>(); + for (MLToolSpec toolSpec : toolSpecs) { + toolNameSpecMap.put(getToolName(toolSpec), toolSpec); + } + List selectedToolSpecs = new ArrayList<>(); + for (String tool : selectedTools) { + if (toolNameSpecMap.containsKey(tool)) { + selectedToolSpecs.add(toolNameSpecMap.get(tool)); + } + } + toolSpecs = selectedToolSpecs; + } + return toolSpecs; + } + + private void processOutput( + Map params, + ActionListener listener, + ConversationIndexMemory memory, + String memoryId, + String parentInteractionId, + List toolSpecs, + List flowAgentOutput, + Map additionalInfo, + AtomicInteger traceNumber, + MLMemorySpec memorySpec, + MLToolSpec previousToolSpec, + int finalI, + Object output, + StepListener nextStepListener + ) throws IOException, + PrivilegedActionException { + String toolName = getToolName(previousToolSpec); + String outputKey = toolName + ".output"; + String outputResponse = parseResponse(output); + params.put(outputKey, escapeJson(outputResponse)); + + if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) { + if (output instanceof ModelTensorOutput) { + flowAgentOutput.addAll(((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors()); + } else { + String result = output instanceof String + ? (String) output + : AccessController.doPrivileged((PrivilegedExceptionAction) () -> StringUtils.toJson(output)); + + ModelTensor stepOutput = ModelTensor.builder().name(toolName).result(result).build(); + flowAgentOutput.add(stepOutput); + } + if (memory == null) { + additionalInfo.put(outputKey, outputResponse); + } + } + + if (finalI == toolSpecs.size()) { + ActionListener updateListener = ActionListener.wrap(r -> { + log.info("Updated additional info for interaction " + r.getId() + " of flow agent."); + listener.onResponse(flowAgentOutput); + }, e -> { + log.error("Failed to update root interaction", e); + listener.onResponse(flowAgentOutput); + }); + if (memory == null) { + if (memoryId == null || parentInteractionId == null || memorySpec == null || memorySpec.getType() == null) { + listener.onResponse(flowAgentOutput); + } else { + updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener); + } + } else { + saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { + log.info("saved last trace for interaction " + parentInteractionId + " of flow agent"); + Map updateContent = Map.of(AI_RESPONSE_FIELD, outputResponse, ADDITIONAL_INFO_FIELD, additionalInfo); + memory.update(parentInteractionId, updateContent, updateListener); + }, e -> { + log.error("Failed to update root interaction ", e); + listener.onFailure(e); + })); + } + } else { + if (memory == null) { + runNextStep(params, toolSpecs, finalI, nextStepListener); + } else { + saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { + runNextStep(params, toolSpecs, finalI, nextStepListener); + }, e -> { + log.error("Failed to update root interaction ", e); + listener.onFailure(e); + })); + } + } + } + + private void runNextStep(Map params, List toolSpecs, int finalI, StepListener nextStepListener) { + MLToolSpec toolSpec = toolSpecs.get(finalI); + Tool tool = createTool(toolSpec); + if (finalI < toolSpecs.size()) { + tool.run(getToolExecuteParams(toolSpec, params), nextStepListener); + } + } + + private void saveMessage( + Map params, + ConversationIndexMemory memory, + String outputResponse, + String memoryId, + String parentInteractionId, + String toolName, + AtomicInteger traceNumber, + ActionListener listener + ) { + ConversationIndexMessage finalMessage = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type(memory.getType()) + .question(params.get(QUESTION)) + .response(outputResponse) + .finalAnswer(true) + .sessionId(memoryId) + .build(); + memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener); + } + + @VisibleForTesting + void updateMemoryWithListener( + Map additionalInfo, + MLMemorySpec memorySpec, + String memoryId, + String interactionId, + ActionListener listener + ) { + if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { + return; + } + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); + conversationIndexMemoryFactory + .create( + memoryId, + ActionListener + .wrap( + memory -> memory.update(interactionId, Map.of(ActionConstants.ADDITIONAL_INFO_FIELD, additionalInfo), listener), + e -> log.error("Failed create memory from id: " + memoryId, e) + ) + ); + } + + @VisibleForTesting + String parseResponse(Object output) throws IOException { + if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) { + ModelTensors tensors = (ModelTensors) ((List) output).get(0); + return tensors.toXContent(JsonXContent.contentBuilder(), null).toString(); + } else if (output instanceof ModelTensor) { + return ((ModelTensor) output).toXContent(JsonXContent.contentBuilder(), null).toString(); + } else if (output instanceof ModelTensorOutput) { + return ((ModelTensorOutput) output).toXContent(JsonXContent.contentBuilder(), null).toString(); + } else { + if (output instanceof String) { + return (String) output; + } else { + return StringUtils.toJson(output); + } + } + } + + @VisibleForTesting + Tool createTool(MLToolSpec toolSpec) { + Map toolParams = new HashMap<>(); + if (toolSpec.getParameters() != null) { + toolParams.putAll(toolSpec.getParameters()); + } + if (!toolFactories.containsKey(toolSpec.getType())) { + throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); + } + Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams); + if (toolSpec.getName() != null) { + tool.setName(toolSpec.getName()); + } + + if (toolSpec.getDescription() != null) { + tool.setDescription(toolSpec.getDescription()); + } + return tool; + } + + @VisibleForTesting + Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { + Map executeParams = new HashMap<>(); + if (toolSpec.getParameters() != null) { + executeParams.putAll(toolSpec.getParameters()); + } + for (String key : params.keySet()) { + String toBeReplaced = null; + if (key.startsWith(toolSpec.getType() + ".")) { + toBeReplaced = toolSpec.getType() + "."; + } + if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) { + toBeReplaced = toolSpec.getName() + "."; + } + if (toBeReplaced != null) { + executeParams.put(key.replace(toBeReplaced, ""), params.get(key)); + } else { + executeParams.put(key, params.get(key)); + } + } + + if (executeParams.containsKey("input")) { + String input = executeParams.get("input"); + StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}"); + input = substitutor.replace(input); + executeParams.put("input", input); + } + return executeParams; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 1d574abc86..674a1237c6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -71,6 +71,7 @@ public MLFlowAgentRunner( this.memoryFactoryMap = memoryFactoryMap; } + @Override public void run(MLAgent mlAgent, Map params, ActionListener listener) { List toolSpecs = mlAgent.getTools(); StepListener firstStepListener = null; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java index 8dcbe050bb..eb8947fdd1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -12,6 +12,7 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; @@ -140,6 +141,10 @@ public void getMessages(ActionListener listener) { memoryManager.getFinalInteractions(conversationId, LAST_N_INTERACTIONS, listener); } + public void getMessages(ActionListener listener, int size) { + memoryManager.getFinalInteractions(conversationId, size, listener); + } + @Override public void clear() { throw new RuntimeException("clear method is not supported in ConversationIndexMemory"); @@ -150,6 +155,10 @@ public void remove(String id) { throw new RuntimeException("remove method is not supported in ConversationIndexMemory"); } + public void update(String messageId, Map updateContent, ActionListener updateListener) { + getMemoryManager().updateInteraction(messageId, updateContent, updateListener); + } + public static class Factory implements Memory.Factory { private Client client; private MLIndicesHandler mlIndicesHandler; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0518846a36..ff3be400dd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -16,6 +16,8 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; import java.util.Arrays; import java.util.HashMap; @@ -112,6 +114,8 @@ public class MLChatAgentRunnerTest { @Captor private ArgumentCaptor>> memoryInteractionCapture; @Captor + private ArgumentCaptor messageHistoryLimitCapture; + @Captor private ArgumentCaptor> conversationIndexMemoryCapture; @Captor private ArgumentCaptor> mlMemoryManagerCapture; @@ -132,7 +136,7 @@ public void setup() { ActionListener> listener = invocation.getArgument(0); listener.onResponse(generateInteractions(2)); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); doAnswer(invocation -> { @@ -464,13 +468,15 @@ public void testChatHistoryExcludeOngoingQuestion() { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); HashMap params = new HashMap<>(); + params.put(MESSAGE_HISTORY_LIMIT, "5"); mlChatAgentRunner.run(mlAgent, params, agentActionListener); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); Assert.assertFalse(chatHistory.contains("input-99")); + Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue()); } @Test @@ -517,7 +523,7 @@ private void testInteractions(String maxInteraction) { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); HashMap params = new HashMap<>(); params.put("verbose", "true"); @@ -525,6 +531,7 @@ private void testInteractions(String maxInteraction) { Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); Assert.assertFalse(chatHistory.contains("input-99")); + Assert.assertEquals(LAST_N_INTERACTIONS, messageHistoryLimitCapture.getValue().intValue()); } @Test @@ -545,7 +552,7 @@ public void testChatHistoryException() { ActionListener> listener = invocation.getArgument(0); listener.onFailure(new RuntimeException("Test Exception")); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); HashMap params = new HashMap<>(); mlChatAgentRunner.run(mlAgent, params, agentActionListener);