Skip to content

Commit

Permalink
2.x consistent get interactions (#1334)
Browse files Browse the repository at this point in the history
* consistent getInteractions response when security/no security

Signed-off-by: HenryL27 <[email protected]>

* fix deletion race condition

Signed-off-by: HenryL27 <[email protected]>

* cleanup

Signed-off-by: HenryL27 <[email protected]>

---------

Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 authored Sep 13, 2023
1 parent d21b032 commit 24a629b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,17 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
internalListener.onResponse(true);
return;
}
GetRequest getRequest = Requests.getRequest(indexName).id(conversationId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) {
throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found");
}
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
internalListener.onResponse(true);
return;
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
String user = User.parse(userstr).getName();
// If you're not the owner of this conversation, you do not have permission
Expand All @@ -290,7 +290,13 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
}
internalListener.onResponse(true);
}, e -> { internalListener.onFailure(e); });
client.get(getRequest, al);
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> {
log.error("Failed to refresh conversations index during check access ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;

/**
* Class for handling all Conversational Memory operactions
*/
@Log4j2
public class OpenSearchConversationalMemoryHandler implements ConversationalMemoryHandler {

private ConversationMetaIndex conversationMetaIndex;
Expand Down Expand Up @@ -247,19 +250,25 @@ public ActionFuture<List<ConversationMeta>> getConversations(int maxResults) {
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
StepListener<Boolean> accessListener = new StepListener<>();
conversationMetaIndex.checkAccess(conversationId, accessListener);

log.info("DELETING CONVERSATION " + conversationId);
accessListener.whenComplete(access -> {
if (access) {
StepListener<Boolean> metaDeleteListener = new StepListener<>();
StepListener<Boolean> interactionsListener = new StepListener<>();

conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener);
interactionsIndex.deleteConversation(conversationId, interactionsListener);

metaDeleteListener.whenComplete(metaResult -> {
interactionsListener
.whenComplete(interactionResult -> { listener.onResponse(metaResult && interactionResult); }, listener::onFailure);
interactionsListener
.whenComplete(
interactionResult -> { conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); },
listener::onFailure
);

metaDeleteListener.whenComplete(metaDeleteResult -> {
log.info("SUCCESSFUL DELETION OF CONVERSATION " + conversationId);
listener.onResponse(metaDeleteResult && interactionsListener.result());
}, listener::onFailure);

} else {
listener.onResponse(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,18 @@ public void testCanDeleteConversations() {
});

StepListener<List<Interaction>> inters2 = new StepListener<>();
inters1.whenComplete(ints -> { cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }, e -> {
inters1.whenComplete(ints -> {
cdl.countDown();
assert (false);
}, e -> {
assert (e.getMessage().startsWith("Conversation ["));
cmHandler.getInteractions(cid2.result(), 0, 10, inters2);
});

LatchedActionListener<List<Interaction>> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(r -> {
assert (del.result());
assert (conversations.result().size() == 1);
assert (conversations.result().get(0).getId().equals(cid2.result()));
assert (inters1.result().size() == 0);
assert (inters2.result().size() == 1);
assert (inters2.result().get(0).getId().equals(iid3.result()));
}, e -> { assert (false); }), cdl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ public void testDelete_DeleteFails_ThenFail() {

public void testCheckAccess_DoesNotExist_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
GetResponse response = mock(GetResponse.class);
doReturn(false).when(response).isExists();
Expand All @@ -423,6 +424,7 @@ public void testCheckAccess_DoesNotExist_ThenFail() {

public void testCheckAccess_WrongId_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
GetResponse response = mock(GetResponse.class);
doReturn(true).when(response).isExists();
Expand All @@ -443,6 +445,7 @@ public void testCheckAccess_WrongId_ThenFail() {

public void testCheckAccess_GetFails_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
doAnswer(invocation -> {
ActionListener<GetResponse> al = invocation.getArgument(1);
Expand All @@ -459,6 +462,7 @@ public void testCheckAccess_GetFails_ThenFail() {

public void testCheckAccess_ClientFails_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any());
@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.settings.MLCommonsSettings;
Expand Down Expand Up @@ -163,15 +164,20 @@ public void testDeleteConversation_WithInteractions() throws IOException {
assert (!gcmap.containsKey("next_token"));
assert (((ArrayList) gcmap.get("conversations")).size() == 0);

Response giresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
assert (giresponse != null);
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
HttpEntity gihttpEntity = giresponse.getEntity();
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
Map gimap = gson.fromJson(gientityString, Map.class);
assert (gimap.containsKey("interactions"));
assert (!gimap.containsKey("next_token"));
assert (((ArrayList) gimap.get("interactions")).size() == 0);
try {
Response giresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
assert (giresponse != null);
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
HttpEntity gihttpEntity = giresponse.getEntity();
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
Map gimap = gson.fromJson(gientityString, Map.class);
assert (gimap.containsKey("interactions"));
assert (!gimap.containsKey("next_token"));
assert (((ArrayList) gimap.get("interactions")).size() == 0);
assert (false);
} catch (ResponseException e) {
assert (TestHelper.restStatus(e.getResponse()) == RestStatus.NOT_FOUND);
}
}
}

0 comments on commit 24a629b

Please sign in to comment.