Skip to content

Commit

Permalink
add updated_time in conversation meta index and update transport acti…
Browse files Browse the repository at this point in the history
…ons to support it

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Nov 27, 2023
1 parent c346f2a commit d0b493c
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class ConversationMeta implements Writeable, ToXContentObject {
@Getter
private Instant createdTime;
@Getter
private Instant updatedTime;
@Getter
private String name;
@Getter
private String user;
Expand All @@ -66,9 +68,10 @@ public static ConversationMeta fromSearchHit(SearchHit hit) {
*/
public static ConversationMeta fromMap(String id, Map<String, Object> docFields) {
Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_FIELD));
Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_FIELD));
String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD);
String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD);
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

/**
Expand All @@ -81,15 +84,17 @@ public static ConversationMeta fromMap(String id, Map<String, Object> docFields)
public static ConversationMeta fromStream(StreamInput in) throws IOException {
String id = in.readString();
Instant created = in.readInstant();
Instant updated = in.readInstant();
String name = in.readString();
String user = in.readOptionalString();
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeInstant(createdTime);
out.writeInstant(updatedTime);
out.writeString(name);
out.writeOptionalString(user);
}
Expand All @@ -104,6 +109,7 @@ public IndexRequest toIndexRequest(String index) {
IndexRequest request = new IndexRequest(index);
return request.id(this.id).source(
ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime,
ConversationalIndexConstants.META_UPDATED_FIELD, this.createdTime,
ConversationalIndexConstants.META_NAME_FIELD, this.name
);
}
Expand All @@ -113,6 +119,7 @@ public String toString() {
return "{id=" + id
+ ", name=" + name
+ ", created=" + createdTime.toString()
+ ", updated=" + updatedTime.toString()
+ ", user=" + user
+ "}";
}
Expand All @@ -122,6 +129,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
builder.startObject();
builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id);
builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime);
builder.field(ConversationalIndexConstants.META_UPDATED_FIELD, this.updatedTime);
builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name);
if(this.user != null) {
builder.field(ConversationalIndexConstants.USER_FIELD, this.user);
Expand All @@ -139,6 +147,7 @@ public boolean equals(Object other) {
return Objects.equals(this.id, otherConversation.id) &&
Objects.equals(this.user, otherConversation.user) &&
Objects.equals(this.createdTime, otherConversation.createdTime) &&
Objects.equals(this.updatedTime, otherConversation.updatedTime) &&
Objects.equals(this.name, otherConversation.name);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public class ConversationalIndexConstants {
public final static String META_INDEX_NAME = ".plugins-ml-memory-meta";
/** Name of the metadata field for initial timestamp */
public final static String META_CREATED_FIELD = "create_time";
/** Name of the metadata field for updated timestamp */
public final static String META_UPDATED_FIELD = "updated_time";
/** Name of the metadata field for name of the conversation */
public final static String META_NAME_FIELD = "name";
/** Name of the owning user field in all indices */
Expand All @@ -48,6 +50,9 @@ public class ConversationalIndexConstants {
+ META_CREATED_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ META_UPDATED_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
Expand Down Expand Up @@ -238,4 +239,10 @@ public ActionFuture<String> createInteraction(
*/
public ActionFuture<SearchResponse> searchInteractions(String conversationId, SearchRequest request);

/**
* Update a conversation
* @param updateContent update content for the conversations index
* @param listener receives the update response
*/
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand Down Expand Up @@ -93,10 +96,10 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
Integer traceNumber = request.getTraceNum();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener
.wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> {
internalListener.onFailure(e);
});
ActionListener<String> al = ActionListener.wrap(iid -> {
cmHandler.updateConversation(cid, new HashMap<>(), getUpdateResponseListener(cid, iid, internalListener));
// internalListener.onResponse(new CreateInteractionResponse(iid));
}, e -> { internalListener.onFailure(e); });
if (parentId == null || traceNumber == null) {
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al);
} else {
Expand All @@ -108,4 +111,29 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String conversationId,
String interactionId,
ActionListener<CreateInteractionResponse> actionListener
) {
return ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Failed to update the Conversation with ID: {} after interaction {} is created", conversationId, interactionId);
actionListener.onResponse(new CreateInteractionResponse(conversationId));
return;
}
log.info("Successfully updated the Conversation with ID: {} after interaction {} is created", conversationId);
actionListener.onResponse(new CreateInteractionResponse(conversationId));
}, exception -> {
log
.error(
"Failed to update Conversation with ID {} after interaction {} is created. Details: {}",
conversationId,
interactionId,
exception
);
actionListener.onResponse(new CreateInteractionResponse(conversationId));
});

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.memory.action.conversation;

import java.time.Instant;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -36,7 +39,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
updateRequest.doc(updateConversationRequest.getUpdateContent());
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -128,6 +130,8 @@ public void createConversation(String name, String applicationType, ActionListen
.source(
ConversationalIndexConstants.META_CREATED_FIELD,
Instant.now(),
ConversationalIndexConstants.META_UPDATED_FIELD,
Instant.now(),
ConversationalIndexConstants.META_NAME_FIELD,
name,
ConversationalIndexConstants.USER_FIELD,
Expand Down Expand Up @@ -346,4 +350,18 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
listener.onFailure(e);
}
}

/**
* Update conversations in the index
* @param updateRequest original update request
* @param listener receives the update response for the wrapped query
*/
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, listener);
} catch (Exception e) {
log.error("Failed to update Conversation. Details {}:", e);
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
Expand Down Expand Up @@ -383,4 +386,13 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList
interactionsIndex.getTraces(interactionId, from, maxResults, listener);
}

public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener) {
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);

conversationMetaIndex.updateConversation(updateRequest, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ public class GetConversationsResponseTests extends OpenSearchTestCase {
public void setup() {
conversations = List
.of(
new ConversationMeta("0", Instant.now(), "name0", "user0"),
new ConversationMeta("1", Instant.now(), "name1", "user0"),
new ConversationMeta("2", Instant.now(), "name2", "user2")
new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0"),
new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0"),
new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2")
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ public void testGetConversations() {
log.info("testing get conversations transport");
List<ConversationMeta> testResult = List
.of(
new ConversationMeta("testcid1", Instant.now(), "", null),
new ConversationMeta("testcid2", Instant.now(), "testname", null)
new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null),
new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null)
);
doAnswer(invocation -> {
ActionListener<List<ConversationMeta>> listener = invocation.getArgument(2);
Expand All @@ -130,9 +130,9 @@ public void testGetConversations() {
public void testPagination() {
List<ConversationMeta> testResult = List
.of(
new ConversationMeta("testcid1", Instant.now(), "", null),
new ConversationMeta("testcid2", Instant.now(), "testname", null),
new ConversationMeta("testcid3", Instant.now(), "testname", null)
new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null),
new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null),
new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null)
);
doAnswer(invocation -> {
ActionListener<List<ConversationMeta>> listener = invocation.getArgument(2);
Expand Down

0 comments on commit d0b493c

Please sign in to comment.