From a49fca6e7be1ddd374dafefd4bb5e23939fb7466 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Sat, 30 Dec 2023 01:46:16 +0800 Subject: [PATCH] support regenerate for chatbot (#1816) * support regenerate for chatbot Signed-off-by: Hailong Cui * update test method name Signed-off-by: Hailong Cui * exclude MLModelCacheHelper for jacoco coverage check Signed-off-by: Hailong Cui * Address review comments Signed-off-by: Hailong Cui * Address review comments Signed-off-by: Hailong Cui --------- Signed-off-by: Hailong Cui --- .../algorithms/agent/MLAgentExecutor.java | 99 ++++++++++--- .../ml/engine/memory/MLMemoryManager.java | 63 +++++++++ .../algorithms/agent/MLAgentExecutorTest.java | 133 ++++++++++++++++++ .../engine/memory/MLMemoryManagerTests.java | 71 ++++++++++ plugin/build.gradle | 4 +- 5 files changed, 347 insertions(+), 23 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index c1078d24c5..29c86bffc8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -45,6 +45,8 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; import com.google.common.annotations.VisibleForTesting; import com.google.gson.Gson; @@ -62,6 +64,7 @@ public class MLAgentExecutor implements Executable { public static final String MEMORY_ID = "memory_id"; public static final String QUESTION = "question"; public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; + public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; private Client client; private Settings settings; @@ -113,9 +116,14 @@ public void execute(Input input, ActionListener listener) { MLMemorySpec memorySpec = mlAgent.getMemory(); String memoryId = inputDataSet.getParameters().get(MEMORY_ID); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); + if (memoryId == null && regenerateInteractionId != null) { + throw new IllegalArgumentException("A memory ID must be provided to regenerate."); + } + if (memorySpec != null && memorySpec.getType() != null && memoryFactoryMap.containsKey(memorySpec.getType()) @@ -124,28 +132,27 @@ public void execute(Input input, ActionListener listener) { (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // Create root interaction ID - ConversationIndexMessage msg = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type(appType) - .question(question) - .response("") - .finalAnswer(true) - .sessionId(memory.getConversationId()) - .build(); - memory.save(msg, null, null, null, ActionListener.wrap(interaction -> { - log.info("Created parent interaction ID: " + interaction.getId()); - inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId()); - ActionListener agentActionListener = createAgentActionListener( - listener, - outputs, - modelTensors - ); - executeAgent(inputDataSet, mlAgent, agentActionListener); - }, ex -> { - log.error("Failed to create parent interaction", ex); - listener.onFailure(ex); - })); + ActionListener agentActionListener = createAgentActionListener(listener, outputs, modelTensors); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(memoryId, regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent); + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) + ); + } else { + saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent); + } }, ex -> { log.error("Failed to read conversation memory", ex); listener.onFailure(ex); @@ -167,6 +174,54 @@ public void execute(Input input, ActionListener listener) { } + /** + * save root interaction and start execute the agent + * @param listener callback listener + * @param memory memory instance + * @param inputDataSet input + * @param mlAgent agent to run + */ + private void saveRootInteractionAndExecute( + ActionListener listener, + ConversationIndexMemory memory, + RemoteInferenceInputDataSet inputDataSet, + MLAgent mlAgent + ) { + String appType = mlAgent.getAppType(); + String question = inputDataSet.getParameters().get(QUESTION); + String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); + // Create root interaction ID + ConversationIndexMessage msg = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type(appType) + .question(question) + .response("") + .finalAnswer(true) + .sessionId(memory.getConversationId()) + .build(); + memory.save(msg, null, null, null, ActionListener.wrap(interaction -> { + log.info("Created parent interaction ID: " + interaction.getId()); + inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId()); + // only delete previous interaction when new interaction created + if (regenerateInteractionId != null) { + memory + .getMemoryManager() + .deleteInteractionAndTrace( + regenerateInteractionId, + ActionListener.wrap(deleted -> executeAgent(inputDataSet, mlAgent, listener), e -> { + log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e); + listener.onFailure(e); + }) + ); + } else { + executeAgent(inputDataSet, mlAgent, listener); + } + }, ex -> { + log.error("Failed to create parent interaction", ex); + listener.onFailure(ex); + })); + } + private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener agentActionListener) { MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent); mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index fd62ad87c1..6219a4b3a6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -28,8 +28,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; @@ -233,4 +238,62 @@ public void updateInteraction(String interactionId, Map updateCo actionListener.onFailure(exception); } } + + /** + * Delete interaction and its trace data + * @param interactionId interaction id + * @param listener callback for delete result + */ + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(INTERACTIONS_INDEX_NAME); + deleteByQueryRequest.setQuery(buildDeleteInteractionQuery(interactionId)); + deleteByQueryRequest.setRefresh(true); + + innerDeleteInteractionAndTrace(deleteByQueryRequest, interactionId, listener); + } + + @VisibleForTesting + void innerDeleteInteractionAndTrace(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener listener) { + try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) { + ActionListener al = ActionListener.wrap(bulkResponse -> { + if (bulkResponse != null && (!bulkResponse.getBulkFailures().isEmpty() || !bulkResponse.getSearchFailures().isEmpty())) { + log.info("Failed to delete the interaction with ID: {}", interactionId); + listener.onResponse(false); + return; + } + log.info("Successfully delete the interaction with ID: {}", interactionId); + listener.onResponse(true); + }, exception -> { + log.error("Failed to delete interaction with ID {}. Details: {}", interactionId, exception); + listener.onFailure(exception); + }); + // bulk delete interaction and its trace + client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, al); + } catch (Exception e) { + log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e); + listener.onFailure(e); + } + } + + @VisibleForTesting + QueryBuilder buildDeleteInteractionQuery(String interactionId) { + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + // interaction itself + boolQueryBuilder.should(QueryBuilders.idsQuery().addIds(interactionId)); + + // Build the trace query + BoolQueryBuilder traceBoolBuilder = QueryBuilders.boolQuery(); + // Add the ExistsQueryBuilder for checking null values + ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); + traceBoolBuilder.must(existsQueryBuilder); + + // Add the TermQueryBuilder for another field + TermQueryBuilder termQueryBuilder = QueryBuilders + .termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, interactionId); + traceBoolBuilder.must(termQueryBuilder); + + // interaction trace + boolQueryBuilder.should(traceBoolBuilder); + return boolQueryBuilder; + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index a843f15b6d..d2ecb58f7a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -5,7 +5,11 @@ package org.opensearch.ml.engine.algorithms.agent; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID; import java.io.IOException; import java.util.Arrays; @@ -23,6 +27,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; @@ -40,6 +45,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -50,7 +56,10 @@ import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.memory.ConversationIndexMemory; +import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionResponse; import org.opensearch.threadpool.ThreadPool; import com.google.gson.Gson; @@ -85,6 +94,11 @@ public class MLAgentExecutorTest { private ActionListener agentActionListener; @Mock private MLAgentRunner mlAgentRunner; + + @Mock + private ConversationIndexMemory memory; + @Mock + private MLMemoryManager memoryManager; private MLAgentExecutor mlAgentExecutor; @Captor @@ -114,6 +128,7 @@ public void setup() { Mockito.when(clusterService.state()).thenReturn(clusterState); Mockito.when(clusterState.metadata()).thenReturn(metadata); Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(true); + Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -306,6 +321,124 @@ public void test_CreateConversation_ReturnsResult() { Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); } + @Test + public void test_Regenerate_Validation() { + Map params = new HashMap<>(); + params.put(REGENERATE_INTERACTION_ID, "foo"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); + AgentMLInput agentMLInput = new AgentMLInput("test", FunctionName.AGENT, dataset); + Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + mlAgentExecutor.execute(agentMLInput, agentActionListener); + + Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + Assert.assertTrue(exception instanceof IllegalArgumentException); + Assert.assertEquals(exception.getMessage(), "A memory ID must be provided to regenerate."); + } + + @Test + public void test_Regenerate_GetOriginalInteraction() { + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(modelTensor); + return null; + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + + CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); + Mockito.when(interaction.getId()).thenReturn("interaction_id"); + Mockito.doAnswer(invocation -> { + ActionListener responseActionListener = invocation.getArgument(4); + responseActionListener.onResponse(interaction); + return null; + }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doAnswer(invocation -> { + Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); + ActionListener listener = invocation.getArgument(3); + listener.onResponse(memory); + return null; + }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(Boolean.TRUE); + return null; + }).when(memoryManager).deleteInteractionAndTrace(Mockito.anyString(), Mockito.any()); + + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + GetInteractionResponse interactionResponse = Mockito.mock(GetInteractionResponse.class); + Interaction mockInteraction = Mockito.mock(Interaction.class); + Mockito.when(mockInteraction.getInput()).thenReturn("regenerate question"); + Mockito.when(interactionResponse.getInteraction()).thenReturn(mockInteraction); + listener.onResponse(interactionResponse); + return null; + }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); + + String interactionId = "bar-interaction"; + Map params = new HashMap<>(); + params.put(MEMORY_ID, "foo-memory"); + params.put(REGENERATE_INTERACTION_ID, interactionId); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); + AgentMLInput agentMLInput = new AgentMLInput("test", FunctionName.AGENT, dataset); + Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + mlAgentExecutor.execute(agentMLInput, agentActionListener); + + Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); + Assert.assertEquals(params.get(QUESTION), "regenerate question"); + // original interaction got deleted + Mockito.verify(memoryManager, times(1)).deleteInteractionAndTrace(Mockito.eq(interactionId), Mockito.any()); + } + + @Test + public void test_Regenerate_OriginalInteraction_NotExist() { + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); + ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(modelTensor); + return null; + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + + CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); + Mockito.when(interaction.getId()).thenReturn("interaction_id"); + Mockito.doAnswer(invocation -> { + ActionListener responseActionListener = invocation.getArgument(4); + responseActionListener.onResponse(interaction); + return null; + }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doAnswer(invocation -> { + Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); + ActionListener listener = invocation.getArgument(3); + listener.onResponse(memory); + return null; + }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Interaction bar-interaction not found")); + return null; + }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); + + Map params = new HashMap<>(); + params.put(MEMORY_ID, "foo-memory"); + params.put(REGENERATE_INTERACTION_ID, "bar-interaction"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); + AgentMLInput agentMLInput = new AgentMLInput("test", FunctionName.AGENT, dataset); + Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + mlAgentExecutor.execute(agentMLInput, agentActionListener); + + Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); + Assert.assertNull(params.get(QUESTION)); + + Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + Assert.assertTrue(exception instanceof ResourceNotFoundException); + Assert.assertEquals(exception.getMessage(), "Interaction bar-interaction not found"); + } + @Test public void test_CreateFlowAgent() { MLAgent mlAgent = MLAgent.builder().name("test_agent").type("flow").build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index a11860db8b..234a8f856e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -22,12 +22,15 @@ import java.util.List; import java.util.Map; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; @@ -43,10 +46,16 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; @@ -107,6 +116,9 @@ public class MLMemoryManagerTests { @Mock ActionListener updateResponseActionListener; + @Mock + ActionListener deletionInteractionListener; + String conversationName; String applicationType; @@ -398,4 +410,63 @@ public void testUpdateInteraction_thenFail() { verify(updateResponseActionListener).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); } + + @Test + public void testBuildTraceQuery() { + QueryBuilder queryBuilder = mlMemoryManager.buildDeleteInteractionQuery("interaction-id-1"); + String query = Strings.toString(XContentType.JSON, queryBuilder); + Assert + .assertEquals( + "{\"bool\":{\"should\":[{\"ids\":{\"values\":[\"interaction-id-1\"],\"boost\":1.0}},{\"bool\":{\"must\":[{\"exists\":{\"field\":\"trace_number\",\"boost\":1.0}},{\"term\":{\"parent_interaction_id\":{\"value\":\"interaction-id-1\",\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}],\"adjust_pure_negative\":true,\"boost\":1.0}}", + query + ); + } + + @Test + public void testDeleteInteraction() { + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse bulkByScrollResponse = Mockito.mock(BulkByScrollResponse.class); + Mockito.when(bulkByScrollResponse.getBulkFailures()).thenReturn(List.of()); + Mockito.when(bulkByScrollResponse.getSearchFailures()).thenReturn(List.of()); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(Mockito.eq(DeleteByQueryAction.INSTANCE), Mockito.any(DeleteByQueryRequest.class), Mockito.any()); + + mlMemoryManager.deleteInteractionAndTrace("test-interaction", deletionInteractionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + Mockito.verify(deletionInteractionListener, times(1)).onResponse(argumentCaptor.capture()); + Assert.assertTrue(argumentCaptor.getValue()); + } + + @Test + public void testDeleteInteractionFailed() { + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse bulkByScrollResponse = Mockito.mock(BulkByScrollResponse.class); + Mockito.when(bulkByScrollResponse.getBulkFailures()).thenReturn(List.of(Mockito.mock(BulkItemResponse.Failure.class))); + Mockito.when(bulkByScrollResponse.getSearchFailures()).thenReturn(List.of()); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(Mockito.eq(DeleteByQueryAction.INSTANCE), Mockito.any(DeleteByQueryRequest.class), Mockito.any()); + + mlMemoryManager.deleteInteractionAndTrace("test-interaction", deletionInteractionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + Mockito.verify(deletionInteractionListener, times(1)).onResponse(argumentCaptor.capture()); + Assert.assertFalse(argumentCaptor.getValue()); + } + + @Test + public void testDeleteInteractionException() { + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IndexNotFoundException("test-index")); + return null; + }).when(client).execute(Mockito.eq(DeleteByQueryAction.INSTANCE), Mockito.any(DeleteByQueryRequest.class), Mockito.any()); + + mlMemoryManager.deleteInteractionAndTrace("test-interaction", deletionInteractionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(deletionInteractionListener, times(1)).onFailure(argumentCaptor.capture()); + Assert.assertTrue(argumentCaptor.getValue() instanceof IndexNotFoundException); + } } diff --git a/plugin/build.gradle b/plugin/build.gradle index a6fbbf1851..ba544abe24 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -298,7 +298,9 @@ List jacocoExclusions = [ 'org.opensearch.ml.cluster.MLSyncUpCron', 'org.opensearch.ml.model.MLModelGroupManager', 'org.opensearch.ml.helper.ModelAccessControlHelper', - 'org.opensearch.ml.action.models.DeleteModelTransportAction.2' + 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', + 'org.opensearch.ml.model.MLModelCacheHelper', + 'org.opensearch.ml.model.MLModelCacheHelper.1' ] jacocoTestCoverageVerification {