From 920685d042c14a0a4613efb466d9d9c277f91e80 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Mon, 5 Aug 2024 16:31:39 +0800 Subject: [PATCH] init commit for adding additional info for memory metadata (#2750) create conversation support additional info add test for search conversation add bwc Signed-off-by: Hailong Cui --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../common/conversation/ConversationMeta.java | 31 ++++++- .../ConversationalIndexConstants.java | 10 ++- .../conversation/ConversationMetaTests.java | 11 ++- .../memory/ConversationalMemoryHandler.java | 14 ++++ .../CreateConversationRequest.java | 45 +++++++++-- .../CreateConversationTransportAction.java | 5 +- .../UpdateConversationRequest.java | 3 +- .../memory/index/ConversationMetaIndex.java | 17 +++- ...OpenSearchConversationalMemoryHandler.java | 18 ++++- .../CreateConversationRequestTests.java | 26 +++++- ...reateConversationTransportActionTests.java | 10 +-- .../GetConversationResponseTests.java | 25 +++++- .../GetConversationTransportActionTests.java | 2 +- .../GetConversationsResponseTests.java | 6 +- .../GetConversationsTransportActionTests.java | 10 +-- .../index/ConversationMetaIndexITTests.java | 81 +++++++++++++++++++ ...earchConversationalMemoryHandlerTests.java | 18 ++++- 18 files changed, 289 insertions(+), 44 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index b4a1a665a5..39da1edd23 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -541,4 +541,5 @@ public class CommonValue { public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); + public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index ae38ab7429..9cc3b49bc4 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -22,23 +22,26 @@ import java.util.Map; import java.util.Objects; -import org.opensearch.action.index.IndexRequest; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.CommonValue; import org.opensearch.search.SearchHit; import lombok.AllArgsConstructor; import lombok.Getter; +import static org.opensearch.ml.common.CommonValue.VERSION_2_17_0; + /** * Class for holding conversational metadata */ @AllArgsConstructor public class ConversationMeta implements Writeable, ToXContentObject { - + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0; @Getter private String id; @Getter @@ -49,6 +52,8 @@ public class ConversationMeta implements Writeable, ToXContentObject { private String name; @Getter private String user; + @Getter + private Map additionalInfos; /** * Creates a conversationMeta object from a SearchHit object @@ -71,7 +76,8 @@ public static ConversationMeta fromMap(String id, Map docFields) Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD)); String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); - return new ConversationMeta(id, created, updated, name, user); + Map additionalInfos = (Map)docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD); + return new ConversationMeta(id, created, updated, name, user, additionalInfos); } /** @@ -87,7 +93,13 @@ public static ConversationMeta fromStream(StreamInput in) throws IOException { Instant updated = in.readInstant(); String name = in.readString(); String user = in.readOptionalString(); - return new ConversationMeta(id, created, updated, name, user); + Map additionalInfos = null; + if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { + if (in.readBoolean()) { + additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + return new ConversationMeta(id, created, updated, name, user, additionalInfos); } @Override @@ -97,6 +109,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(updatedTime); out.writeString(name); out.writeOptionalString(user); + if(out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { + if (additionalInfos == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString); + } + } } @Override @@ -119,6 +139,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para if(this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); } + if (this.additionalInfos != null) { + builder.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, this.additionalInfos); + } builder.endObject(); return builder; } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index bf65b778dd..b542864726 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -24,7 +24,7 @@ */ public class ConversationalIndexConstants { /** Version of the meta index schema */ - public final static Integer META_INDEX_SCHEMA_VERSION = 1; + public final static Integer META_INDEX_SCHEMA_VERSION = 2; /** Name of the conversational metadata index */ public final static String META_INDEX_NAME = ".plugins-ml-memory-meta"; /** Name of the metadata field for initial timestamp */ @@ -37,6 +37,9 @@ public class ConversationalIndexConstants { public final static String USER_FIELD = "user"; /** Name of the application that created this conversation */ public final static String APPLICATION_TYPE_FIELD = "application_type"; + /** Name of the additional information for this memory */ + public final static String META_ADDITIONAL_INFO_FIELD = "additional_info"; + /** Mappings for the conversational metadata index */ public final static String META_MAPPING = "{\n" + " \"_meta\": {\n" @@ -57,7 +60,10 @@ public class ConversationalIndexConstants { + "\": {\"type\": \"keyword\"},\n" + " \"" + APPLICATION_TYPE_FIELD - + "\": {\"type\": \"keyword\"}\n" + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + META_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"flat_object\"}\n" + " }\n" + "}"; diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java index 304703d34f..2b4e628d05 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -20,6 +20,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class ConversationMetaTests { @@ -30,7 +31,7 @@ public class ConversationMetaTests { @Before public void setUp() { time = Instant.now(); - conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin"); + conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", null); } @Test @@ -41,6 +42,7 @@ public void test_fromSearchHit() throws IOException { content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time); content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name"); content.field(ConversationalIndexConstants.USER_FIELD, "admin"); + content.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, Map.of("test_key", "test_value")); content.endObject(); SearchHit[] hits = new SearchHit[1]; @@ -50,6 +52,7 @@ public void test_fromSearchHit() throws IOException { assertEquals(conversationMeta.getId(), "cId"); assertEquals(conversationMeta.getName(), "meta name"); assertEquals(conversationMeta.getUser(), "admin"); + assertEquals(conversationMeta.getAdditionalInfos().get("test_key"), "test_value"); } @Test @@ -85,7 +88,7 @@ public void test_fromStream() throws IOException { @Test public void test_ToXContent() throws IOException { - ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); conversationMeta.toXContent(builder, EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); @@ -94,13 +97,13 @@ public void test_ToXContent() throws IOException { @Test public void test_toString() { - ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString()); } @Test public void test_equal() { - ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); assertEquals(meta.equals(conversationMeta), false); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index b553a222a9..aa24016b96 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -68,6 +68,20 @@ public interface ConversationalMemoryHandler { */ public void createConversation(String name, String applicationType, ActionListener listener); + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param additionalInfos additional information associated with this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ); + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 950a2c5a88..991ddde2a7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -18,15 +18,19 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD; import java.io.IOException; import java.util.Map; import org.opensearch.OpenSearchParseException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.rest.RestRequest; @@ -36,10 +40,14 @@ * Action Request for creating a conversation */ public class CreateConversationRequest extends ActionRequest { + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0; + @Getter private String name = null; @Getter private String applicationType = null; + @Getter + private Map additionalInfos = null; /** * Constructor @@ -50,6 +58,11 @@ public CreateConversationRequest(StreamInput in) throws IOException { super(in); this.name = in.readOptionalString(); this.applicationType = in.readOptionalString(); + if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { + if (in.readBoolean()) { + this.additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString); + } + } } /** @@ -71,6 +84,19 @@ public CreateConversationRequest(String name, String applicationType) { this.applicationType = applicationType; } + /** + * Constructor + * @param name name of the conversation + * @param applicationType of the conversation + * @param additionalInfos information of the conversation + */ + public CreateConversationRequest(String name, String applicationType, Map additionalInfos) { + super(); + this.name = name; + this.applicationType = applicationType; + this.additionalInfos = additionalInfos; + } + /** * Constructor * name will be null @@ -82,6 +108,14 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(name); out.writeOptionalString(applicationType); + if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { + if (additionalInfos == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString); + } + } } @Override @@ -101,12 +135,13 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) if (!restRequest.hasContent()) { return new CreateConversationRequest(); } - try { - Map body = restRequest.contentParser().mapStrings(); - if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) { + try (XContentParser parser = restRequest.contentParser()) { + Map body = parser.map(); + if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { return new CreateConversationRequest( - body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) + (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), + body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), + body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) ); } else { return new CreateConversationRequest(); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java index 0a882b00dd..11bc569d00 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -19,6 +19,8 @@ import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE; +import java.util.Map; + import org.opensearch.OpenSearchException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -79,6 +81,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis } String name = request.getName(); String applicationType = request.getApplicationType(); + Map additionalInfos = request.getAdditionalInfos(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> { @@ -89,7 +92,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis if (name == null) { cmHandler.createConversation(al); } else { - cmHandler.createConversation(name, applicationType, al); + cmHandler.createConversation(name, applicationType, additionalInfos, al); } } catch (Exception e) { log.error("Failed to create new memory with name " + request.getName(), e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java index 7afec5d0ab..de0907099f 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; import java.io.ByteArrayInputStream; @@ -35,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest { private String conversationId; private Map updateContent; - private static final Set allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD)); + private static final Set allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, META_ADDITIONAL_INFO_FIELD)); @Builder public UpdateConversationRequest(String conversationId, Map updateContent) { diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index d4dde66326..2bcd2541fa 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -24,6 +24,7 @@ import java.time.Instant; import java.util.LinkedList; import java.util.List; +import java.util.Map; import org.opensearch.OpenSearchStatusException; import org.opensearch.OpenSearchWrapperException; @@ -127,9 +128,15 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) * Adds a new conversation with the specified name to the index * @param name user-specified name of the conversation to be added * @param applicationType the application type that creates this conversation + * @param additionalInfos the additional info that creates this conversation * @param listener listener to wait for this to finish */ - public void createConversation(String name, String applicationType, ActionListener listener) { + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { String userstr = getUserStrFromThreadContext(); @@ -146,7 +153,9 @@ public void createConversation(String name, String applicationType, ActionListen ConversationalIndexConstants.USER_FIELD, userstr == null ? null : User.parse(userstr).getName(), ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType + applicationType, + ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, + additionalInfos == null ? Map.of() : additionalInfos ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -177,7 +186,7 @@ public void createConversation(String name, String applicationType, ActionListen * @param listener listener to wait for this to finish */ public void createConversation(ActionListener listener) { - createConversation("", "", listener); + createConversation("", "", null, listener); } /** @@ -186,7 +195,7 @@ public void createConversation(ActionListener listener) { * @param listener listener to wait for this to finish */ public void createConversation(String name, ActionListener listener) { - createConversation(name, "", listener); + createConversation(name, "", null, listener); } /** diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 89a128e0f3..755d207a86 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -103,7 +103,23 @@ public void createConversation(String name, ActionListener listener) { * @param listener listener to wait for this op to finish, gets unique id of new conversation */ public void createConversation(String name, String applicationType, ActionListener listener) { - conversationMetaIndex.createConversation(name, applicationType, listener); + conversationMetaIndex.createConversation(name, applicationType, null, listener); + } + + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param additionalInfos the additional information associated with this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ) { + conversationMetaIndex.createConversation(name, applicationType, additionalInfos, listener); } /** diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index dd50f8fffe..0f2dd2b5ce 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -18,14 +18,15 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD; import java.io.IOException; import java.util.Map; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; -import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -107,11 +108,28 @@ public void testNamedRestRequest_WithAppType() throws IOException { } public void testRestRequest_NullName() throws IOException { - exceptionRule.expect(OpenSearchParseException.class); - exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withContent(new BytesArray("{\"name\":null}"), MediaTypeRegistry.JSON) .build(); - CreateConversationRequest.fromRestRequest(req); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + Assert.assertNull(request.getName()); + } + + public void testRestRequest_WithAdditionalInfo() throws IOException { + String name = "test-name"; + Map additionalInfo = Map.of("key1", "value1", "key2", 123); + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray( + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, META_ADDITIONAL_INFO_FIELD, additionalInfo)) + ), + MediaTypeRegistry.JSON + ) + .build(); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + assert (request.getName().equals(name)); + Assert.assertNull(request.getApplicationType()); + Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); + Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java index c6df207d16..1946bd3d8f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java @@ -111,10 +111,10 @@ public void setup() throws IOException { public void testCreateConversation() { log.info("testing create conversation transport"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(3); listener.onResponse("testID"); return null; - }).when(cmHandler).createConversation(any(), any(), any()); + }).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); verify(actionListener).onResponse(argCaptor.capture()); @@ -137,10 +137,10 @@ public void testCreateConversationWithNullName() { public void testCreateConversationFails_thenFail() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Testing Error")); return null; - }).when(cmHandler).createConversation(any(), any(), any()); + }).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); @@ -148,7 +148,7 @@ public void testCreateConversationFails_thenFail() { } public void testDoExecuteFails_thenFail() { - doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any()); + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index b3ed2f14ff..0b39d546f8 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -19,8 +19,10 @@ import java.io.IOException; import java.time.Instant; +import java.util.Map; import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Assert; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; @@ -36,7 +38,7 @@ public class GetConversationResponseTests extends OpenSearchTestCase { public void testGetConversationResponseStreaming() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); GetConversationResponse response = new GetConversationResponse(convo); assert (response.getConversation().equals(convo)); @@ -49,7 +51,7 @@ public void testGetConversationResponseStreaming() throws IOException { } public void testToXContent() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -63,4 +65,23 @@ public void testToXContent() throws IOException { LevenshteinDistance ld = new LevenshteinDistance(); assert (ld.getDistance(result, expected) > 0.95); } + + public void testToXContent_withAdditionalInfo() throws IOException { + Map additionalInfos = Map.of("key1", "value1"); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, additionalInfos); + GetConversationResponse response = new GetConversationResponse(convo); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"memory_id\":\"cid\",\"create_time\":\"" + + convo.getCreatedTime() + + "\",\"updated_time\":\"" + + convo.getUpdatedTime() + + "\",\"name\":\"name\"" + + ",\"additional_info\":{\"key1\":\"value1\"}" + + "}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + Assert.assertTrue(ld.getDistance(result, expected) > 0.95); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java index 83137292db..558ecd9b65 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -107,7 +107,7 @@ public void setup() throws IOException { } public void testGetConversation() { - ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null); + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java index a04a022973..b28ed26d0f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -46,9 +46,9 @@ public class GetConversationsResponseTests extends OpenSearchTestCase { public void setup() { conversations = List .of( - new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0"), - new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0"), - new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2") + new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0", null), + new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0", null), + new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2", null) ); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java index f349ac6c97..a866167d37 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -114,8 +114,8 @@ public void testGetConversations() { log.info("testing get conversations transport"); List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); @@ -132,9 +132,9 @@ public void testGetConversations() { public void testPagination() { List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null), - new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null), + new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 74ee62bb73..5baefa358d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -20,11 +20,14 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.Stack; import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; +import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.opensearch.OpenSearchStatusException; @@ -561,6 +564,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); + Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); + Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); cdl.countDown(); }, e -> { cdl.countDown(); @@ -634,4 +639,80 @@ public void testCanGetAConversationByIdSecurely() { } } + public void testCanCreateConversationWithAdditionalInfo() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener cid1 = new StepListener<>(); + index.createConversation("hailong-convo", "app", Map.of("k", "v"), cid1); + + StepListener get1 = new StepListener<>(); + cid1.whenComplete(cid -> { index.getConversation(cid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + get1.whenComplete(convo1 -> { + try { + Assert.assertEquals(cid1.result(), convo1.getId()); + Assert.assertEquals("hailong-convo", convo1.getName()); + Assert.assertNotNull(convo1.getAdditionalInfos()); + Assert.assertEquals("v", convo1.getAdditionalInfos().get("k")); + } finally { + cdl.countDown(); + } + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testCanQueryOverConversationsByAdditionalInfo() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener convo1 = new StepListener<>(); + index.createConversation("Conversation1", "app", Map.of("k1", "v1"), convo1); + + StepListener convo2 = new StepListener<>(); + convo1.whenComplete(cid -> { index.createConversation("Mehul Conversation", convo2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener search = new StepListener<>(); + convo2.whenComplete(cid -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(QueryBuilders.matchQuery(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD + ".k1", "v1")); + index.searchConversations(request, search); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + search.whenComplete(response -> { + log.info("SEARCH RESPONSE"); + log.info(response.toString()); + cdl.countDown(); + assert (response.getHits().getAt(0).getId().equals(convo1.result())); + Assert.assertEquals(1L, Objects.requireNonNull(response.getHits().getTotalHits()).value); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index bf4fafd1b6..903be08338 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -29,6 +29,9 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -74,7 +77,7 @@ public void testCreateConversation_NoName_FutureSuccess() { assert (result.actionGet(200).equals("cid")); } - public void testCreateConversation_Named_FutureSucess() { + public void testCreateConversation_Named_FutureSuccess() { doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse("cid"); @@ -84,6 +87,17 @@ public void testCreateConversation_Named_FutureSucess() { assert (result.actionGet(200).equals("cid")); } + public void testCreateConversation_AdditionalInfo_Success() throws Exception { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(3); + al.onResponse("cid"); + return null; + }).when(conversationMetaIndex).createConversation(anyString(), anyString(), any(), any()); + CompletableFuture future = new CompletableFuture<>(); + cmHandler.createConversation("FutureSuccess", "", Map.of(), ActionListener.wrap(future::complete, future::completeExceptionally)); + assert (future.get(200, TimeUnit.MILLISECONDS).equals("cid")); + } + public void testCreateInteraction_Future() { doAnswer(invocation -> { ActionListener al = invocation.getArgument(7); @@ -301,7 +315,7 @@ public void testSearchInteractions_Future() { } public void testGetAConversation_Future() { - ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null); + ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(response);