diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index 9d85d0b6cd..cb44ade93c 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -26,7 +26,7 @@ public class ConversationalIndexConstants { /** Version of the meta index schema */ public final static Integer META_INDEX_SCHEMA_VERSION = 1; /** Name of the conversational metadata index */ - public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta"; + public final static String META_INDEX_NAME = ".plugins-ml-memory-meta"; /** Name of the metadata field for initial timestamp */ public final static String META_CREATED_TIME_FIELD = "create_time"; /** Name of the metadata field for updated timestamp */ @@ -64,7 +64,7 @@ public class ConversationalIndexConstants { /** Version of the interactions index schema */ public final static Integer INTERACTIONS_INDEX_SCHEMA_VERSION = 1; /** Name of the conversational interactions index */ - public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-conversation-interactions"; + public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-memory-message"; /** Name of the interaction field for the conversation Id */ public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id"; /** Name of the interaction field for the human input */ diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java index f5910119fa..de0d06b2bd 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java @@ -17,11 +17,14 @@ */ package org.opensearch.ml.memory.action.conversation; +import java.util.HashMap; import java.util.Map; import org.opensearch.OpenSearchException; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -93,10 +96,9 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList Integer traceNumber = request.getTraceNumber(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); - ActionListener al = ActionListener - .wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> { - internalListener.onFailure(e); - }); + ActionListener al = ActionListener.wrap(iid -> { + cmHandler.updateConversation(cid, new HashMap<>(), getUpdateResponseListener(cid, iid, internalListener)); + }, e -> { internalListener.onFailure(e); }); if (parintIid == null || traceNumber == null) { cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); } else { @@ -108,4 +110,34 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList } } + private ActionListener getUpdateResponseListener( + String conversationId, + String interactionId, + ActionListener actionListener + ) { + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log + .debug( + "Successfully updated the Conversation with ID: {} after interaction {} is created", + conversationId, + interactionId + ); + actionListener.onResponse(new CreateInteractionResponse(interactionId)); + } else { + log.error("Failed to update the Conversation with ID: {} after interaction {} is created", conversationId, interactionId); + actionListener.onResponse(new CreateInteractionResponse(interactionId)); + } + }, exception -> { + log + .error( + "Failed to update Conversation with ID {} after interaction {} is created. Details: {}", + conversationId, + interactionId, + exception + ); + actionListener.onResponse(new CreateInteractionResponse(interactionId)); + }); + + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java index 9f8c42f17b..995edb8b16 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -5,10 +5,14 @@ package org.opensearch.ml.memory.action.conversation; +import java.time.Instant; +import java.util.Map; + import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -36,9 +40,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener updateContent = updateConversationRequest.getUpdateContent(); + updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); + updateRequest.doc(updateContent); + updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); } catch (Exception e) { diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 74ba94c88c..b2b753651a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -25,6 +25,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -393,6 +394,7 @@ public void updateConversation(String conversationId, Map update updateRequest.doc(updateContent); updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); conversationMetaIndex.updateConversation(updateRequest, listener); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java index eb8e4672ce..22ec5ce386 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java @@ -32,13 +32,17 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; @@ -78,6 +82,8 @@ public class CreateInteractionTransportActionTests extends OpenSearchTestCase { CreateInteractionRequest request; CreateInteractionTransportAction action; ThreadContext threadContext; + UpdateResponse updateResponse; + ShardId shardId; @Before public void setup() throws IOException { @@ -101,6 +107,9 @@ public void setup() throws IOException { Collections.singletonMap("metadata", "some meta") ); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); @@ -114,7 +123,46 @@ public void setup() throws IOException { } public void testCreateInteraction() { - log.info("testing create interaction transport"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(cmHandler).updateConversation(any(), any(), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onResponse("testID"); + return null; + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID")); + } + + public void testCreateInteraction_WrongUpdateStatus() { + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(cmHandler).updateConversation(any(), any(), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onResponse("testID"); + return null; + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID")); + } + + public void testCreateInteraction_UpdateException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Update Conversation Exception")); + return null; + }).when(cmHandler).updateConversation(any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(6); listener.onResponse("testID"); @@ -138,6 +186,11 @@ public void testCreateInteraction_Trace() { 1 ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(cmHandler).updateConversation(any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(6); listener.onResponse("testID"); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java index ea713d99bb..f8476f9ccb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -79,7 +80,9 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); String conversationId = "test_conversation_id"; - Map updateContent = Map.of(META_NAME_FIELD, "new name", META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); + Map updateContent = new HashMap<>(); + updateContent.put(META_NAME_FIELD, "new name"); + updateContent.put(META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); when(updateRequest.getConversationId()).thenReturn(conversationId); when(updateRequest.getUpdateContent()).thenReturn(updateContent); shardId = new ShardId(new Index("indexName", "uuid"), 1); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index ef7b048a0a..e3841209c5 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -574,9 +574,7 @@ public void testGetConversation_NoIndex_ThenFail() { assert (argCaptor .getValue() .getMessage() - .equals( - "no such index [.plugins-ml-conversation-meta] and cannot get conversation since the conversation index does not exist" - )); + .equals("no such index [.plugins-ml-memory-meta] and cannot get conversation since the conversation index does not exist")); } public void testGetConversation_ResponseNotExist_ThenFail() { @@ -652,9 +650,7 @@ public void testUpdateConversation_NoIndex_ThenFail() { assert (argCaptor .getValue() .getMessage() - .equals( - "no such index [.plugins-ml-conversation-meta] and cannot update conversation since the conversation index does not exist" - )); + .equals("no such index [.plugins-ml-memory-meta] and cannot update conversation since the conversation index does not exist")); } public void testUpdateConversation_Success() { diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 70743aa9f3..007e78019d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -731,9 +731,7 @@ public void testGetSg_NoIndex_ThenFail() { assert (argCaptor .getValue() .getMessage() - .equals( - "no such index [.plugins-ml-conversation-interactions] and cannot get interaction since the interactions index does not exist" - )); + .equals("no such index [.plugins-ml-memory-message] and cannot get interaction since the interactions index does not exist")); } public void testGetSg_InteractionNotExist_ThenFail() { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index dc99ef4438..fd62ad87c1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -5,13 +5,31 @@ package org.opensearch.ml.engine.memory; +import static org.opensearch.ml.common.conversation.ActionConstants.TRACE_NUMBER_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; + import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; @@ -19,15 +37,17 @@ import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; -import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; -import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; import org.opensearch.ml.memory.action.conversation.GetTracesAction; import org.opensearch.ml.memory.action.conversation.GetTracesRequest; import org.opensearch.ml.memory.action.conversation.GetTracesResponse; import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.ml.memory.index.ConversationMetaIndex; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import lombok.AllArgsConstructor; @@ -41,6 +61,8 @@ public class MLMemoryManager { private Client client; + private ClusterService clusterService; + private ConversationMetaIndex conversationMetaIndex; /** * Create a new Conversation @@ -106,24 +128,74 @@ public void createInteraction( } /** - * Get the interactions associate with this conversation that are not traces, sorted by recency + * Get the latest interactions associated with this conversation that are not traces, from oldest to newest * @param conversationId the conversation whose interactions to get * @param lastNInteraction Return how many interactions * @param actionListener get all the final interactions that are not traces */ public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { - Preconditions.checkNotNull(conversationId); Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); - ActionListener al = ActionListener.wrap(getInteractionsResponse -> { - actionListener.onResponse(getInteractionsResponse.getInteractions()); - }, e -> { actionListener.onFailure(e); }); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + actionListener.onResponse(List.of()); + return; + } + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerGetFinalInteractions(conversationId, lastNInteraction, actionListener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? "" : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { actionListener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + } catch (Exception e) { + log.error("Failed to get final interactions for conversation " + conversationId, e); + actionListener.onFailure(e); + } + } - try { - client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, lastNInteraction), al); - } catch (Exception exception) { - actionListener.onFailure(exception); + @VisibleForTesting + void innerGetFinalInteractions(String conversationId, int lastNInteraction, ActionListener> listener) { + SearchRequest searchRequest = Requests.searchRequest(INTERACTIONS_INDEX_NAME); + + // Build the query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + + // Add the ExistsQueryBuilder for checking null values + ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(TRACE_NUMBER_FIELD); + boolQueryBuilder.mustNot(existsQueryBuilder); + + // Add the TermQueryBuilder for another field + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + boolQueryBuilder.must(termQueryBuilder); + + // Set the query to the search source + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQueryBuilder); + searchRequest.source(searchSourceBuilder); + + searchRequest.source().size(lastNInteraction); + searchRequest.source().sort(INTERACTIONS_CREATE_TIME_FIELD, SortOrder.DESC); + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(response -> { + List result = new LinkedList(); + for (SearchHit hit : response.getHits()) { + result.add(0, Interaction.fromSearchHit(hit)); + } + internalListener.onResponse(result); + }, e -> { internalListener.onFailure(e); }); + client.search(searchRequest, al); + } catch (Exception e) { + listener.onFailure(e); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java deleted file mode 100644 index a21a3ed60d..0000000000 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java +++ /dev/null @@ -1,124 +0,0 @@ -package org.opensearch.ml.engine.memory; - -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.util.List; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.client.AdminClient; -import org.opensearch.client.Client; -import org.opensearch.client.IndicesAdminClient; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.index.ConversationMetaIndex; -import org.opensearch.threadpool.ThreadPool; - -public class MLMemoryManagerTest { - - @Mock - Client client; - - @Mock - AdminClient adminClient; - - @Mock - IndicesAdminClient indicesAdminClient; - - @Mock - ClusterService clusterService; - - @Mock - ClusterState clusterState; - - @Mock - Metadata metadata; - - @Mock - ConversationMetaIndex conversationMetaIndex; - - @Mock - private ThreadPool threadPool; - - MLMemoryManager memoryManager; - Settings settings; - ThreadContext threadContext; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - memoryManager = new MLMemoryManager(client); - doNothing().when(client).execute(any(), any(), any()); - doNothing().when(client).update(any(), any()); - when(client.admin()).thenReturn(adminClient); - when(adminClient.indices()).thenReturn(indicesAdminClient); - doNothing().when(indicesAdminClient).refresh(any(), any()); - doNothing().when(conversationMetaIndex).checkAccess(any(), any()); - when(clusterService.state()).thenReturn(clusterState); - when(clusterState.metadata()).thenReturn(metadata); - when(metadata.hasIndex(anyString())).thenReturn(true); - settings = Settings.builder().put("test_key", 10).build(); - threadContext = new ThreadContext(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - } - - @Test - public void createConversation() { - ActionListener actionListener = mock(ActionListener.class); - memoryManager.createConversation("test", "test", actionListener); - } - - @Test - public void createInteraction() { - ActionListener actionListener = mock(ActionListener.class); - memoryManager.createInteraction("test", "test", "test", "test", "test", Map.of("feedback", "1"), "test", 0, actionListener); - } - - @Test - public void createInteractionNullAdditionalInfo() { - ActionListener actionListener = mock(ActionListener.class); - memoryManager.createInteraction("test", "test", "test", "test", "test", null, "test", 0, actionListener); - } - - @Test - public void getFinalInteractions() { - ActionListener> actionListener = mock(ActionListener.class); - memoryManager.getFinalInteractions("test", 1, actionListener); - } - - @Test - public void getTracesIndex() { - ActionListener> actionListener = mock(ActionListener.class); - memoryManager.getTraces("test", actionListener); - } - - @Test - public void getTracesNoIndex() { - ActionListener> actionListener = mock(ActionListener.class); - when(metadata.hasIndex(anyString())).thenReturn(false); - memoryManager.getTraces("test", actionListener); - } - - @Test - public void updateInteraction() { - ActionListener actionListener = mock(ActionListener.class); - memoryManager.updateInteraction("test", Map.of("feedback", "1"), actionListener); - } -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index b3a5f0da56..aa9a2ba75d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -8,8 +8,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -27,11 +29,26 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; @@ -39,23 +56,46 @@ import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; -import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; -import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; import org.opensearch.ml.memory.action.conversation.GetTracesAction; import org.opensearch.ml.memory.action.conversation.GetTracesRequest; import org.opensearch.ml.memory.action.conversation.GetTracesResponse; import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.ml.memory.index.ConversationMetaIndex; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.threadpool.ThreadPool; public class MLMemoryManagerTests { @Mock Client client; + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + @Mock MLMemoryManager mlMemoryManager; + @Mock + ConversationMetaIndex conversationMetaIndex; + + @Mock + Metadata metadata; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ThreadPool threadPool; + @Mock ActionListener createConversationResponseActionListener; @@ -74,9 +114,15 @@ public class MLMemoryManagerTests { @Before public void setUp() { MockitoAnnotations.openMocks(this); - mlMemoryManager = new MLMemoryManager(client); + mlMemoryManager = new MLMemoryManager(client, clusterService, conversationMetaIndex); conversationName = "new conversation"; applicationType = "ml application"; + doReturn(clusterState).when(clusterService).state(); + doReturn(metadata).when(clusterState).metadata(); + doReturn(adminClient).when(client).admin(); + doReturn(indicesAdminClient).when(adminClient).indices(); + doReturn(threadPool).when(client).threadPool(); + doReturn(new ThreadContext(Settings.EMPTY)).when(threadPool).getThreadContext(); } @Test @@ -159,39 +205,111 @@ public void testCreateInteractionFails_thenFail() { } @Test - public void testGetInteractions() { - List interactions = List - .of( - new Interaction( - "id0", - Instant.now(), - "cid", - "input", - "pt", - "response", - "origin", - Collections.singletonMap("metadata", "some meta") - ) - ); - ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + public void testGetInteractions_NoIndex_ThenEmpty() { + doReturn(false).when(metadata).hasIndex(anyString()); + + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(interactionListActionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 0); + } + + @Test + public void testGetInteractions_SearchFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); doAnswer(invocation -> { - ActionListener al = invocation.getArgument(2); - GetInteractionsResponse getInteractionsResponse = new GetInteractionsResponse(interactions, 4, false); - al.onResponse(getInteractionsResponse); + ActionListener al = invocation.getArgument(1); + al.onResponse(true); return null; - }).when(client).execute(any(), any(), any()); + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failure in Search")); + return null; + }).when(client).search(any(), any()); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in Search")); + } - verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture(), any()); - assertEquals("cid", captor.getValue().getConversationId()); - assertEquals(0, captor.getValue().getFrom()); - assertEquals(10, captor.getValue().getMaxResults()); + @Test + public void testGetInteractions_NoAccessNoUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + String userstr = ""; + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + + doAnswer(invocation -> { + ThreadContext tc = new ThreadContext(Settings.EMPTY); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + return tc; + }).when(threadPool).getThreadContext(); + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener, times(1)).onFailure(argCaptor.capture()); + System.out.println(argCaptor.getValue().getMessage()); + assert (argCaptor.getValue().getMessage().equals("User [] does not have access to conversation cid")); + } + + @Test + public void testGetInteractions_Success() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + + doAnswer(invocation -> { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); + SearchHits searchHits = new SearchHits(hits, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + ActionListener al = invocation.getArgument(1); + al.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(interactionListActionListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(1, argCaptor.getValue().size()); } @Test public void testGetInteractionFails_thenFail() { - doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + doThrow(new RuntimeException("Failure in runtime")).when(threadPool).getThreadContext(); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(interactionListActionListener).onFailure(argCaptor.capture());