diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 64b4f5267f..e36c296066 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -270,17 +270,17 @@ public void checkAccess(String conversationId, ActionListener listener) String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - // If security is off - User doesn't exist - you have permission - if (userstr == null || User.parse(userstr) == null) { - internalListener.onResponse(true); - return; - } GetRequest getRequest = Requests.getRequest(indexName).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); } + // If security is off - User doesn't exist - you have permission + if (userstr == null || User.parse(userstr) == null) { + internalListener.onResponse(true); + return; + } ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); String user = User.parse(userstr).getName(); // If you're not the owner of this conversation, you do not have permission @@ -290,7 +290,13 @@ public void checkAccess(String conversationId, ActionListener listener) } internalListener.onResponse(true); }, e -> { internalListener.onFailure(e); }); - client.get(getRequest, al); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> { + log.error("Failed to refresh conversations index during check access ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index d2c70ff6e7..6b33533ee2 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -33,9 +33,12 @@ import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; + /** * Class for handling all Conversational Memory operactions */ +@Log4j2 public class OpenSearchConversationalMemoryHandler implements ConversationalMemoryHandler { private ConversationMetaIndex conversationMetaIndex; @@ -247,19 +250,25 @@ public ActionFuture> getConversations(int maxResults) { public void deleteConversation(String conversationId, ActionListener listener) { StepListener accessListener = new StepListener<>(); conversationMetaIndex.checkAccess(conversationId, accessListener); - + log.info("DELETING CONVERSATION " + conversationId); accessListener.whenComplete(access -> { if (access) { StepListener metaDeleteListener = new StepListener<>(); StepListener interactionsListener = new StepListener<>(); - conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); interactionsIndex.deleteConversation(conversationId, interactionsListener); - metaDeleteListener.whenComplete(metaResult -> { - interactionsListener - .whenComplete(interactionResult -> { listener.onResponse(metaResult && interactionResult); }, listener::onFailure); + interactionsListener + .whenComplete( + interactionResult -> { conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); }, + listener::onFailure + ); + + metaDeleteListener.whenComplete(metaDeleteResult -> { + log.info("SUCCESSFUL DELETION OF CONVERSATION " + conversationId); + listener.onResponse(metaDeleteResult && interactionsListener.result()); }, listener::onFailure); + } else { listener.onResponse(false); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java index 7c842bdcac..6ee0d4cc31 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java @@ -249,16 +249,18 @@ public void testCanDeleteConversations() { }); StepListener> inters2 = new StepListener<>(); - inters1.whenComplete(ints -> { cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }, e -> { + inters1.whenComplete(ints -> { cdl.countDown(); assert (false); + }, e -> { + assert (e.getMessage().startsWith("Conversation [")); + cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }); LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(r -> { assert (del.result()); assert (conversations.result().size() == 1); assert (conversations.result().get(0).getId().equals(cid2.result())); - assert (inters1.result().size() == 0); assert (inters2.result().size() == 1); assert (inters2.result().get(0).getId().equals(iid3.result())); }, e -> { assert (false); }), cdl); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index c85384407e..821d801cdf 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -404,6 +404,7 @@ public void testDelete_DeleteFails_ThenFail() { public void testCheckAccess_DoesNotExist_ThenFail() { setupUser("user"); + setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); GetResponse response = mock(GetResponse.class); doReturn(false).when(response).isExists(); @@ -423,6 +424,7 @@ public void testCheckAccess_DoesNotExist_ThenFail() { public void testCheckAccess_WrongId_ThenFail() { setupUser("user"); + setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); GetResponse response = mock(GetResponse.class); doReturn(true).when(response).isExists(); @@ -443,6 +445,7 @@ public void testCheckAccess_WrongId_ThenFail() { public void testCheckAccess_GetFails_ThenFail() { setupUser("user"); + setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); @@ -459,6 +462,7 @@ public void testCheckAccess_GetFails_ThenFail() { public void testCheckAccess_ClientFails_ThenFail() { setupUser("user"); + setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any()); @SuppressWarnings("unchecked") diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java index 14995ddb5b..2eb7589696 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java @@ -26,6 +26,7 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.settings.MLCommonsSettings; @@ -163,15 +164,20 @@ public void testDeleteConversation_WithInteractions() throws IOException { assert (!gcmap.containsKey("next_token")); assert (((ArrayList) gcmap.get("conversations")).size() == 0); - Response giresponse = TestHelper - .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); - assert (giresponse != null); - assert (TestHelper.restStatus(giresponse) == RestStatus.OK); - HttpEntity gihttpEntity = giresponse.getEntity(); - String gientityString = TestHelper.httpEntityToString(gihttpEntity); - Map gimap = gson.fromJson(gientityString, Map.class); - assert (gimap.containsKey("interactions")); - assert (!gimap.containsKey("next_token")); - assert (((ArrayList) gimap.get("interactions")).size() == 0); + try { + Response giresponse = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + assert (giresponse != null); + assert (TestHelper.restStatus(giresponse) == RestStatus.OK); + HttpEntity gihttpEntity = giresponse.getEntity(); + String gientityString = TestHelper.httpEntityToString(gihttpEntity); + Map gimap = gson.fromJson(gientityString, Map.class); + assert (gimap.containsKey("interactions")); + assert (!gimap.containsKey("next_token")); + assert (((ArrayList) gimap.get("interactions")).size() == 0); + assert (false); + } catch (ResponseException e) { + assert (TestHelper.restStatus(e.getResponse()) == RestStatus.NOT_FOUND); + } } }