Skip to content

Commit

Permalink
update memory index name and add updated_time (#1784)
Browse files Browse the repository at this point in the history
* update memory index name and add updated_time

Signed-off-by: Xun Zhang <[email protected]>

* remove refresh for searching

Signed-off-by: Xun Zhang <[email protected]>

* some log updates from comments

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Dec 20, 2023
1 parent 2d5b1bb commit 06f8b2a
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class ConversationalIndexConstants {
/** Version of the meta index schema */
public final static Integer META_INDEX_SCHEMA_VERSION = 1;
/** Name of the conversational metadata index */
public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta";
public final static String META_INDEX_NAME = ".plugins-ml-memory-meta";
/** Name of the metadata field for initial timestamp */
public final static String META_CREATED_TIME_FIELD = "create_time";
/** Name of the metadata field for updated timestamp */
Expand Down Expand Up @@ -64,7 +64,7 @@ public class ConversationalIndexConstants {
/** Version of the interactions index schema */
public final static Integer INTERACTIONS_INDEX_SCHEMA_VERSION = 1;
/** Name of the conversational interactions index */
public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-conversation-interactions";
public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-memory-message";
/** Name of the interaction field for the conversation Id */
public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id";
/** Name of the interaction field for the human input */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand Down Expand Up @@ -93,10 +96,9 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
Integer traceNumber = request.getTraceNumber();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener
.wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> {
internalListener.onFailure(e);
});
ActionListener<String> al = ActionListener.wrap(iid -> {
cmHandler.updateConversation(cid, new HashMap<>(), getUpdateResponseListener(cid, iid, internalListener));
}, e -> { internalListener.onFailure(e); });
if (parintIid == null || traceNumber == null) {
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al);
} else {
Expand All @@ -108,4 +110,34 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String conversationId,
String interactionId,
ActionListener<CreateInteractionResponse> actionListener
) {
return ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
log
.debug(
"Successfully updated the Conversation with ID: {} after interaction {} is created",
conversationId,
interactionId
);
actionListener.onResponse(new CreateInteractionResponse(interactionId));
} else {
log.error("Failed to update the Conversation with ID: {} after interaction {} is created", conversationId, interactionId);
actionListener.onResponse(new CreateInteractionResponse(interactionId));
}
}, exception -> {
log
.error(
"Failed to update Conversation with ID {} after interaction {} is created. Details: {}",
conversationId,
interactionId,
exception
);
actionListener.onResponse(new CreateInteractionResponse(interactionId));
});

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

package org.opensearch.ml.memory.action.conversation;

import java.time.Instant;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -36,9 +40,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
updateRequest.doc(updateConversationRequest.getUpdateContent());
updateRequest.docAsUpsert(true);
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -393,6 +394,7 @@ public void updateConversation(String conversationId, Map<String, Object> update

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

conversationMetaIndex.updateConversation(updateRequest, listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
Expand Down Expand Up @@ -78,6 +82,8 @@ public class CreateInteractionTransportActionTests extends OpenSearchTestCase {
CreateInteractionRequest request;
CreateInteractionTransportAction action;
ThreadContext threadContext;
UpdateResponse updateResponse;
ShardId shardId;

@Before
public void setup() throws IOException {
Expand All @@ -101,6 +107,9 @@ public void setup() throws IOException {
Collections.singletonMap("metadata", "some meta")
);

shardId = new ShardId(new Index("indexName", "uuid"), 1);
updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);

Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
Expand All @@ -114,7 +123,46 @@ public void setup() throws IOException {
}

public void testCreateInteraction() {
log.info("testing create interaction transport");
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
return null;
}).when(cmHandler).updateConversation(any(), any(), any());
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(6);
listener.onResponse("testID");
return null;
}).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any());
action.doExecute(null, request, actionListener);
ArgumentCaptor<CreateInteractionResponse> argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class);
verify(actionListener).onResponse(argCaptor.capture());
assert (argCaptor.getValue().getId().equals("testID"));
}

public void testCreateInteraction_WrongUpdateStatus() {
updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED);
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
return null;
}).when(cmHandler).updateConversation(any(), any(), any());
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(6);
listener.onResponse("testID");
return null;
}).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any());
action.doExecute(null, request, actionListener);
ArgumentCaptor<CreateInteractionResponse> argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class);
verify(actionListener).onResponse(argCaptor.capture());
assert (argCaptor.getValue().getId().equals("testID"));
}

public void testCreateInteraction_UpdateException() {
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("Update Conversation Exception"));
return null;
}).when(cmHandler).updateConversation(any(), any(), any());
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(6);
listener.onResponse("testID");
Expand All @@ -138,6 +186,11 @@ public void testCreateInteraction_Trace() {
1
);

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
return null;
}).when(cmHandler).updateConversation(any(), any(), any());
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(6);
listener.onResponse("testID");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
Expand Down Expand Up @@ -79,7 +80,9 @@ public void setup() throws IOException {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
String conversationId = "test_conversation_id";
Map<String, Object> updateContent = Map.of(META_NAME_FIELD, "new name", META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123));
Map<String, Object> updateContent = new HashMap<>();
updateContent.put(META_NAME_FIELD, "new name");
updateContent.put(META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123));
when(updateRequest.getConversationId()).thenReturn(conversationId);
when(updateRequest.getUpdateContent()).thenReturn(updateContent);
shardId = new ShardId(new Index("indexName", "uuid"), 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,7 @@ public void testGetConversation_NoIndex_ThenFail() {
assert (argCaptor
.getValue()
.getMessage()
.equals(
"no such index [.plugins-ml-conversation-meta] and cannot get conversation since the conversation index does not exist"
));
.equals("no such index [.plugins-ml-memory-meta] and cannot get conversation since the conversation index does not exist"));
}

public void testGetConversation_ResponseNotExist_ThenFail() {
Expand Down Expand Up @@ -652,9 +650,7 @@ public void testUpdateConversation_NoIndex_ThenFail() {
assert (argCaptor
.getValue()
.getMessage()
.equals(
"no such index [.plugins-ml-conversation-meta] and cannot update conversation since the conversation index does not exist"
));
.equals("no such index [.plugins-ml-memory-meta] and cannot update conversation since the conversation index does not exist"));
}

public void testUpdateConversation_Success() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,7 @@ public void testGetSg_NoIndex_ThenFail() {
assert (argCaptor
.getValue()
.getMessage()
.equals(
"no such index [.plugins-ml-conversation-interactions] and cannot get interaction since the interactions index does not exist"
));
.equals("no such index [.plugins-ml-memory-message] and cannot get interaction since the interactions index does not exist"));
}

public void testGetSg_InteractionNotExist_ThenFail() {
Expand Down
Loading

0 comments on commit 06f8b2a

Please sign in to comment.