From 01cf00c2e86fb09e41ff95cae16e4e1a0d29ab0a Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Tue, 12 Dec 2023 13:34:21 -0800 Subject: [PATCH 1/7] add new data fields in the memory layer and update tests (#1730) * add new data fields in the memory layer and update tests Signed-off-by: Xun Zhang * add more tests coverage and address comments Signed-off-by: Xun Zhang * address comments and more tests Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../common/conversation/ActionConstants.java | 4 + .../common/conversation/ConversationMeta.java | 36 ++- .../ConversationalIndexConstants.java | 28 +- .../ml/common/conversation/Interaction.java | 61 ++++- .../conversation/ConversationMetaTests.java | 106 ++++++++ .../common/conversation/InteractionTests.java | 190 +++++++++++++ .../memory/ConversationalMemoryHandler.java | 54 +++- .../CreateConversationRequest.java | 21 +- .../CreateConversationTransportAction.java | 3 +- .../CreateInteractionRequest.java | 92 ++++++- .../CreateInteractionTransportAction.java | 12 +- .../memory/index/ConversationMetaIndex.java | 67 +++-- .../ml/memory/index/InteractionsIndex.java | 134 +++++++-- ...OpenSearchConversationalMemoryHandler.java | 71 ++++- .../ConversationalMemoryHandlerITTests.java | 254 +++++++++++++----- .../CreateConversationRequestTests.java | 15 ++ ...reateConversationTransportActionTests.java | 12 +- .../CreateInteractionRequestTests.java | 63 ++++- ...CreateInteractionTransportActionTests.java | 33 ++- .../GetConversationResponseTests.java | 10 +- .../GetConversationTransportActionTests.java | 2 +- .../GetConversationsResponseTests.java | 10 +- .../GetConversationsTransportActionTests.java | 10 +- .../GetInteractionResponseTests.java | 26 +- .../GetInteractionTransportActionTests.java | 3 +- .../GetInteractionsResponseTests.java | 38 ++- .../GetInteractionsTransportActionTests.java | 5 +- .../index/ConversationMetaIndexITTests.java | 10 +- .../index/ConversationMetaIndexTests.java | 53 +++- .../index/InteractionsIndexITTests.java | 171 ++++++++---- .../memory/index/InteractionsIndexTests.java | 106 +++++++- ...earchConversationalMemoryHandlerTests.java | 66 ++++- ...estMemoryCreateInteractionActionTests.java | 7 +- .../RestMemoryGetInteractionActionIT.java | 9 +- .../RestMemoryGetInteractionsActionIT.java | 17 +- .../GenerativeQAResponseProcessor.java | 2 +- .../client/ConversationalMemoryClient.java | 3 +- .../GenerativeQAResponseProcessorTests.java | 49 +++- .../ConversationalMemoryClientTests.java | 4 +- 39 files changed, 1564 insertions(+), 293 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index f87da7c433..8776c618b0 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -48,6 +48,10 @@ public class ActionConstants { public final static String PROMPT_TEMPLATE_FIELD = "prompt_template"; /** name of metadata field in all requests */ public final static String ADDITIONAL_INFO_FIELD = "additional_info"; + /** name of metadata field in all requests */ + public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id"; + /** name of metadata field in all requests */ + public final static String TRACE_NUMBER_FIELD = "trace_number"; /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index 8ba518a065..ae38ab7429 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -44,6 +44,8 @@ public class ConversationMeta implements Writeable, ToXContentObject { @Getter private Instant createdTime; @Getter + private Instant updatedTime; + @Getter private String name; @Getter private String user; @@ -65,10 +67,11 @@ public static ConversationMeta fromSearchHit(SearchHit hit) { * @return a new conversationMeta object representing the map */ public static ConversationMeta fromMap(String id, Map docFields) { - Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_FIELD)); + Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_TIME_FIELD)); + Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD)); String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); - return new ConversationMeta(id, created, name, user); + return new ConversationMeta(id, created, updated, name, user); } /** @@ -81,38 +84,27 @@ public static ConversationMeta fromMap(String id, Map docFields) public static ConversationMeta fromStream(StreamInput in) throws IOException { String id = in.readString(); Instant created = in.readInstant(); + Instant updated = in.readInstant(); String name = in.readString(); String user = in.readOptionalString(); - return new ConversationMeta(id, created, name, user); + return new ConversationMeta(id, created, updated, name, user); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); out.writeInstant(createdTime); + out.writeInstant(updatedTime); out.writeString(name); out.writeOptionalString(user); } - - /** - * Convert this conversationMeta object into an IndexRequest so it can be indexed - * @param index the index to send this conversation to. Should usually be .conversational-meta - * @return the IndexRequest for the client to send - */ - public IndexRequest toIndexRequest(String index) { - IndexRequest request = new IndexRequest(index); - return request.id(this.id).source( - ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime, - ConversationalIndexConstants.META_NAME_FIELD, this.name - ); - } - @Override public String toString() { return "{id=" + id + ", name=" + name + ", created=" + createdTime.toString() + + ", updated=" + updatedTime.toString() + ", user=" + user + "}"; } @@ -121,7 +113,8 @@ public String toString() { public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { builder.startObject(); builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id); - builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime); + builder.field(ConversationalIndexConstants.META_CREATED_TIME_FIELD, this.createdTime); + builder.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, this.updatedTime); builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name); if(this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); @@ -137,9 +130,10 @@ public boolean equals(Object other) { } ConversationMeta otherConversation = (ConversationMeta) other; return Objects.equals(this.id, otherConversation.id) && - Objects.equals(this.user, otherConversation.user) && - Objects.equals(this.createdTime, otherConversation.createdTime) && - Objects.equals(this.name, otherConversation.name); + Objects.equals(this.user, otherConversation.user) && + Objects.equals(this.createdTime, otherConversation.createdTime) && + Objects.equals(this.updatedTime, otherConversation.updatedTime) && + Objects.equals(this.name, otherConversation.name); } } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index c8e652265b..9d85d0b6cd 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -28,11 +28,15 @@ public class ConversationalIndexConstants { /** Name of the conversational metadata index */ public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta"; /** Name of the metadata field for initial timestamp */ - public final static String META_CREATED_FIELD = "create_time"; + public final static String META_CREATED_TIME_FIELD = "create_time"; + /** Name of the metadata field for updated timestamp */ + public final static String META_UPDATED_TIME_FIELD = "updated_time"; /** Name of the metadata field for name of the conversation */ public final static String META_NAME_FIELD = "name"; /** Name of the owning user field in all indices */ public final static String USER_FIELD = "user"; + /** Name of the application that created this conversation */ + public final static String APPLICATION_TYPE_FIELD = "application_type"; /** Mappings for the conversational metadata index */ public final static String META_MAPPING = "{\n" + " \"_meta\": {\n" @@ -41,12 +45,18 @@ public class ConversationalIndexConstants { + " \"properties\": {\n" + " \"" + META_NAME_FIELD - + "\": {\"type\": \"keyword\"},\n" + + "\": {\"type\": \"text\"},\n" + " \"" - + META_CREATED_FIELD + + META_CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + META_UPDATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + "\": {\"type\": \"keyword\"}\n" + " }\n" + "}"; @@ -69,6 +79,10 @@ public class ConversationalIndexConstants { public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info"; /** Name of the interaction field for the timestamp */ public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time"; + /** Name of the interaction id */ + public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id"; + /** The trace number of an interaction */ + public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number"; /** Mappings for the interactions index */ public final static String INTERACTIONS_MAPPINGS = "{\n" + " \"_meta\": {\n" @@ -95,7 +109,13 @@ public class ConversationalIndexConstants { + "\": {\"type\": \"keyword\"},\n" + " \"" + INTERACTIONS_ADDITIONAL_INFO_FIELD - + "\": {\"type\": \"text\"}\n" + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + " }\n" + "}"; diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 9b6ec636bd..8e06569672 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; @@ -54,7 +55,25 @@ public class Interaction implements Writeable, ToXContentObject { @Getter private String origin; @Getter - private String additionalInfo; + private Map additionalInfo; + @Getter + private String parentInteractionId; + @Getter + private Integer traceNum; + + @Builder(toBuilder = true) + public Interaction(String id, Instant createTime, String conversationId, String input, String promptTemplate, String response, String origin, Map additionalInfo) { + this.id = id; + this.createTime = createTime; + this.conversationId = conversationId; + this.input = input; + this.promptTemplate = promptTemplate; + this.response = response; + this.origin = origin; + this.additionalInfo = additionalInfo; + this.parentInteractionId = null; + this.traceNum = null; + } /** * Creates an Interaction object from a map of fields in the OS index @@ -69,8 +88,10 @@ public static Interaction fromMap(String id, Map fields) { String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); - String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); - return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); + Map additionalInfo = (Map) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); + String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null); + Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null); + return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum); } /** @@ -97,8 +118,13 @@ public static Interaction fromStream(StreamInput in) throws IOException { String promptTemplate = in.readString(); String response = in.readString(); String origin = in.readString(); - String additionalInfo = in.readOptionalString(); - return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); + Map additionalInfo = new HashMap<>(); + if (in.readBoolean()) { + additionalInfo = in.readMap(s -> s.readString(), s -> s.readString()); + } + String parentInteractionId = in.readOptionalString(); + Integer traceNum = in.readOptionalInt(); + return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum); } @@ -111,7 +137,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(promptTemplate); out.writeString(response); out.writeString(origin); - out.writeOptionalString(additionalInfo); + if (additionalInfo != null) { + out.writeBoolean(true); + out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parentInteractionId); + out.writeOptionalInt(traceNum); } @Override @@ -127,6 +160,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para if(additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } + if (parentInteractionId != null) { + builder.field(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId); + } + if (traceNum != null) { + builder.field(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNum); + } builder.endObject(); return builder; } @@ -143,7 +182,12 @@ public boolean equals(Object other) { ((Interaction) other).response.equals(this.response) && ((Interaction) other).origin.equals(this.origin) && ( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) || - ((Interaction) other).additionalInfo.equals(this.additionalInfo)) + ((Interaction) other).additionalInfo.equals(this.additionalInfo)) && + ( (((Interaction) other).parentInteractionId == null && this.parentInteractionId == null) || + ((Interaction) other).parentInteractionId.equals(this.parentInteractionId)) && + ( (((Interaction) other).traceNum == null && this.traceNum == null) || + ((Interaction) other).traceNum.equals(this.traceNum)) + ); } @@ -158,8 +202,9 @@ public String toString() { + ",promt_template=" + promptTemplate + ",response=" + response + ",additional_info=" + additionalInfo + + ",parentInteractionId=" + parentInteractionId + + ",traceNum=" + traceNum + "}"; } - } \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java new file mode 100644 index 0000000000..febb29fbf1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.conversation; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchHit; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class ConversationMetaTests { + + ConversationMeta conversationMeta; + Instant time; + + @Before + public void setUp() { + time = Instant.now(); + conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin"); + } + + @Test + public void test_fromSearchHit() throws IOException { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.META_CREATED_TIME_FIELD, time); + content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time); + content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name"); + content.field(ConversationalIndexConstants.USER_FIELD, "admin"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "cId", null, null).sourceRef(BytesReference.bytes(content)); + + ConversationMeta conversationMeta = ConversationMeta.fromSearchHit(hits[0]); + assertEquals(conversationMeta.getId(), "cId"); + assertEquals(conversationMeta.getName(), "meta name"); + assertEquals(conversationMeta.getUser(), "admin"); + } + + @Test + public void test_fromMap() { + Map params = Map + .of( + ConversationalIndexConstants.META_CREATED_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.META_UPDATED_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.META_NAME_FIELD, + "meta name", + ConversationalIndexConstants.USER_FIELD, + "admin" + ); + ConversationMeta conversationMeta = ConversationMeta.fromMap("test-conversation-meta", params); + assertEquals(conversationMeta.getId(), "test-conversation-meta"); + assertEquals(conversationMeta.getName(), "meta name"); + assertEquals(conversationMeta.getUser(), "admin"); + } + + @Test + public void test_fromStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + conversationMeta.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ConversationMeta meta = ConversationMeta.fromStream(streamInput); + assertEquals(meta.getId(), conversationMeta.getId()); + assertEquals(meta.getName(), conversationMeta.getName()); + assertEquals(meta.getUser(), conversationMeta.getUser()); + } + + @Test + public void test_ToXContent() throws IOException { + ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + conversationMeta.toXContent(builder, EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + assertEquals(content, "{\"conversation_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}"); + } + + @Test + public void test_toString() { + ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString()); + } + + @Test + public void test_equal() { + ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin"); + assertEquals(meta.equals(conversationMeta), false); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java new file mode 100644 index 0000000000..c704547050 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.conversation; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchHit; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class InteractionTests { + + Interaction interaction; + Instant time; + + @Before + public void setUp() { + time = Instant.ofEpochMilli(123); + interaction = Interaction.builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .parentInteractionId("parent id") + .traceNum(1) + .build(); + } + + @Test + public void test_fromMap() { + Map params = Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + "conversation-id", + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "sample inputs", + ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, + "some prompt template", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "sample responses", + ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, + "amazon bedrock", + ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, + Collections.singletonMap("suggestion", "new suggestion"), + ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, + "parent id", + ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, + 1 + ); + Interaction interaction = Interaction.fromMap("test-interaction-id", params); + assertEquals(interaction.getId(), "test-interaction-id"); + assertEquals(interaction.getCreateTime(), time); + assertEquals(interaction.getInput(), "sample inputs"); + assertEquals(interaction.getPromptTemplate(), "some prompt template"); + assertEquals(interaction.getResponse(), "sample responses"); + assertEquals(interaction.getOrigin(), "amazon bedrock"); + assertEquals(interaction.getAdditionalInfo(), Collections.singletonMap("suggestion", "new suggestion")); + assertEquals(interaction.getParentInteractionId(), "parent id"); + assertEquals(interaction.getTraceNum().toString(), "1"); + } + + @Test + public void test_fromSearchHit() throws IOException { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, time); + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); + + Interaction interaction = Interaction.fromSearchHit(hits[0]); + assertEquals(interaction.getId(), "iId"); + assertEquals(interaction.getCreateTime(), time); + assertEquals(interaction.getInput(), "sample inputs"); + } + + @Test + public void test_fromStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + interaction.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + Interaction interaction1 = Interaction.fromStream(streamInput); + assertEquals(interaction1.getId(), interaction.getId()); + assertEquals(interaction1.getParentInteractionId(), interaction.getParentInteractionId()); + assertEquals(interaction1.getResponse(), interaction.getResponse()); + assertEquals(interaction1.getOrigin(), interaction.getOrigin()); + assertEquals(interaction1.getPromptTemplate(), interaction.getPromptTemplate()); + assertEquals(interaction1.getAdditionalInfo(), interaction.getAdditionalInfo()); + assertEquals(interaction1.getTraceNum(), interaction.getTraceNum()); + assertEquals(interaction1.getConversationId(), interaction.getConversationId()); + } + + @Test + public void test_ToXContent() throws IOException { + Interaction interaction = Interaction.builder() + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parant id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + interaction.toXContent(builder, EMPTY_PARAMS); + String interactionContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"conversation_id\":\"conversation id\",\"interaction_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_interaction_id\":\"parant id\",\"trace_number\":1}", interactionContent); + } + + @Test + public void test_not_equal() { + Interaction interaction1 = Interaction.builder() + .id("id") + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parent id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); + assertEquals(interaction.equals(interaction1), false); + } + + @Test + public void test_Equal() { + Interaction interaction1 = Interaction.builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .parentInteractionId("parent id") + .traceNum(1) + .build(); + assertEquals(interaction.equals(interaction1), true); + } + + @Test + public void test_toString() { + Interaction interaction1 = Interaction.builder() + .id("id") + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parent id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); + assertEquals("Interaction{id=id,cid=conversation id,create_time=null,origin=amazon bedrock,input=null,promt_template=null,response=null,additional_info={suggestion=new suggestion},parentInteractionId=parent id,traceNum=1}", interaction1.toString()); + } + + @Test + public void test_ParentInteraction() { + Interaction parentInteraction = Interaction.builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .build(); + assertEquals("Interaction{id=test-interaction-id,cid=conversation-id,create_time=1970-01-01T00:00:00.123Z,origin=amazon bedrock,input=sample inputs,promt_template=some prompt template,response=sample responses,additional_info={suggestion=new suggestion},parentInteractionId=null,traceNum=null}", parentInteraction.toString()); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 42cece3f2e..0a439fe7e0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -18,9 +18,11 @@ package org.opensearch.ml.memory; import java.util.List; +import java.util.Map; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -58,6 +60,14 @@ public interface ConversationalMemoryHandler { */ public ActionFuture createConversation(String name); + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, String applicationType, ActionListener listener); + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to @@ -74,7 +84,7 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ); @@ -94,7 +104,31 @@ public ActionFuture createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo + ); + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param interactionId the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + ActionListener listener, + String interactionId, + Integer traceNumber ); /** @@ -120,6 +154,15 @@ public ActionFuture createInteraction( */ public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener); + /** + * Get the traces associate with this interaction, sorted by recency + * @param interactionId the interaction whose traces to get + * @param from where to start listing from + * @param maxResults how many traces to get + * @param listener gets the list of traces in this conversation, sorted by recency + */ + public void getTraces(String interactionId, int from, int maxResults, ActionListener> listener); + /** * Get the interactions associate with this conversation, sorted by recency * @param conversationId the conversation whose interactions to get @@ -203,6 +246,13 @@ public ActionFuture createInteraction( */ public ActionFuture searchInteractions(String conversationId, SearchRequest request); + /** + * Update a conversation + * @param updateContent update content for the conversations index + * @param listener receives the update response + */ + public void updateConversation(String conversationId, Map updateContent, ActionListener listener); + /** * Get a single ConversationMeta object * @param conversationId id of the conversation to get diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index e0a03f13eb..c65c1b581b 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.action.conversation; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; + import java.io.IOException; import java.util.Map; @@ -35,6 +37,8 @@ public class CreateConversationRequest extends ActionRequest { @Getter private String name = null; + @Getter + private String applicationType = null; /** * Constructor @@ -44,6 +48,7 @@ public class CreateConversationRequest extends ActionRequest { public CreateConversationRequest(StreamInput in) throws IOException { super(in); this.name = in.readOptionalString(); + this.applicationType = in.readOptionalString(); } /** @@ -55,6 +60,16 @@ public CreateConversationRequest(String name) { this.name = name; } + /** + * Constructor + * @param name name of the conversation + */ + public CreateConversationRequest(String name, String applicationType) { + super(); + this.name = name; + this.applicationType = applicationType; + } + /** * Constructor * name will be null @@ -65,6 +80,7 @@ public CreateConversationRequest() {} public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(name); + out.writeOptionalString(applicationType); } @Override @@ -86,7 +102,10 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } Map body = restRequest.contentParser().mapStrings(); if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) { - return new CreateConversationRequest(body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)); + return new CreateConversationRequest( + body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), + body.get(APPLICATION_TYPE_FIELD) + ); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java index f6856b7c66..c9c26c6e20 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -82,6 +82,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis return; } String name = request.getName(); + String applicationType = request.getApplicationType(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> { @@ -92,7 +93,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis if (name == null) { cmHandler.createConversation(al); } else { - cmHandler.createConversation(name, al); + cmHandler.createConversation(name, applicationType, al); } } catch (Exception e) { log.error("Failed to create new conversation with name " + request.getName(), e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index 52344b3792..5f9f4a8128 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -18,14 +18,17 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.rest.RestRequest; @@ -48,7 +51,27 @@ public class CreateInteractionRequest extends ActionRequest { @Getter private String origin; @Getter - private String additionalInfo; + private Map additionalInfo; + @Getter + private String parentIid; + @Getter + private Integer traceNumber; + + public CreateInteractionRequest( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo + ) { + this.conversationId = conversationId; + this.input = input; + this.promptTemplate = promptTemplate; + this.response = response; + this.origin = origin; + this.additionalInfo = additionalInfo; + } /** * Constructor @@ -62,7 +85,11 @@ public CreateInteractionRequest(StreamInput in) throws IOException { this.promptTemplate = in.readString(); this.response = in.readString(); this.origin = in.readOptionalString(); - this.additionalInfo = in.readOptionalString(); + if (in.readBoolean()) { + this.additionalInfo = in.readMap(s -> s.readString(), s -> s.readString()); + } + this.parentIid = in.readOptionalString(); + this.traceNumber = in.readOptionalInt(); } @Override @@ -73,7 +100,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(promptTemplate); out.writeString(response); out.writeOptionalString(origin); - out.writeOptionalString(additionalInfo); + if (additionalInfo != null) { + out.writeBoolean(true); + out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parentIid); + out.writeOptionalInt(traceNumber); } @Override @@ -92,14 +126,52 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong reading from request */ public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException { - Map body = request.contentParser().mapStrings(); String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); - String inp = body.get(ActionConstants.INPUT_FIELD); - String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD); - String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD); - String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD); - String addinf = body.get(ActionConstants.ADDITIONAL_INFO_FIELD); - return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, addinf); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + String input = null; + String prompt = null; + String response = null; + String origin = null; + Map addinf = new HashMap<>(); + String parintid = null; + Integer tracenum = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ActionConstants.INPUT_FIELD: + input = parser.text(); + break; + case ActionConstants.PROMPT_TEMPLATE_FIELD: + prompt = parser.text(); + break; + case ActionConstants.AI_RESPONSE_FIELD: + response = parser.text(); + break; + case ActionConstants.RESPONSE_ORIGIN_FIELD: + origin = parser.text(); + break; + case ActionConstants.ADDITIONAL_INFO_FIELD: + addinf = parser.mapStrings(); + break; + case ActionConstants.PARENT_INTERACTION_ID_FIELD: + parintid = parser.text(); + break; + case ActionConstants.TRACE_NUMBER_FIELD: + tracenum = parser.intValue(false); + break; + default: + parser.skipChildren(); + break; + } + } + + return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java index 2273cc32e8..f5910119fa 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.action.conversation; +import java.util.Map; + import org.opensearch.OpenSearchException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -86,14 +88,20 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList String rsp = request.getResponse(); String ogn = request.getOrigin(); String prompt = request.getPromptTemplate(); - String additionalInfo = request.getAdditionalInfo(); + Map additionalInfo = request.getAdditionalInfo(); + String parintIid = request.getParentIid(); + Integer traceNumber = request.getTraceNumber(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener .wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> { internalListener.onFailure(e); }); - cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); + if (parintIid == null || traceNumber == null) { + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); + } else { + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al, parintIid, traceNumber); + } } catch (Exception e) { log.error("Failed to create interaction for conversation " + cid, e); actionListener.onFailure(e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 47c55ac1e7..87c786c229 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -39,6 +39,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.cluster.service.ClusterService; @@ -71,7 +73,7 @@ public class ConversationMetaIndex { private Client client; private ClusterService clusterService; - private String userstr() { + private String getUserStrFromThreadContext() { return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); } @@ -119,21 +121,27 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) /** * Adds a new conversation with the specified name to the index * @param name user-specified name of the conversation to be added + * @param applicationType the application type that creates this conversation * @param listener listener to wait for this to finish */ - public void createConversation(String name, ActionListener listener) { + public void createConversation(String name, String applicationType, ActionListener listener) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = userstr(); + String userstr = getUserStrFromThreadContext(); + Instant now = Instant.now(); IndexRequest request = Requests .indexRequest(META_INDEX_NAME) .source( - ConversationalIndexConstants.META_CREATED_FIELD, - Instant.now(), + ConversationalIndexConstants.META_CREATED_TIME_FIELD, + now, + ConversationalIndexConstants.META_UPDATED_TIME_FIELD, + now, ConversationalIndexConstants.META_NAME_FIELD, name, ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName() + userstr == null ? null : User.parse(userstr).getName(), + ConversationalIndexConstants.APPLICATION_TYPE_FIELD, + applicationType ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -163,7 +171,16 @@ public void createConversation(String name, ActionListener listener) { * @param listener listener to wait for this to finish */ public void createConversation(ActionListener listener) { - createConversation("", listener); + createConversation("", "", listener); + } + + /** + * Adds a new conversation named "" + * @param name user-specified name of the conversation to be added + * @param listener listener to wait for this to finish + */ + public void createConversation(String name, ActionListener listener) { + createConversation(name, "", listener); } /** @@ -175,10 +192,9 @@ public void createConversation(ActionListener listener) { public void getConversations(int from, int maxResults, ActionListener> listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); - return; } SearchRequest request = Requests.searchRequest(META_INDEX_NAME); - String userstr = userstr(); + String userstr = getUserStrFromThreadContext(); QueryBuilder queryBuilder; if (userstr == null) queryBuilder = new MatchAllQueryBuilder(); @@ -186,7 +202,7 @@ public void getConversations(int from, int maxResults, ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(searchResponse -> { @@ -228,10 +244,9 @@ public void getConversations(int maxResults, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); - return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); - String userstr = userstr(); + String userstr = getUserStrFromThreadContext(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -272,7 +287,7 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - String userstr = userstr(); + String userstr = getUserStrFromThreadContext(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -317,7 +332,7 @@ public void searchConversations(SearchRequest request, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME) + ); + return; + } + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + client.update(updateRequest, internalListener); + } catch (Exception e) { + log.error("Failed to update Conversation. Details {}:", e); + listener.onFailure(e); + } + } + /** * Get a single ConversationMeta object * @param conversationId id of the conversation to get @@ -349,7 +386,7 @@ public void getConversation(String conversationId, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index bd4eb1e39a..edf4d827d1 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -23,6 +23,7 @@ import java.time.Instant; import java.util.LinkedList; import java.util.List; +import java.util.Map; import org.opensearch.OpenSearchSecurityException; import org.opensearch.OpenSearchWrapperException; @@ -48,12 +49,15 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; import com.google.common.annotations.VisibleForTesting; @@ -74,10 +78,6 @@ public class InteractionsIndex { // How big the steps should be when gathering *ALL* interactions in a conversation private final int resultsAtATime = 300; - private String userstr() { - return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - } - /** * 'PUT's the index in opensearch if it's not there already * @param listener gets whether the index needed to be initialized. Throws error if it fails to init @@ -130,6 +130,8 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened + * @param parintid the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ public void createInteraction( @@ -138,12 +140,17 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, Instant timestamp, - ActionListener listener + ActionListener listener, + String parintid, + Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = userstr(); + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { @@ -164,7 +171,11 @@ public void createInteraction( ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo, ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp + timestamp, + ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, + parintid, + ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, + traceNumber ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -189,6 +200,30 @@ public void createInteraction( }, e -> { listener.onFailure(e); })); } + /** + * Add an interaction to this index. Return the ID of the newly created interaction + * @param conversationId The id of the conversation this interaction belongs to + * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the GenAI response for this interaction + * @param origin the origin of the response for this interaction + * @param additionalInfo additional information used for constructing the LLM prompt + * @param timestamp when this interaction happened + * @param listener gets the id of the newly created interaction record + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + Instant timestamp, + ActionListener listener + ) { + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, timestamp, listener, null, null); + } + /** * Add an interaction to this index, timestamped now. Return the id of the newly created interaction * @param conversationId The id of the converation this interaction belongs to @@ -205,10 +240,10 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ) { - createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, Instant.now(), listener); + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, Instant.now(), listener, null, null); } /** @@ -241,10 +276,26 @@ public void getInteractions(String conversationId, int from, int maxResults, Act @VisibleForTesting void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); - TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); - request.source().query(builder); + + // Build the query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + + // Add the ExistsQueryBuilder for checking null values + ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); + boolQueryBuilder.mustNot(existsQueryBuilder); + + // Add the TermQueryBuilder for another field + TermQueryBuilder termQueryBuilder = QueryBuilders + .termQuery(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + boolQueryBuilder.must(termQueryBuilder); + + // Set the query to the search source + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQueryBuilder); + + request.source(searchSourceBuilder); request.source().from(from).size(maxResults); - request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.DESC); + request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.ASC); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(response -> { @@ -265,6 +316,51 @@ void innerGetInteractions(String conversationId, int from, int maxResults, Actio } } + /** + * Gets a list of interactions belonging to a conversation + * @param interactionId the interaction to read from + * @param from where to start in the reading + * @param maxResults how many interactions to return + * @param listener gets the list, sorted by recency, of interactions + */ + public void getTraces(String interactionId, int from, int maxResults, ActionListener> listener) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener.onResponse(List.of()); + return; + } + SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); + // Build the query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + + // Add the ExistsQueryBuilder for checking null values + ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); + boolQueryBuilder.must(existsQueryBuilder); + + // Add the TermQueryBuilder for another field + TermQueryBuilder termQueryBuilder = QueryBuilders + .termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, interactionId); + boolQueryBuilder.must(termQueryBuilder); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQueryBuilder); + + request.source(searchSourceBuilder); + request.source().from(from).size(maxResults); + request.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(response -> { + List result = new LinkedList(); + for (SearchHit hit : response.getHits()) { + result.add(Interaction.fromSearchHit(hit)); + } + internalListener.onResponse(result); + }, e -> { internalListener.onFailure(e); }); + client.search(request, al); + } catch (Exception e) { + listener.onFailure(e); + } + } + /** * Gets all of the interactions in a conversation, regardless of conversation size * @param conversationId conversation to get all interactions of @@ -321,7 +417,7 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - String userstr = userstr(); + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -381,7 +477,10 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = userstr(); + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } @@ -431,7 +530,10 @@ public void getInteraction(String conversationId, String interactionId, ActionLi listener.onFailure(e); } } else { - String userstr = userstr(); + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index c1997be829..74ba94c88c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -19,16 +19,20 @@ import java.time.Instant; import java.util.List; +import java.util.Map; import org.opensearch.action.StepListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -89,6 +93,16 @@ public void createConversation(String name, ActionListener listener) { conversationMetaIndex.createConversation(name, listener); } + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, String applicationType, ActionListener listener) { + conversationMetaIndex.createConversation(name, applicationType, listener); + } + /** * Create a new conversation * @param name the name of the new conversation @@ -116,13 +130,52 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ) { Instant time = Instant.now(); interactionsIndex.createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, time, listener); } + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param interactionId the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + ActionListener listener, + String interactionId, + Integer traceNumber + ) { + Instant time = Instant.now(); + interactionsIndex + .createInteraction( + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + time, + listener, + interactionId, + traceNumber + ); + } + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to @@ -139,7 +192,7 @@ public ActionFuture createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo ) { PlainActionFuture fut = PlainActionFuture.newFuture(); createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, fut); @@ -330,6 +383,20 @@ public ActionFuture searchInteractions(String conversationId, Se return fut; } + public void getTraces(String interactionId, int from, int maxResults, ActionListener> listener) { + interactionsIndex.getTraces(interactionId, from, maxResults, listener); + } + + public void updateConversation(String conversationId, Map updateContent, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); + updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); + + updateRequest.doc(updateContent); + updateRequest.docAsUpsert(true); + + conversationMetaIndex.updateConversation(updateRequest, listener); + } + /** * Get a single ConversationMeta object * @param conversationId id of the conversation to get diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java index 6ee0d4cc31..5449f91e18 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java @@ -17,6 +17,7 @@ */ package org.opensearch.ml.memory; +import java.util.Collections; import java.util.List; import java.util.Stack; import java.util.concurrent.CountDownLatch; @@ -110,7 +111,16 @@ public void testCanAddNewInteractionsToConversation() { StepListener iid1Listener = new StepListener<>(); cidListener.whenComplete(cid -> { - cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -118,7 +128,16 @@ public void testCanAddNewInteractionsToConversation() { StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + cmHandler + .createInteraction( + cidListener.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -144,7 +163,16 @@ public void testCanGetInteractionsBackOut() { StepListener iid1Listener = new StepListener<>(); cidListener.whenComplete(cid -> { - cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -152,7 +180,16 @@ public void testCanGetInteractionsBackOut() { StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + cmHandler + .createInteraction( + cidListener.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -170,8 +207,8 @@ public void testCanGetInteractionsBackOut() { String id2 = iid2Listener.result(); String cid = cidListener.result(); assert (interactions.size() == 2); - assert (interactions.get(0).getId().equals(id2)); - assert (interactions.get(1).getId().equals(id1)); + assert (interactions.get(0).getId().equals(id1)); + assert (interactions.get(1).getId().equals(id2)); assert (conversations.size() == 1); assert (conversations.get(0).getId().equals(cid)); }, e -> { assert (false); }), cdl); @@ -195,24 +232,38 @@ public void testCanDeleteConversations() { cmHandler.createConversation("test", cid1); StepListener iid1 = new StepListener<>(); - cid1 - .whenComplete( - cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + cid1.whenComplete(cid -> { + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - iid -> { cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid2); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + iid1.whenComplete(iid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener cid2 = new StepListener<>(); iid2.whenComplete(iid -> { cmHandler.createConversation(cid2); }, e -> { @@ -221,14 +272,21 @@ public void testCanDeleteConversations() { }); StepListener iid3 = new StepListener<>(); - cid2 - .whenComplete( - cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid3); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + cid2.whenComplete(cid -> { + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid3 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener del = new StepListener<>(); iid3.whenComplete(iid -> { cmHandler.deleteConversation(cid1.result(), del); }, e -> { @@ -328,59 +386,102 @@ public void testDifferentUsers_DifferentConversations() { cid1.whenComplete(cid -> { cmHandler.createConversation("conversation2", cid2); }, onFail); - cid2 - .whenComplete( - cid -> { - cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid1); - }, - onFail - ); + cid2.whenComplete(cid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1 + ); + }, onFail); - iid1 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid1.result(), "test input2", "pt", "test response", "test origin", "meta", iid2); - }, - onFail - ); + iid1.whenComplete(iid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input2", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2 + ); + }, onFail); - iid2 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid2.result(), "test input3", "pt", "test response", "test origin", "meta", iid3); - }, - onFail - ); + iid2.whenComplete(iid -> { + cmHandler + .createInteraction( + cid2.result(), + "test input3", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid3 + ); + }, onFail); iid3.whenComplete(iid -> { contextStack.push(setUser(user2)); cmHandler.createConversation("conversation3", cid3); }, onFail); - cid3 - .whenComplete( - cid -> { - cmHandler.createInteraction(cid3.result(), "test input4", "pt", "test response", "test origin", "meta", iid4); - }, - onFail - ); + cid3.whenComplete(cid -> { + cmHandler + .createInteraction( + cid3.result(), + "test input4", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid4 + ); + }, onFail); - iid4 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid3.result(), "test input5", "pt", "test response", "test origin", "meta", iid5); - }, - onFail - ); + iid4.whenComplete(iid -> { + cmHandler + .createInteraction( + cid3.result(), + "test input5", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid5 + ); + }, onFail); iid5.whenComplete(iid -> { - cmHandler.createInteraction(cid1.result(), "test inputf1", "pt", "test response", "test origin", "meta", failiid1); + cmHandler + .createInteraction( + cid1.result(), + "test inputf1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid1 + ); }, onFail); failiid1.whenComplete(shouldHaveFailedAsString, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid1.result(), "test inputf2", "pt", "test response", "test origin", "meta", failiid2); + cmHandler + .createInteraction( + cid1.result(), + "test inputf2", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid2 + ); } else { onFail.accept(e); } @@ -403,8 +504,8 @@ public void testDifferentUsers_DifferentConversations() { inter3.whenComplete(inters -> { assert (inters.size() == 2); - assert (inters.get(0).getId().equals(iid5.result())); - assert (inters.get(1).getId().equals(iid4.result())); + assert (inters.get(0).getId().equals(iid4.result())); + assert (inters.get(1).getId().equals(iid5.result())); cmHandler.getInteractions(cid2.result(), 0, 10, failInter2); }, onFail); @@ -436,8 +537,8 @@ public void testDifferentUsers_DifferentConversations() { inter1.whenComplete(inters -> { assert (inters.size() == 2); - assert (inters.get(0).getId().equals(iid2.result())); - assert (inters.get(1).getId().equals(iid1.result())); + assert (inters.get(0).getId().equals(iid1.result())); + assert (inters.get(1).getId().equals(iid2.result())); cmHandler.getInteractions(cid2.result(), 0, 10, inter2); }, onFail); @@ -450,7 +551,16 @@ public void testDifferentUsers_DifferentConversations() { failInter3.whenComplete(shouldHaveFailedAsInterList, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user1 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid3.result(), "test inputf3", "pt", "test response", "test origin", "meta", failiid3); + cmHandler + .createInteraction( + cid3.result(), + "test inputf3", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid3 + ); } else { onFail.accept(e); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 22b55bb7c2..c1148438c3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.action.conversation; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; + import java.io.IOException; import java.util.Map; @@ -85,4 +87,17 @@ public void testNamedRestRequest() throws IOException { assert (request.getName().equals(name)); } + public void testNamedRestRequest_WithAppType() throws IOException { + String name = "test-name"; + String appType = "conversational-search"; + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, APPLICATION_TYPE_FIELD, appType))), + MediaTypeRegistry.JSON + ) + .build(); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + assert (request.getName().equals(name)); + assert (request.getApplicationType().equals(appType)); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java index 313071dc45..c2e4b16e65 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java @@ -31,6 +31,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -80,6 +81,7 @@ public class CreateConversationTransportActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { + MockitoAnnotations.openMocks(this); this.threadPool = Mockito.mock(ThreadPool.class); this.client = Mockito.mock(Client.class); this.clusterService = Mockito.mock(ClusterService.class); @@ -107,10 +109,10 @@ public void setup() throws IOException { public void testCreateConversation() { log.info("testing create conversation transport"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse("testID"); return null; - }).when(cmHandler).createConversation(any(), any()); + }).when(cmHandler).createConversation(any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); verify(actionListener).onResponse(argCaptor.capture()); @@ -133,10 +135,10 @@ public void testCreateConversationWithNullName() { public void testCreateConversationFails_thenFail() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new Exception("Testing Error")); return null; - }).when(cmHandler).createConversation(any(), any()); + }).when(cmHandler).createConversation(any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); @@ -144,7 +146,7 @@ public void testCreateConversationFails_thenFail() { } public void testDoExecuteFails_thenFail() { - doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any()); + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index cf027aef79..fae2984af9 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -18,6 +18,7 @@ package org.opensearch.ml.memory.action.conversation; import java.io.IOException; +import java.util.Collections; import java.util.Map; import org.junit.Before; @@ -47,7 +48,14 @@ public void setup() { } public void testConstructorsAndStreaming() throws IOException { - CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "pt", "response", "origin", "metadata"); + CreateInteractionRequest request = new CreateInteractionRequest( + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ); assert (request.validate() == null); assert (request.getConversationId().equals("cid")); assert (request.getInput().equals("input")); @@ -67,14 +75,21 @@ public void testConstructorsAndStreaming() throws IOException { } public void testNullCID_thenFail() { - CreateInteractionRequest request = new CreateInteractionRequest(null, "input", "pt", "response", "origin", "metadata"); + CreateInteractionRequest request = new CreateInteractionRequest( + null, + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ); assert (request.validate() != null); assert (request.validate().validationErrors().size() == 1); assert (request.validate().validationErrors().get(0).equals("Interaction MUST belong to a conversation ID")); } public void testFromRestRequest() throws IOException { - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -85,19 +100,57 @@ public void testFromRestRequest() throws IOException { ActionConstants.RESPONSE_ORIGIN_FIELD, "origin", ActionConstants.ADDITIONAL_INFO_FIELD, - "metadata" + Collections.singletonMap("metadata", "some meta") ); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) .build(); CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + assert (request.validate() == null); assert (request.getConversationId().equals("cid")); assert (request.getInput().equals("input")); assert (request.getPromptTemplate().equals("pt")); assert (request.getResponse().equals("response")); assert (request.getOrigin().equals("origin")); - assert (request.getAdditionalInfo().equals("metadata")); + assert (request.getAdditionalInfo().equals(Collections.singletonMap("metadata", "some meta"))); + } + + public void testFromRestRequest_Trace() throws IOException { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + Collections.singletonMap("metadata", "some meta"), + ActionConstants.PARENT_INTERACTION_ID_FIELD, + "parentId", + ActionConstants.TRACE_NUMBER_FIELD, + 1 + ); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "tid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + + assert (request.validate() == null); + assert (request.getConversationId().equals("tid")); + assert (request.getInput().equals("input")); + assert (request.getPromptTemplate().equals("pt")); + assert (request.getResponse().equals("response")); + assert (request.getOrigin().equals("origin")); + assert (request.getAdditionalInfo().equals(Collections.singletonMap("metadata", "some meta"))); + assert (request.getParentIid().equals("parentId")); + assert (request.getTraceNumber().equals(1)); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java index 8321a0b65e..eb8e4672ce 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import java.util.Set; import org.junit.Before; @@ -91,7 +92,14 @@ public void setup() throws IOException { this.actionListener = al; this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); - this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata"); + this.request = new CreateInteractionRequest( + "test-cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ); Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); @@ -118,6 +126,29 @@ public void testCreateInteraction() { assert (argCaptor.getValue().getId().equals("testID")); } + public void testCreateInteraction_Trace() { + CreateInteractionRequest createConversationRequest = new CreateInteractionRequest( + "test-cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onResponse("testID"); + return null; + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, createConversationRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID")); + } + public void testCreateInteractionFails_thenFail() { log.info("testing create interaction transport"); doAnswer(invocation -> { diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index 4b8f3a8fed..abb8d04de9 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -36,7 +36,7 @@ public class GetConversationResponseTests extends OpenSearchTestCase { public void testGetConversationResponseStreaming() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null); GetConversationResponse response = new GetConversationResponse(convo); assert (response.getConversation().equals(convo)); @@ -49,12 +49,16 @@ public void testGetConversationResponseStreaming() throws IOException { } public void testToXContent() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"name\":\"name\"}"; + String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + + convo.getCreatedTime() + + "\",\"updated_time\":\"" + + convo.getUpdatedTime() + + "\",\"name\":\"name\"}"; // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness LevenshteinDistance ld = new LevenshteinDistance(); assert (ld.getDistance(result, expected) > 0.95); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java index 3afcc1dd21..97ff87f63b 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -105,7 +105,7 @@ public void setup() throws IOException { } public void testGetConversation() { - ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), "name", null); + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java index e6ed013b7a..4d14e6f703 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -46,9 +46,9 @@ public class GetConversationsResponseTests extends OpenSearchTestCase { public void setup() { conversations = List .of( - new ConversationMeta("0", Instant.now(), "name0", "user0"), - new ConversationMeta("1", Instant.now(), "name1", "user0"), - new ConversationMeta("2", Instant.now(), "name2", "user2") + new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0"), + new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0"), + new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2") ); } @@ -75,6 +75,8 @@ public void testToXContent_MoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + conversation.getCreatedTime() + + "\"updated_time\":\"" + + conversation.getUpdatedTime() + "\",\"name\":\"name0\",\"user\":\"user0\"}],\"next_token\":2}"; log.info("FINDME"); log.info(result); @@ -93,6 +95,8 @@ public void testToXContent_NoMoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + conversation.getCreatedTime() + + "\"updated_time\":\"" + + conversation.getUpdatedTime() + "\",\"name\":\"name0\",\"user\":\"user0\"}]}"; log.info("FINDME"); log.info(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java index 41c99bdc74..130d39c5cb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -112,8 +112,8 @@ public void testGetConversations() { log.info("testing get conversations transport"); List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), "", null), - new ConversationMeta("testcid2", Instant.now(), "testname", null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); @@ -130,9 +130,9 @@ public void testGetConversations() { public void testPagination() { List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), "", null), - new ConversationMeta("testcid2", Instant.now(), "testname", null), - new ConversationMeta("testcid3", Instant.now(), "testname", null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null), + new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java index b7cbc1c471..5cd79afc4a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import org.apache.lucene.search.spell.LevenshteinDistance; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -36,7 +37,16 @@ public class GetInteractionResponseTests extends OpenSearchTestCase { public void testConstructorAndStreaming() throws IOException { - Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + Interaction interaction = new Interaction( + "iid", + Instant.now(), + "cid", + "inp", + "pt", + "rsp", + "ogn", + Collections.singletonMap("metadata", "some meta") + ); GetInteractionResponse response = new GetInteractionResponse(interaction); assert (response.getInteraction().equals(interaction)); @@ -49,14 +59,24 @@ public void testConstructorAndStreaming() throws IOException { } public void testToXContent() throws IOException { - Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + Interaction interaction = new Interaction( + "iid", + Instant.now(), + "cid", + "inp", + "pt", + "rsp", + "ogn", + Collections.singletonMap("metadata", "some meta") + ); GetInteractionResponse response = new GetInteractionResponse(interaction); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); + System.out.println(result); String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\"" + interaction.getCreateTime() - + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":\"extra\"}"; + + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":{\"metadata\":\"some meta\"}}"; // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness LevenshteinDistance ld = new LevenshteinDistance(); assert (ld.getDistance(result, expected) > 0.95); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java index 6ca8197b54..eca0a9251a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.Set; import org.junit.Before; @@ -112,7 +113,7 @@ public void testGetInteraction() { "pt", "test-response", "test-origin", - "metadata" + Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java index bbd17b2603..c1fdfbffac 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.List; import org.apache.lucene.search.spell.LevenshteinDistance; @@ -45,9 +46,36 @@ public class GetInteractionsResponseTests extends OpenSearchTestCase { public void setup() { interactions = List .of( - new Interaction("id0", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata"), - new Interaction("id1", Instant.now(), "cid", "input", "pt", "response", "origin", "mteadata"), - new Interaction("id2", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata") + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ), + new Interaction( + "id1", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ), + new Interaction( + "id2", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ) ); } @@ -74,7 +102,7 @@ public void testToXContent_MoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + interaction.getCreateTime() - + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":\"metadata\"}],\"next_token\":2}"; + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}],\"next_token\":2}"; log.info(result); log.info(expected); // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness @@ -91,7 +119,7 @@ public void testToXContent_NoMoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + interaction.getCreateTime() - + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":\"metadata\"}]}"; + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}]}"; log.info(result); log.info(expected); // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java index a7a245b680..7b4b62df15 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.List; import java.util.Set; @@ -118,7 +119,7 @@ public void testGetInteractions_noMorePages() { "pt", "test-response", "test-origin", - "metadata" + Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -145,7 +146,7 @@ public void testGetInteractions_MorePages() { "pt", "test-response", "test-origin", - "metadata" + Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index fc605e3fb0..fc8e7d0145 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -26,6 +26,7 @@ import java.util.function.Consumer; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; @@ -38,7 +39,7 @@ import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; -import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.search.builder.SearchSourceBuilder; @@ -435,7 +436,7 @@ public void testCanQueryOverConversations() { convo2.whenComplete(cid -> { SearchRequest request = new SearchRequest(); request.source(new SearchSourceBuilder()); - request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Henry Conversation")); + request.source().query(QueryBuilders.matchQuery(ConversationalIndexConstants.META_NAME_FIELD, "Henry Conversation")); index.searchConversations(request, search); }, e -> { cdl.countDown(); @@ -461,6 +462,7 @@ public void testCanQueryOverConversations() { } } + @Ignore // this IT is flaky, not working as expected public void testCanQueryOverConversationsSecurely() { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { CountDownLatch cdl = new CountDownLatch(1); @@ -492,7 +494,7 @@ public void testCanQueryOverConversationsSecurely() { convo2.whenComplete(cid -> { SearchRequest request = new SearchRequest(); request.source(new SearchSourceBuilder()); - request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Dhrubo Conversation")); + request.source().query(QueryBuilders.matchQuery(ConversationalIndexConstants.META_NAME_FIELD, "Dhrubo Conversation")); index.searchConversations(request, search1); }, onFail); @@ -500,7 +502,7 @@ public void testCanQueryOverConversationsSecurely() { search1.whenComplete(response -> { SearchRequest request = new SearchRequest(); request.source(new SearchSourceBuilder()); - request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Jing Conversation")); + request.source().query(QueryBuilders.matchQuery(ConversationalIndexConstants.META_NAME_FIELD, "Jing Conversation")); index.searchConversations(request, search2); }, onFail); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 5445fd6213..ef7b048a0a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -35,6 +35,7 @@ import org.opensearch.OpenSearchWrapperException; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.delete.DeleteResponse; @@ -42,6 +43,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -52,6 +55,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -628,12 +633,56 @@ public void testGetConversation_RefreshFails_ThenFail() { public void testGetConversation_ClientFails_ThenFail() { doReturn(true).when(metadata).hasIndex(anyString()); - doThrow(new RuntimeException("Clietn Failure")).when(client).admin(); + doThrow(new RuntimeException("Client Failure")).when(client).admin(); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); conversationMetaIndex.getConversation("tester_id", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); - assert (argCaptor.getValue().getMessage().equals("Clietn Failure")); + assert (argCaptor.getValue().getMessage().equals("Client Failure")); + } + + public void testUpdateConversation_NoIndex_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals( + "no such index [.plugins-ml-conversation-meta] and cannot update conversation since the conversation index does not exist" + )); + } + + public void testUpdateConversation_Success() { + doReturn(true).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + + doAnswer(invocation -> { + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(), any()); + conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(getListener, times(1)).onResponse(argCaptor.capture()); + } + + public void testUpdateConversation_ClientFails() { + doReturn(true).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + + doThrow(new RuntimeException("Client Failure")).when(client).update(any(), any()); + conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure")); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index 0c0791fb23..133c31971a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -20,6 +20,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -104,7 +105,7 @@ public void testCanAddNewInteraction() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), new LatchedActionListener<>(ActionListener.wrap(id -> { ids[0] = id; }, e -> { @@ -121,7 +122,7 @@ public void testCanAddNewInteraction() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), new LatchedActionListener<>(ActionListener.wrap(id -> { ids[1] = id; }, e -> { @@ -145,7 +146,16 @@ public void testGetInteractions() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + id1Listener + ); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -156,7 +166,7 @@ public void testGetInteractions() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -175,8 +185,8 @@ public void testGetInteractions() { LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(interactions -> { assert (interactions.size() == 2); - assert (interactions.get(0).getId().equals(id2Listener.result())); - assert (interactions.get(1).getId().equals(id1Listener.result())); + assert (interactions.get(0).getId().equals(id1Listener.result())); + assert (interactions.get(1).getId().equals(id2Listener.result())); }, e -> { log.error(e); assert (false); @@ -194,7 +204,16 @@ public void testGetInteractionPages() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + id1Listener + ); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -205,7 +224,7 @@ public void testGetInteractionPages() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -224,7 +243,7 @@ public void testGetInteractionPages() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(4, ChronoUnit.MINUTES), id3Listener ); @@ -255,9 +274,9 @@ public void testGetInteractionPages() { String id3 = id3Listener.result(); assert (interactions2.size() == 1); assert (interactions1.size() == 2); - assert (interactions1.get(0).getId().equals(id3)); + assert (interactions1.get(0).getId().equals(id1)); assert (interactions1.get(1).getId().equals(id2)); - assert (interactions2.get(0).getId().equals(id1)); + assert (interactions2.get(0).getId().equals(id3)); }, e -> { log.error(e); assert (false); @@ -276,40 +295,70 @@ public void testDeleteConversation() { final String conversation2 = "conversation2"; CountDownLatch cdl = new CountDownLatch(1); StepListener iid1 = new StepListener<>(); - index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid1); + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid1 + ); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid2); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid1.whenComplete(r -> { + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid2 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid3 = new StepListener<>(); - iid2 - .whenComplete( - r -> { index.createInteraction(conversation2, "test input", "pt", "test response", "test origin", "metadata", iid3); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid2.whenComplete(r -> { + index + .createInteraction( + conversation2, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid3 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid4 = new StepListener<>(); - iid3 - .whenComplete( - r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid4); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid3.whenComplete(r -> { + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid4 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener deleteListener = new StepListener<>(); iid4.whenComplete(r -> { index.deleteConversation(conversation1, deleteListener); }, e -> { @@ -368,7 +417,7 @@ public void testSearchInteractions() { "pt", "response about fish", "origin1", - "lots of information about fish", + Collections.singletonMap("metadata", "lots of information about fish"), iid1 ); @@ -381,7 +430,7 @@ public void testSearchInteractions() { "pt", "response about squash", "origin1", - "lots of information about squash", + Collections.singletonMap("metadata", "lots of information about fish"), iid2 ); }, e -> { @@ -399,7 +448,7 @@ public void testSearchInteractions() { "pt2", "response about fish", "origin1", - "lots of information about fish", + Collections.singletonMap("metadata", "lots of information about fish"), iid3 ); }, e -> { @@ -417,7 +466,7 @@ public void testSearchInteractions() { "pt", "response about france", "origin1", - "lots of information about france", + Collections.singletonMap("metadata", "lots of information about france"), iid4 ); }, e -> { @@ -466,18 +515,34 @@ public void testGetInteractionById() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener iid1 = new StepListener<>(); - index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", iid1); + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid1 + ); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - iid -> { index.createInteraction(conversation, "test input2", "pt", "test response", "test origin", "metadata", iid2); }, - e -> { - cdl.countDown(); - log.error(e); - assert false; - } - ); + iid1.whenComplete(iid -> { + index + .createInteraction( + conversation, + "test input2", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid2 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert false; + }); StepListener get1 = new StepListener<>(); iid2.whenComplete(iid -> { index.getInteraction(conversation, iid1.result(), get1); }, e -> { diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 2d4184eec3..70743aa9f3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.verify; import java.time.Instant; +import java.util.Collections; import java.util.List; import org.junit.Before; @@ -44,6 +45,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -52,12 +55,19 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -250,7 +260,8 @@ public void testCreate_NoIndex_ThenFail() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("no index to add conversation to")); @@ -268,7 +279,8 @@ public void testCreate_BadRestStatus_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Failed to create interaction")); @@ -284,7 +296,8 @@ public void testCreate_InternalFailure_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Failure")); @@ -296,7 +309,8 @@ public void testCreate_ClientFails_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); @@ -308,7 +322,8 @@ public void testCreate_NoAccessNoUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -323,7 +338,8 @@ public void testCreate_NoAccessWithUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation cid")); @@ -337,7 +353,8 @@ public void testCreate_CreateIndexFails_ThenFail() { }).when(interactionsIndex).initInteractionsIndexIfAbsent(any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Fail in Index Creation")); @@ -413,6 +430,73 @@ public void testGet_NoAccessNoUser_ThenFail() { .equals("User [" + ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + "] does not have access to conversation cid")); } + public void testGetTraces_NoIndex_ThenEmpty() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getTracesListener = mock(ActionListener.class); + interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 0); + } + + public void testGetTraces() { + doAnswer(invocation -> { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); + SearchHits searchHits = new SearchHits(hits, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + ActionListener al = invocation.getArgument(1); + al.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + doReturn(true).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getTracesListener = mock(ActionListener.class); + interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 1); + } + + public void testGetTraces_clientFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Client Failure")).when(client).search(any(), any()); + ActionListener> getTracesListener = mock(ActionListener.class); + interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getTracesListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure")); + } + public void testGetAll_BadMaxResults_ThenFail() { @SuppressWarnings("unchecked") ActionListener> getInteractionsListener = mock(ActionListener.class); @@ -425,10 +509,10 @@ public void testGetAll_BadMaxResults_ThenFail() { public void testGetAll_Recursion() { List interactions = List .of( - new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta") + new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")) ); doAnswer(invocation -> { ActionListener> al = invocation.getArgument(3); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index c8df948bcb..a979505a52 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -26,15 +26,21 @@ import static org.mockito.Mockito.verify; import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder; @@ -82,10 +88,9 @@ public void testCreateInteraction_Future() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); - ActionFuture result = cmHandler.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta"); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); + ActionFuture result = cmHandler + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")); assert (result.actionGet(200).equals("iid")); } @@ -94,9 +99,7 @@ public void testCreateInteraction_FromBuilder_Success() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); InteractionBuilder builder = Interaction .builder() .conversationId("cid") @@ -104,7 +107,7 @@ public void testCreateInteraction_FromBuilder_Success() { .origin("origin") .response("rsp") .promptTemplate("pt") - .additionalInfo("meta"); + .additionalInfo(Collections.singletonMap("meta", "some meta")); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); cmHandler.createInteraction(builder, createInteractionListener); @@ -118,9 +121,7 @@ public void testCreateInteraction_FromBuilder_Future() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); InteractionBuilder builder = Interaction .builder() .origin("ogn") @@ -128,7 +129,7 @@ public void testCreateInteraction_FromBuilder_Future() { .input("inp") .response("rsp") .promptTemplate("pt") - .additionalInfo("meta"); + .additionalInfo(Collections.singletonMap("meta", "some meta")); ActionFuture result = cmHandler.createInteraction(builder); assert (result.actionGet(200).equals("iid")); } @@ -163,6 +164,34 @@ public void testGetConversations_Page_Future() { assert (result.actionGet(200).size() == 0); } + public void testGetTraces() { + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onResponse(List.of()); + return null; + }).when(interactionsIndex).getTraces(any(), anyInt(), anyInt(), any()); + ActionListener> getTracesListener = mock(ActionListener.class); + cmHandler.getTraces("iId", 0, 10, getTracesListener); + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 0); + } + + public void testUpdateConversation() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + al.onResponse(updateResponse); + return null; + }).when(conversationMetaIndex).updateConversation(any(), any()); + + ActionListener updateConversationListener = mock(ActionListener.class); + cmHandler.updateConversation("cId", new HashMap<>(), updateConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(updateConversationListener, times(1)).onResponse(argCaptor.capture()); + } + public void testDelete_NoAccess() { doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); @@ -271,7 +300,7 @@ public void testSearchInteractions_Future() { } public void testGetAConversation_Future() { - ConversationMeta response = new ConversationMeta("cid", Instant.now(), "boring name", null); + ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(response); @@ -282,7 +311,16 @@ public void testGetAConversation_Future() { } public void testGetAnInteraction_Future() { - Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + Interaction interaction = new Interaction( + "iid", + Instant.now(), + "cid", + "inp", + "pt", + "rsp", + "ogn", + Collections.singletonMap("meta", "some meta") + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onResponse(interaction); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java index ced83f730a..3212f81d36 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -61,7 +62,7 @@ public void testBasics() { } public void testPrepareRequest() throws Exception { - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -72,7 +73,7 @@ public void testPrepareRequest() throws Exception { ActionConstants.RESPONSE_ORIGIN_FIELD, "origin", ActionConstants.ADDITIONAL_INFO_FIELD, - "metadata" + Collections.singletonMap("metadata", "some meta") ); RestMemoryCreateInteractionAction action = new RestMemoryCreateInteractionAction(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) @@ -92,6 +93,6 @@ public void testPrepareRequest() throws Exception { assert (req.getPromptTemplate().equals("pt")); assert (req.getResponse().equals("response")); assert (req.getOrigin().equals("origin")); - assert (req.getAdditionalInfo().equals("metadata")); + assert (req.getAdditionalInfo().equals(Collections.singletonMap("metadata", "some meta"))); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java index 691195a99b..da196ad7d8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java @@ -18,6 +18,7 @@ package org.opensearch.ml.rest; import java.io.IOException; +import java.util.Collections; import java.util.Map; import org.apache.hc.core5.http.HttpEntity; @@ -66,7 +67,7 @@ public void testGetInteraction() throws IOException { assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); String cid = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -77,7 +78,7 @@ public void testGetInteraction() throws IOException { ActionConstants.PROMPT_TEMPLATE_FIELD, "promtp template", ActionConstants.ADDITIONAL_INFO_FIELD, - "some metadata" + Collections.singletonMap("metadata", "some metadata") ); Response ciresponse = TestHelper .makeRequest( @@ -111,7 +112,7 @@ public void testGetInteraction() throws IOException { HttpEntity gihttpEntity = giresponse.getEntity(); String gientityString = TestHelper.httpEntityToString(gihttpEntity); @SuppressWarnings("unchecked") - Map gimap = gson.fromJson(gientityString, Map.class); + Map gimap = gson.fromJson(gientityString, Map.class); assert (gimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD) && gimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD).equals(iid)); assert (gimap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gimap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(cid)); @@ -122,6 +123,6 @@ public void testGetInteraction() throws IOException { assert (gimap.containsKey(ActionConstants.RESPONSE_ORIGIN_FIELD) && gimap.get(ActionConstants.RESPONSE_ORIGIN_FIELD).equals("origin")); assert (gimap.containsKey(ActionConstants.ADDITIONAL_INFO_FIELD) - && gimap.get(ActionConstants.ADDITIONAL_INFO_FIELD).equals("some metadata")); + && gimap.get(ActionConstants.ADDITIONAL_INFO_FIELD).equals(Collections.singletonMap("metadata", "some metadata"))); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java index eeba9e4aab..1c37662218 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.Map; import org.apache.hc.core5.http.HttpEntity; @@ -101,7 +102,7 @@ public void testGetInteractions_LastPage() throws IOException { assert (ccmap.containsKey("conversation_id")); String cid = (String) ccmap.get("conversation_id"); - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -112,7 +113,7 @@ public void testGetInteractions_LastPage() throws IOException { ActionConstants.PROMPT_TEMPLATE_FIELD, "promtp template", ActionConstants.ADDITIONAL_INFO_FIELD, - "some metadata" + Collections.singletonMap("meta data", "some meta") ); Response response = TestHelper .makeRequest( @@ -156,7 +157,7 @@ public void testGetInteractions_MorePages() throws IOException { assert (ccmap.containsKey("conversation_id")); String cid = (String) ccmap.get("conversation_id"); - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -167,7 +168,7 @@ public void testGetInteractions_MorePages() throws IOException { ActionConstants.PROMPT_TEMPLATE_FIELD, "promtp template", ActionConstants.ADDITIONAL_INFO_FIELD, - "some metadata" + Collections.singletonMap("meta data", "some meta") ); Response response = TestHelper .makeRequest( @@ -219,7 +220,7 @@ public void testGetInteractions_NextPage() throws IOException { assert (ccmap.containsKey("conversation_id")); String cid = (String) ccmap.get("conversation_id"); - Map params = Map + Map params = Map .of( ActionConstants.INPUT_FIELD, "input", @@ -230,7 +231,7 @@ public void testGetInteractions_NextPage() throws IOException { ActionConstants.PROMPT_TEMPLATE_FIELD, "promtp template", ActionConstants.ADDITIONAL_INFO_FIELD, - "some metadata" + Collections.singletonMap("meta data", "some meta") ); Response response = TestHelper .makeRequest( @@ -285,7 +286,7 @@ public void testGetInteractions_NextPage() throws IOException { assert (((ArrayList) map1.get("interactions")).size() == 1); @SuppressWarnings("unchecked") ArrayList interactions = (ArrayList) map1.get("interactions"); - assert (((String) interactions.get(0).get("interaction_id")).equals(iid2)); + assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); assert (((Double) map1.get("next_token")).intValue() == 1); Response response3 = TestHelper @@ -307,7 +308,7 @@ public void testGetInteractions_NextPage() throws IOException { assert (((ArrayList) map3.get("interactions")).size() == 1); @SuppressWarnings("unchecked") ArrayList interactions3 = (ArrayList) map3.get("interactions"); - assert (((String) interactions3.get(0).get("interaction_id")).equals(iid)); + assert (((String) interactions3.get(0).get("interaction_id")).equals(iid2)); assert (((Double) map3.get("next_token")).intValue() == 2); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 111437ab0f..09d9b0d896 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -171,7 +171,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp PromptUtil.getPromptTemplate(systemPrompt, userInstructions), answer, GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - jsonArrayToString(searchResults) + Collections.singletonMap("metadata", jsonArrayToString(searchResults)) ); log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index eca29b3914..cb94b75748 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -67,7 +68,7 @@ public String createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo ) { Preconditions.checkNotNull(conversationId); Preconditions.checkNotNull(input); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index b62f0ab38f..0e97049e40 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import java.time.Instant; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -149,7 +150,21 @@ public void testProcessResponse() throws Exception { ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + .thenReturn( + List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ) + ); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -209,7 +224,21 @@ public void testProcessResponseSmallerContextSize() throws Exception { ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + .thenReturn( + List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ) + ); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -270,7 +299,21 @@ public void testProcessResponseMissingContextField() throws Exception { ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + .thenReturn( + List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ) + ); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java index 7241ba40ed..43c84d4aec 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java @@ -21,6 +21,7 @@ import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.stream.IntStream; @@ -182,7 +183,8 @@ public void testCreateInteraction() { ActionFuture future = mock(ActionFuture.class); when(future.actionGet(anyLong())).thenReturn(res); when(client.execute(eq(CreateInteractionAction.INSTANCE), any())).thenReturn(future); - String actual = memoryClient.createInteraction("cid", "input", "prompt", "answer", "origin", "hits"); + String actual = memoryClient + .createInteraction("cid", "input", "prompt", "answer", "origin", Collections.singletonMap("metadata", "hits")); assertEquals(id, actual); } } From 3949ef9efb87b4a81f9e358fa1400295dec76f8c Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 13 Dec 2023 06:33:27 -0800 Subject: [PATCH 2/7] agent meta classes in common (#1757) * agent meta classes Signed-off-by: Jing Zhang * address comments Signed-off-by: Jing Zhang * add one more UT Signed-off-by: Jing Zhang * add more UT Signed-off-by: Jing Zhang * more UT Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 116 ++++++++ .../opensearch/ml/common/agent/LLMSpec.java | 102 +++++++ .../opensearch/ml/common/agent/MLAgent.java | 264 ++++++++++++++++++ .../ml/common/agent/MLMemorySpec.java | 108 +++++++ .../ml/common/agent/MLToolSpec.java | 143 ++++++++++ .../ml/common/agent/LLMSpecTest.java | 91 ++++++ .../ml/common/agent/MLAgentTest.java | 143 ++++++++++ .../ml/common/agent/MLMemorySpecTest.java | 69 +++++ .../ml/common/agent/MLToolSpecTest.java | 75 +++++ 9 files changed, 1111 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index f82e742866..53a12a4224 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,8 +5,25 @@ package org.opensearch.ml.common; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_CREATED_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.USER_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.ALL_CONFIG_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD; @@ -44,6 +61,12 @@ public class CommonValue { public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; public static final String ML_MAP_RESPONSE_KEY = "response"; + public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; + public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; + public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" @@ -326,4 +349,97 @@ public class CommonValue { + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; + + public static final String ML_AGENT_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_AGENT_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLAgent.AGENT_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLAgent.AGENT_TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + MLAgent.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLAgent.LLM_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.TOOLS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.MEMORY_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLAgent.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + + public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + META_NAME_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + META_CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + META_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + + "\": {\"type\": \"keyword\"}\n" + + " }\n" + + "}"; + + public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + INTERACTIONS_CONVERSATION_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + INTERACTIONS_INPUT_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_RESPONSE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_ORIGIN_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + + " }\n" + + "}"; } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java new file mode 100644 index 0000000000..6c0fda289a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class LLMSpec implements ToXContentObject { + public static final String MODEL_ID_FIELD = "model_id"; + public static final String PARAMETERS_FIELD = "parameters"; + + private String modelId; + private Map parameters; + + + @Builder(toBuilder = true) + public LLMSpec(String modelId, Map parameters) { + if (modelId == null) { + throw new IllegalArgumentException("model id is null"); + } + this.modelId = modelId; + this.parameters = parameters; + } + + public LLMSpec(StreamInput input) throws IOException{ + modelId = input.readString(); + if (input.readBoolean()) { + parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + builder.endObject(); + return builder; + } + + public static LLMSpec parse(XContentParser parser) throws IOException { + String modelId = null; + Map parameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + return LLMSpec.builder() + .modelId(modelId) + .parameters(parameters) + .build(); + } + + public static LLMSpec fromStream(StreamInput in) throws IOException { + LLMSpec toolSpec = new LLMSpec(in); + return toolSpec; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java new file mode 100644 index 0000000000..ba2f241375 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -0,0 +1,264 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class MLAgent implements ToXContentObject, Writeable { + public static final String AGENT_NAME_FIELD = "name"; + public static final String AGENT_TYPE_FIELD = "type"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String LLM_FIELD = "llm"; + public static final String TOOLS_FIELD = "tools"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String MEMORY_FIELD = "memory"; + public static final String MEMORY_ID_FIELD = "memory_id"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; + public static final String APP_TYPE_FIELD = "app_type"; + + private String name; + private String type; + private String description; + private LLMSpec llm; + private List tools; + private Map parameters; + private MLMemorySpec memory; + + private Instant createdTime; + private Instant lastUpdateTime; + private String appType; + + @Builder(toBuilder = true) + public MLAgent(String name, + String type, + String description, + LLMSpec llm, + List tools, + Map parameters, + MLMemorySpec memory, + Instant createdTime, + Instant lastUpdateTime, + String appType) { + if (name == null) { + throw new IllegalArgumentException("agent name is null"); + } + this.name = name; + this.type = type; + this.description = description; + this.llm = llm; + this.tools = tools; + this.parameters = parameters; + this.memory = memory; + this.createdTime = createdTime; + this.lastUpdateTime = lastUpdateTime; + this.appType = appType; + } + + public MLAgent(StreamInput input) throws IOException{ + name = input.readString(); + type = input.readString(); + description = input.readOptionalString(); + if (input.readBoolean()) { + llm = new LLMSpec(input); + } + if (input.readBoolean()) { + tools = new ArrayList<>(); + int size = input.readInt(); + for (int i=0; i toolNames = new HashSet<>(); + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); + } + } + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(type); + out.writeOptionalString(description); + if (llm != null) { + out.writeBoolean(true); + llm.writeTo(out); + } else { + out.writeBoolean(false); + } + if (tools != null && tools.size() > 0) { + out.writeBoolean(true); + out.writeInt(tools.size()); + for (MLToolSpec tool : tools) { + tool.writeTo(out); + } + } else { + out.writeBoolean(false); + } + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + if (memory != null) { + out.writeBoolean(true); + memory.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeInstant(createdTime); + out.writeInstant(lastUpdateTime); + out.writeString(appType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(AGENT_NAME_FIELD, name); + } + if (type != null) { + builder.field(AGENT_TYPE_FIELD, type); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (llm != null) { + builder.field(LLM_FIELD, llm); + } + if (tools != null && tools.size() > 0) { + builder.field(TOOLS_FIELD, tools); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + if (memory != null) { + builder.field(MEMORY_FIELD, memory); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (appType != null) { + builder.field(APP_TYPE_FIELD, appType); + } + builder.endObject(); + return builder; + } + + public static MLAgent parse(XContentParser parser) throws IOException { + String name = null; + String type = null; + String description = null;; + LLMSpec llm = null; + List tools = null; + Map parameters = null; + MLMemorySpec memory = null; + Instant createdTime = null; + Instant lastUpdateTime = null; + String appType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case AGENT_NAME_FIELD: + name = parser.text(); + break; + case AGENT_TYPE_FIELD: + type = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case LLM_FIELD: + llm = LLMSpec.parse(parser); + break; + case TOOLS_FIELD: + tools = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + tools.add(MLToolSpec.parse(parser)); + } + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + case MEMORY_FIELD: + memory = MLMemorySpec.parse(parser); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + break; + case APP_TYPE_FIELD: + appType = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLAgent.builder() + .name(name) + .type(type) + .description(description) + .llm(llm) + .tools(tools) + .parameters(parameters) + .memory(memory) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .appType(appType) + .build(); + } + + public static MLAgent fromStream(StreamInput in) throws IOException { + MLAgent agent = new MLAgent(in); + return agent; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java new file mode 100644 index 0000000000..aa192a7ee2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +@Getter +public class MLMemorySpec implements ToXContentObject { + public static final String MEMORY_TYPE_FIELD = "type"; + public static final String WINDOW_SIZE_FIELD = "window_size"; + public static final String SESSION_ID_FIELD = "session_id"; + + private String type; + @Setter + private String sessionId; + private Integer windowSize; + + + @Builder(toBuilder = true) + public MLMemorySpec(String type, + String sessionId, + Integer windowSize) { + if (type == null) { + throw new IllegalArgumentException("agent name is null"); + } + this.type = type; + this.sessionId = sessionId; + this.windowSize = windowSize; + } + + public MLMemorySpec(StreamInput input) throws IOException{ + type = input.readString(); + sessionId = input.readOptionalString(); + windowSize = input.readOptionalInt(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeOptionalString(sessionId); + out.writeOptionalInt(windowSize); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEMORY_TYPE_FIELD, type); + if (windowSize != null) { + builder.field(WINDOW_SIZE_FIELD, windowSize); + } + if (sessionId != null) { + builder.field(SESSION_ID_FIELD, sessionId); + } + builder.endObject(); + return builder; + } + + public static MLMemorySpec parse(XContentParser parser) throws IOException { + String type = null; + String sessionId = null; + Integer windowSize = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_TYPE_FIELD: + type = parser.text(); + break; + case SESSION_ID_FIELD: + sessionId = parser.text(); + break; + case WINDOW_SIZE_FIELD: + windowSize = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLMemorySpec.builder() + .type(type) + .sessionId(sessionId) + .windowSize(windowSize) + .build(); + } + + public static MLMemorySpec fromStream(StreamInput in) throws IOException { + MLMemorySpec toolSpec = new MLMemorySpec(in); + return toolSpec; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java new file mode 100644 index 0000000000..055c59d449 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class MLToolSpec implements ToXContentObject { + public static final String TOOL_TYPE_FIELD = "type"; + public static final String TOOL_NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + + private String type; + private String name; + private String description; + private Map parameters; + private boolean includeOutputInAgentResponse; + + + @Builder(toBuilder = true) + public MLToolSpec(String type, + String name, + String description, + Map parameters, + boolean includeOutputInAgentResponse) { + if (type == null) { + throw new IllegalArgumentException("tool type is null"); + } + this.type = type; + this.name = name; + this.description = description; + this.parameters = parameters; + this.includeOutputInAgentResponse = includeOutputInAgentResponse; + } + + public MLToolSpec(StreamInput input) throws IOException{ + type = input.readString(); + name = input.readOptionalString(); + description = input.readOptionalString(); + if (input.readBoolean()) { + parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + includeOutputInAgentResponse = input.readBoolean(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeOptionalString(name); + out.writeOptionalString(description); + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + out.writeBoolean(includeOutputInAgentResponse); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (type != null) { + builder.field(TOOL_TYPE_FIELD, type); + } + if (name != null) { + builder.field(TOOL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse); + builder.endObject(); + return builder; + } + + public static MLToolSpec parse(XContentParser parser) throws IOException { + String type = null; + String name = null; + String description = null; + Map parameters = null; + boolean includeOutputInAgentResponse = false; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TOOL_TYPE_FIELD: + type = parser.text(); + break; + case TOOL_NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: + includeOutputInAgentResponse = parser.booleanValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLToolSpec.builder() + .type(type) + .name(name) + .description(description) + .parameters(parameters) + .includeOutputInAgentResponse(includeOutputInAgentResponse) + .build(); + } + + public static MLToolSpec fromStream(StreamInput in) throws IOException { + MLToolSpec toolSpec = new MLToolSpec(in); + return toolSpec; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java new file mode 100644 index 0000000000..0964efb7d1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java @@ -0,0 +1,91 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class LLMSpecTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void constructor_NonModelID() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model id is null"); + + LLMSpec spec = new LLMSpec(null, Map.of("test_key", "test_value")); + } + + @Test + public void writeTo() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = new LLMSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + } + + @Test + public void writeTo_EmptyParameters() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of()); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = new LLMSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(null, spec1.getParameters()); + } + + @Test + public void toXContent() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + LLMSpec spec = LLMSpec.parse(parser); + + Assert.assertEquals(spec.getModelId(), "test_model"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); + } + + @Test + public void fromStream() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = LLMSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java new file mode 100644 index 0000000000..bfaec959c4 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -0,0 +1,143 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLAgentTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void constructor_NullName() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("agent name is null"); + + MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + } + + @Test + public void writeTo() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent.getAppType(), agent1.getAppType()); + Assert.assertEquals(agent.getDescription(), agent1.getDescription()); + Assert.assertEquals(agent.getCreatedTime(), agent1.getCreatedTime()); + Assert.assertEquals(agent.getName(), agent1.getName()); + Assert.assertEquals(agent.getParameters(), agent1.getParameters()); + Assert.assertEquals(agent.getType(), agent1.getType()); + } + + @Test + public void writeTo_NullLLM() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getLlm(), null); + } + + @Test + public void writeTo_NullTools() throws IOException { + MLAgent agent = new MLAgent("test", "flow", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getTools(), null); + } + + @Test + public void writeTo_NullParameters() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getParameters(), null); + } + + @Test + public void writeTo_NullMemory() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), null, Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getMemory(), null); + } + + @Test + public void toXContent() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + String expectedStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + + Assert.assertEquals(content, expectedStr); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLAgent agent = MLAgent.parse(parser); + + Assert.assertEquals(agent.getName(), "test"); + Assert.assertEquals(agent.getType(), "test"); + Assert.assertEquals(agent.getDescription(), "test"); + Assert.assertEquals(agent.getLlm().getModelId(), "test_model"); + Assert.assertEquals(agent.getLlm().getParameters(), Map.of("test_key", "test_value")); + Assert.assertEquals(agent.getTools().get(0).getName(), "test"); + Assert.assertEquals(agent.getTools().get(0).getType(), "test"); + Assert.assertEquals(agent.getTools().get(0).getDescription(), "test"); + Assert.assertEquals(agent.getTools().get(0).getParameters(), Map.of("test", "test")); + Assert.assertEquals(agent.getTools().get(0).isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(agent.getCreatedTime(), Instant.EPOCH); + Assert.assertEquals(agent.getLastUpdateTime(), Instant.EPOCH); + Assert.assertEquals(agent.getAppType(), "test"); + Assert.assertEquals(agent.getMemory().getSessionId(), "123"); + Assert.assertEquals(agent.getParameters(), Map.of("test", "test")); + } + + @Test + public void fromStream() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = MLAgent.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(agent.getAppType(), agent1.getAppType()); + Assert.assertEquals(agent.getDescription(), agent1.getDescription()); + Assert.assertEquals(agent.getCreatedTime(), agent1.getCreatedTime()); + Assert.assertEquals(agent.getName(), agent1.getName()); + Assert.assertEquals(agent.getParameters(), agent1.getParameters()); + Assert.assertEquals(agent.getType(), agent1.getType()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java new file mode 100644 index 0000000000..2d028985e0 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java @@ -0,0 +1,69 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLMemorySpecTest { + + @Test + public void writeTo() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLMemorySpec spec1 = new MLMemorySpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getSessionId(), spec1.getSessionId()); + Assert.assertEquals(spec.getWindowSize(), spec1.getWindowSize()); + } + + @Test + public void toXContent() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLMemorySpec spec = MLMemorySpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test"); + Assert.assertEquals(spec.getWindowSize(), Integer.valueOf(0)); + Assert.assertEquals(spec.getSessionId(), "123"); + } + + @Test + public void fromStream() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLMemorySpec spec1 = MLMemorySpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getSessionId(), spec1.getSessionId()); + Assert.assertEquals(spec.getWindowSize(), spec1.getWindowSize()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java new file mode 100644 index 0000000000..d831611035 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLToolSpecTest { + + @Test + public void writeTo() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + } + + @Test + public void toXContent() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLToolSpec spec = MLToolSpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test"); + Assert.assertEquals(spec.getName(), "test"); + Assert.assertEquals(spec.getDescription(), "test"); + Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + } + + @Test + public void fromStream() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + } +} \ No newline at end of file From 48ad895e0f6beea5f2843da1d403c333f4b5ffda Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 13 Dec 2023 15:37:35 -0800 Subject: [PATCH 3/7] Add CatIndexTool (#1746) * Add CatIndexTool Signed-off-by: Daniel Widdis * Add test coverage Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .gitignore | 3 + .../ml/engine/tools/CatIndexTool.java | 435 ++++++++++++++++++ .../ml/engine/tools/CatIndexToolTests.java | 245 ++++++++++ 3 files changed, 683 insertions(+) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java diff --git a/.gitignore b/.gitignore index 154f424daf..2fc377955d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ .gradle/ build/ .idea/ +.project +.classpath +.settings client/build/ common/build/ ml-algorithms/build/ diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java new file mode 100644 index 0000000000..16cec3870d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -0,0 +1,435 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Spliterators; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.logging.log4j.util.Strings; +import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Table; +import org.opensearch.common.Table.Cell; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.index.IndexSettings; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Getter; +import lombok.Setter; + +@ToolAnnotation(CatIndexTool.TYPE) +public class CatIndexTool implements Tool { + public static final String TYPE = "CatIndexTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information."; + + @Setter + @Getter + private String name = CatIndexTool.TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + @SuppressWarnings("unused") + private ClusterService clusterService; + + public CatIndexTool(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + // TODO: This logic exactly matches the OpenSearch _cat/indices REST action. If code at + // o.o.rest/action/cat/RestIndicesAction.java changes those changes need to be reflected here + // https://github.com/opensearch-project/ml-commons/pull/1582#issuecomment-1796962876 + @SuppressWarnings("unchecked") + List indexList = parameters.containsKey("indices") + ? gson.fromJson(parameters.get("indices"), List.class) + : Collections.emptyList(); + final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); + + final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); + final boolean local = parameters.containsKey("local") ? Boolean.parseBoolean("local") : false; + final TimeValue clusterManagerNodeTimeout = DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; + final boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments") + ? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) + : false; + + final ActionListener internalListener = ActionListener.notifyOnce(ActionListener.wrap(table -> { + // Handle empty table + if (table.getRows().isEmpty()) { + @SuppressWarnings("unchecked") + T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); + listener.onResponse(empty); + return; + } + StringBuilder sb = new StringBuilder( + // Currently using c.value which is short header matching _cat/indices + // May prefer to use c.attr.get("desc") for full description + table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining("\t", "", "\n")) + ); + for (List row : table.getRows()) { + sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining("\t", "", "\n"))); + } + @SuppressWarnings("unchecked") + T response = (T) sb.toString(); + listener.onResponse(response); + }, listener::onFailure)); + + sendGetSettingsRequest( + indices, + indicesOptions, + local, + clusterManagerNodeTimeout, + client, + new ActionListener() { + @Override + public void onResponse(final GetSettingsResponse getSettingsResponse) { + final GroupedActionListener groupedListener = createGroupedListener(4, internalListener); + groupedListener.onResponse(getSettingsResponse); + + // The list of indices that will be returned is determined by the indices returned from the Get Settings call. + // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the + // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so + // force the IndicesOptions for all the sub-requests to be as inclusive as possible. + final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); + + sendIndicesStatsRequest( + indices, + subRequestIndicesOptions, + includeUnloadedSegments, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterStateRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterHealthRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + } + + @Override + public void onFailure(final Exception e) { + internalListener.onFailure(e); + } + } + ); + } + + @Override + public String getType() { + return TYPE; + } + + /** + * We're using the Get Settings API here to resolve the authorized indices for the user. + * This is because the Cluster State and Cluster Health APIs do not filter output based + * on index privileges, so they can't be used to determine which indices are authorized + * or not. On top of this, the Indices Stats API cannot be used either to resolve indices + * as it does not provide information for all existing indices (for example recovering + * indices or non replicated closed indices are not reported in indices stats response). + */ + private void sendGetSettingsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + final GetSettingsRequest request = new GetSettingsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.names(IndexSettings.INDEX_SEARCH_THROTTLED.getKey()); + + client.admin().indices().getSettings(request, listener); + } + + private void sendClusterStateRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + + final ClusterStateRequest request = new ClusterStateRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().state(request, listener); + } + + private void sendClusterHealthRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + + final ClusterHealthRequest request = new ClusterHealthRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().health(request, listener); + } + + private void sendIndicesStatsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean includeUnloadedSegments, + final Client client, + final ActionListener listener + ) { + + final IndicesStatsRequest request = new IndicesStatsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.all(); + request.includeUnloadedSegments(includeUnloadedSegments); + + client.admin().indices().stats(request, listener); + } + + private GroupedActionListener createGroupedListener(final int size, final ActionListener
listener) { + return new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(final Collection responses) { + try { + GetSettingsResponse settingsResponse = extractResponse(responses, GetSettingsResponse.class); + Map indicesSettings = StreamSupport + .stream(Spliterators.spliterator(settingsResponse.getIndexToSettings().entrySet(), 0), false) + .collect(Collectors.toMap(cursor -> cursor.getKey(), cursor -> cursor.getValue())); + + ClusterStateResponse stateResponse = extractResponse(responses, ClusterStateResponse.class); + Map indicesStates = StreamSupport + .stream(stateResponse.getState().getMetadata().spliterator(), false) + .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); + + ClusterHealthResponse healthResponse = extractResponse(responses, ClusterHealthResponse.class); + Map indicesHealths = healthResponse.getIndices(); + + IndicesStatsResponse statsResponse = extractResponse(responses, IndicesStatsResponse.class); + Map indicesStats = statsResponse.getIndices(); + + Table responseTable = buildTable(indicesSettings, indicesHealths, indicesStats, indicesStates); + listener.onResponse(responseTable); + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }, size); + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + /** + * Factory for the {@link CatIndexTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + private ClusterService clusterService; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CatIndexTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + * @param clusterService The OpenSearch cluster service + */ + public void init(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public CatIndexTool create(Map map) { + return new CatIndexTool(client, clusterService); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + + private Table getTableWithHeader() { + Table table = new Table(); + table.startHeaders(); + // First param is cell.value which is currently returned + // Second param is cell.attr we may want to use attr.desc in the future + table.addCell("health", "alias:h;desc:current health status"); + table.addCell("status", "alias:s;desc:open/close status"); + table.addCell("index", "alias:i,idx;desc:index name"); + table.addCell("uuid", "alias:id,uuid;desc:index uuid"); + table.addCell("pri", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); + table.addCell("rep", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); + table.addCell("docs.count", "alias:dc,docsCount;text-align:right;desc:available docs"); + table.addCell("docs.deleted", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); + table.addCell("store.size", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); + table.addCell("pri.store.size", "text-align:right;desc:store size of primaries"); + // Above includes all the default fields for cat indices. See RestIndicesAction for a lot more that could be included. + table.endHeaders(); + return table; + } + + private Table buildTable( + final Map indicesSettings, + final Map indicesHealths, + final Map indicesStats, + final Map indicesMetadatas + ) { + final Table table = getTableWithHeader(); + + indicesSettings.forEach((indexName, settings) -> { + if (indicesMetadatas.containsKey(indexName) == false) { + // the index exists in the Get Indices response but is not present in the cluster state: + // it is likely that the index was deleted in the meanwhile, so we ignore it. + return; + } + + final IndexMetadata indexMetadata = indicesMetadatas.get(indexName); + final IndexMetadata.State indexState = indexMetadata.getState(); + final IndexStats indexStats = indicesStats.get(indexName); + + final String health; + final ClusterIndexHealth indexHealth = indicesHealths.get(indexName); + if (indexHealth != null) { + health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); + } else if (indexStats != null) { + health = "red*"; + } else { + health = ""; + } + + final CommonStats primaryStats; + final CommonStats totalStats; + + if (indexStats == null || indexState == IndexMetadata.State.CLOSE) { + primaryStats = new CommonStats(); + totalStats = new CommonStats(); + } else { + primaryStats = indexStats.getPrimaries(); + totalStats = indexStats.getTotal(); + } + table.startRow(); + table.addCell(health); + table.addCell(indexState.toString().toLowerCase(Locale.ROOT)); + table.addCell(indexName); + table.addCell(indexMetadata.getIndexUUID()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfShards()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfReplicas()); + + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getCount()); + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getDeleted()); + + table.addCell(totalStats.getStore() == null ? null : totalStats.getStore().size()); + table.addCell(primaryStats.getStore() == null ? null : primaryStats.getStore().size()); + + table.endRow(); + }); + + return table; + } + + @SuppressWarnings("unchecked") + private static A extractResponse(final Collection responses, Class c) { + return (A) responses.stream().filter(c::isInstance).findFirst().get(); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java new file mode 100644 index 0000000000..cffb0ff338 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -0,0 +1,245 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.CommonStatsFlags; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndexStats.IndexStatsBuilder; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexMetadata.State; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.engine.tools.CatIndexTool.Factory; + +public class CatIndexToolTests { + + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private GetSettingsResponse getSettingsResponse; + @Mock + private IndicesStatsResponse indicesStatsResponse; + @Mock + private ClusterStateResponse clusterStateResponse; + @Mock + private ClusterHealthResponse clusterHealthResponse; + @Mock + private IndexMetadata indexMetadata; + @Mock + private IndexRoutingTable indexRoutingTable; + + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(client.admin()).thenReturn(adminClient); + + when(indexMetadata.getState()).thenReturn(State.OPEN); + when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT); + + when(metadata.index(any(String.class))).thenReturn(indexMetadata); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + + CatIndexTool.Factory.getInstance().init(client, clusterService); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testRunAsyncNoIndices() throws Exception { + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Collections.emptyMap()); + when(indicesStatsResponse.getIndices()).thenReturn(Collections.emptyMap()); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[0])); + + when(clusterHealthResponse.getIndices()).thenReturn(Collections.emptyMap()); + + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.join(); + assertEquals("There were no results searching the indices parameter [null].", future.get()); + } + + @Test + public void testRunAsyncIndexStats() throws Exception { + String indexName = "foo"; + Index index = new Index(indexName, UUIDs.base64UUID()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Map.of("foo", Settings.EMPTY)); + + int shardId = 0; + ShardId shId = new ShardId(index, shardId); + Path path = Files.createTempDirectory("temp").resolve("indices").resolve(index.getUUID()).resolve(String.valueOf(shardId)); + ShardPath shardPath = new ShardPath(false, path, path, shId); + ShardRouting routing = TestShardRouting.newShardRouting(shId, "node", true, ShardRoutingState.STARTED); + CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL); + IndexStats fooStats = new IndexStatsBuilder(index.getName(), index.getUUID()) + .add(new ShardStats(routing, shardPath, commonStats, null, null, null)) + .build(); + when(indicesStatsResponse.getIndices()).thenReturn(Map.of(indexName, fooStats)); + + when(indexMetadata.getIndex()).thenReturn(index); + when(indexMetadata.getNumberOfShards()).thenReturn(5); + when(indexMetadata.getNumberOfReplicas()).thenReturn(1); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[] { indexMetadata })); + @SuppressWarnings("unchecked") + Iterator iterator = (Iterator) mock(Iterator.class); + when(iterator.hasNext()).thenReturn(false); + when(indexRoutingTable.iterator()).thenReturn(iterator); + ClusterIndexHealth fooHealth = new ClusterIndexHealth(indexMetadata, indexRoutingTable); + when(clusterHealthResponse.getIndices()).thenReturn(Map.of(indexName, fooHealth)); + + // Now make the call + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.orTimeout(10, TimeUnit.SECONDS).join(); + String response = future.get(); + String[] responseRows = response.trim().split("\\n"); + + assertEquals(2, responseRows.length); + String header = responseRows[0]; + String fooRow = responseRows[1]; + assertEquals(header.split("\\t").length, fooRow.split("\\t").length); + assertEquals("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size", header); + assertEquals("red\topen\tfoo\tnull\t5\t1\t0\t0\t0b\t0b", fooRow); + } + + @Test + public void testTool() { + Factory instance = CatIndexTool.Factory.getInstance(); + assertEquals(instance, CatIndexTool.Factory.getInstance()); + assertTrue(instance.getDefaultDescription().contains("tool")); + + Tool tool = instance.create(Collections.emptyMap()); + assertEquals(CatIndexTool.TYPE, tool.getType()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + } +} From 00b5feaf6605e37eca769e32b7dccbf3b2270573 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Fri, 15 Dec 2023 15:29:56 -0800 Subject: [PATCH 4/7] Memory Manager and Update Memory Actions/APIs (#1761) * More memory actions, APS and tests Signed-off-by: Xun Zhang * refactor memory manager and Get Trace actions Signed-off-by: Xun Zhang * updates for some comments Signed-off-by: Xun Zhang * comments updated Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../common/conversation/ActionConstants.java | 9 + .../action/conversation/GetTracesAction.java | 23 ++ .../action/conversation/GetTracesRequest.java | 124 ++++++++ .../conversation/GetTracesResponse.java | 77 +++++ .../GetTracesTransportAction.java | 64 ++++ .../UpdateConversationAction.java | 18 ++ .../UpdateConversationRequest.java | 105 +++++++ .../UpdateConversationTransportAction.java | 68 +++++ .../conversation/UpdateInteractionAction.java | 19 ++ .../UpdateInteractionRequest.java | 110 +++++++ .../UpdateInteractionTransportAction.java | 70 +++++ .../conversation/GetTracesRequestTests.java | 101 +++++++ .../conversation/GetTracesResponseTests.java | 109 +++++++ .../GetTracesTransportActionTests.java | 159 ++++++++++ .../UpdateConversationRequestTests.java | 155 ++++++++++ ...pdateConversationTransportActionTests.java | 135 +++++++++ .../UpdateInteractionRequestTests.java | 172 +++++++++++ ...UpdateInteractionTransportActionTests.java | 134 +++++++++ ml-algorithms/build.gradle | 1 + .../ml/engine/memory/MLMemoryManager.java | 164 +++++++++++ .../engine/memory/MLMemoryManagerTests.java | 277 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 22 +- .../ml/rest/RestMemoryGetTracesAction.java | 37 +++ .../RestMemoryUpdateConversationAction.java | 59 ++++ .../RestMemoryUpdateInteractionAction.java | 60 ++++ .../rest/RestMemoryGetTracesActionTests.java | 64 ++++ .../RestMemoryUpdateConversationTests.java | 165 +++++++++++ ...estMemoryUpdateInteractionActionTests.java | 164 +++++++++++ 28 files changed, 2663 insertions(+), 2 deletions(-) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 8776c618b0..119d5a6659 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -29,6 +29,8 @@ public class ActionConstants { public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations"; /** name of list on interactions in all responses */ public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions"; + /** name of list on traces in all responses */ + public final static String RESPONSE_TRACES_LIST_FIELD = "traces"; /** name of interaction Id field in all responses */ public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id"; @@ -56,20 +58,27 @@ public class ActionConstants { public final static String SUCCESS_FIELD = "success"; private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; + private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction"; /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; /** path for get conversations */ public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; + /** path for update conversations */ + public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update"; /** path for create interaction */ public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; + /** path for get traces */ + public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list"; /** path for delete conversation */ public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete"; /** path for search conversations */ public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + /** path for update interactions */ + public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update"; /** path for get conversation */ public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; /** path for get interaction */ diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java new file mode 100644 index 0000000000..0117df94b5 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action to return the traces associated with an interaction + */ +public class GetTracesAction extends ActionType { + /** Instance of this */ + public static final GetTracesAction INSTANCE = new GetTracesAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get"; + + private GetTracesAction() { + super(NAME, GetTracesResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java new file mode 100644 index 0000000000..9b65f78148 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * ActionRequest for get traces + */ +public class GetTracesRequest extends ActionRequest { + @Getter + private String interactionId; + @Getter + private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; + @Getter + private int from = 0; + + /** + * Constructor + * @param interactionId UID of the interaction to get traces from + */ + public GetTracesRequest(String interactionId) { + this.interactionId = interactionId; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults) { + this.interactionId = interactionId; + this.maxResults = maxResults; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + * @param from position of first interaction to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults, int from) { + this.interactionId = interactionId; + this.maxResults = maxResults; + this.from = from; + } + + /** + * Constructor + * @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo + * @throws IOException if there wasn't a GIR in the stream + */ + public GetTracesRequest(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + this.maxResults = in.readInt(); + this.from = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(interactionId); + out.writeInt(maxResults); + out.writeInt(from); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (interactionId == null) { + exception = addValidationError("Traces must be retrieved from an interaction", exception); + } + if (maxResults <= 0) { + exception = addValidationError("The number of traces to retrieve must be positive", exception); + } + if (from < 0) { + exception = addValidationError("The starting position must be nonnegative", exception); + } + + return exception; + } + + /** + * Makes a GetTracesRequest out of a RestRequest + * @param request Rest Request representing a get traces request + * @return a new GetTracesRequest + * @throws IOException if something goes wrong + */ + public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException { + String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { + int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults, from); + } else { + return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); + } + } else { + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults); + } else { + return new GetTracesRequest(cid); + } + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java new file mode 100644 index 0000000000..38486f1c1b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NonNull; + +/** + * Action Response for get traces for an interaction + */ +@AllArgsConstructor +public class GetTracesResponse extends ActionResponse implements ToXContentObject { + @Getter + @NonNull + private List traces; + @Getter + private int nextToken; + private boolean hasMoreTokens; + + /** + * Constructor + * @param in stream input; assumes GetTracesResponse.writeTo was called + * @throws IOException if there's not a G.I.R. in the stream + */ + public GetTracesResponse(StreamInput in) throws IOException { + super(in); + traces = in.readList(Interaction::fromStream); + nextToken = in.readInt(); + hasMoreTokens = in.readBoolean(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeList(traces); + out.writeInt(nextToken); + out.writeBoolean(hasMoreTokens); + } + + /** + * Are there more pages in this search results + * @return whether there are more traces in this search + */ + public boolean hasMorePages() { + return hasMoreTokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD); + for (Interaction trace : traces) { + trace.toXContent(builder, params); + } + builder.endArray(); + if (hasMoreTokens) { + builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); + } + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java new file mode 100644 index 0000000000..698136fd95 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.util.List; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + */ + @Inject + public GetTracesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client + ) { + super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new); + this.client = client; + this.cmHandler = cmHandler; + } + + @Override + public void doExecute(Task task, GetTracesRequest request, ActionListener actionListener) { + int maxResults = request.getMaxResults(); + int from = request.getFrom(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + // TODO: check this newStoredContext() method and remove it if it's redundant + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener> al = ActionListener.wrap(tracesList -> { + internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getTraces(request.getInteractionId(), from, maxResults, al); + } catch (Exception e) { + log.error("Failed to get traces for conversation " + request.getInteractionId(), e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java new file mode 100644 index 0000000000..6c8023171e --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class UpdateConversationAction extends ActionType { + public static final UpdateConversationAction INSTANCE = new UpdateConversationAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/update"; + + private UpdateConversationAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java new file mode 100644 index 0000000000..7afec5d0ab --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class UpdateConversationRequest extends ActionRequest { + private String conversationId; + private Map updateContent; + + private static final Set allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD)); + + @Builder + public UpdateConversationRequest(String conversationId, Map updateContent) { + this.conversationId = conversationId; + this.updateContent = filterUpdateContent(updateContent); + } + + public UpdateConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.updateContent = filterUpdateContent(in.readMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + out.writeMap(this.getUpdateContent()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.conversationId == null) { + exception = addValidationError("conversation id can't be null", exception); + } + if (this.updateContent == null) { + exception = addValidationError("Update conversation content can't be null", exception); + } + + return exception; + } + + public static UpdateConversationRequest parse(XContentParser parser, String conversationId) throws IOException { + Map dataAsMap = null; + dataAsMap = parser.map(); + + return UpdateConversationRequest.builder().conversationId(conversationId).updateContent(dataAsMap).build(); + } + + public static UpdateConversationRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof UpdateConversationRequest) { + return (UpdateConversationRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new UpdateConversationRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into UpdateConversationRequest", e); + } + } + + private Map filterUpdateContent(Map updateContent) { + if (updateContent == null) { + return new HashMap<>(); + } + return updateContent + .entrySet() + .stream() + .filter(map -> allowedList.contains(map.getKey())) + .collect(Collectors.toMap(map -> map.getKey(), map -> map.getValue())); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java new file mode 100644 index 0000000000..9f8c42f17b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UpdateConversationTransportAction extends HandledTransportAction { + Client client; + + @Inject + public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); + String conversationId = updateConversationRequest.getConversationId(); + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); + updateRequest.doc(updateConversationRequest.getUpdateContent()); + updateRequest.docAsUpsert(true); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Conversation for conversation id" + conversationId, e); + listener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String conversationId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log.info("Successfully updated the Conversation with ID: {}", conversationId); + actionListener.onResponse(updateResponse); + } else { + log.info("Failed to update the Conversation with ID: {}", conversationId); + actionListener.onResponse(updateResponse); + } + }, exception -> { + log.error("Failed to update ML Conversation with ID " + conversationId, exception); + actionListener.onFailure(exception); + }), context::restore); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java new file mode 100644 index 0000000000..64f7e56846 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class UpdateInteractionAction extends ActionType { + public static final UpdateInteractionAction INSTANCE = new UpdateInteractionAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/update"; + + private UpdateInteractionAction() { + super(NAME, UpdateResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java new file mode 100644 index 0000000000..96ef467590 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class UpdateInteractionRequest extends ActionRequest { + private String interactionId; + private Map updateContent; + + private static final Set allowedList = new HashSet<>(Arrays.asList(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + + @Builder + public UpdateInteractionRequest(String interactionId, Map updateContent) { + this.interactionId = interactionId; + this.updateContent = filterUpdateContent(updateContent); + } + + public UpdateInteractionRequest(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + this.updateContent = filterUpdateContent(in.readMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.interactionId); + out.writeMap(this.getUpdateContent()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.interactionId == null) { + exception = addValidationError("interaction id can't be null", exception); + } + if (this.updateContent == null) { + exception = addValidationError("Update Interaction content can't be null", exception); + } + + return exception; + } + + public static UpdateInteractionRequest parse(XContentParser parser, String interactionId) throws IOException { + Map dataAsMap = null; + dataAsMap = parser.map(); + + if (dataAsMap == null) { + throw new OpenSearchParseException("Failed to parse UpdateInteractionRequest due to Null update content"); + } + + return UpdateInteractionRequest.builder().interactionId(interactionId).updateContent(dataAsMap).build(); + } + + public static UpdateInteractionRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof UpdateInteractionRequest) { + return (UpdateInteractionRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new UpdateInteractionRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into UpdateInteractionRequest", e); + } + } + + private Map filterUpdateContent(Map updateContent) { + if (updateContent == null) { + return new HashMap<>(); + } + return updateContent + .entrySet() + .stream() + .filter(map -> allowedList.contains(map.getKey())) + .collect(Collectors.toMap(map -> map.getKey(), map -> map.getValue())); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java new file mode 100644 index 0000000000..9abf8571c4 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UpdateInteractionTransportAction extends HandledTransportAction { + Client client; + + @Inject + public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request); + String interactionId = updateInteractionRequest.getInteractionId(); + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); + updateRequest.doc(updateInteractionRequest.getUpdateContent()); + updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Interaction for interaction id " + interactionId, e); + listener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String interactionId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log.info("Successfully updated the interaction with ID: {}", interactionId); + actionListener.onResponse(updateResponse); + } else { + log.info("Failed to update the interaction with ID: {}", interactionId); + actionListener.onResponse(updateResponse); + } + }, exception -> { + log.error("Failed to update ML interaction with ID " + interactionId, exception); + actionListener.onFailure(exception); + }), context::restore); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java new file mode 100644 index 0000000000..0b88bd48c6 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetTracesRequestTests extends OpenSearchTestCase { + + public void testConstructorsAndStreaming() throws IOException { + GetTracesRequest request = new GetTracesRequest("test-iid"); + assert (request.validate() == null); + assert (request.getInteractionId().equals("test-iid")); + assert (request.getFrom() == 0); + assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + + GetTracesRequest req2 = new GetTracesRequest("test-iid2", 3); + assert (req2.validate() == null); + assert (req2.getInteractionId().equals("test-iid2")); + assert (req2.getFrom() == 0); + assert (req2.getMaxResults() == 3); + + GetTracesRequest req3 = new GetTracesRequest("test-iid3", 4, 5); + assert (req3.validate() == null); + assert (req3.getInteractionId().equals("test-iid3")); + assert (req3.getFrom() == 5); + assert (req3.getMaxResults() == 4); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesRequest req4 = new GetTracesRequest(in); + assert (req4.validate() == null); + assert (req4.getInteractionId().equals("test-iid")); + assert (req4.getFrom() == 0); + assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + } + + public void testBadValues_thenFail() { + String nullstr = null; + GetTracesRequest request = new GetTracesRequest(nullstr); + assert (request.validate().validationErrors().get(0).equals("Traces must be retrieved from an interaction")); + assert (request.validate().validationErrors().size() == 1); + + request = new GetTracesRequest("iid", -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The number of traces to retrieve must be positive")); + + request = new GetTracesRequest("iid", 2, -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative")); + } + + public void testFromRestRequest() throws IOException { + Map basic = Map.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid1"); + Map maxResOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map bothFields = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid4", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); + RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); + RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); + RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); + GetTracesRequest gir1 = GetTracesRequest.fromRestRequest(req1); + GetTracesRequest gir2 = GetTracesRequest.fromRestRequest(req2); + GetTracesRequest gir3 = GetTracesRequest.fromRestRequest(req3); + GetTracesRequest gir4 = GetTracesRequest.fromRestRequest(req4); + + assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null); + assert (gir1.getInteractionId().equals("iid1") && gir2.getInteractionId().equals("iid2")); + assert (gir3.getInteractionId().equals("iid3") && gir4.getInteractionId().equals("iid4")); + assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7); + assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4); + assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java new file mode 100644 index 0000000000..e013bcc518 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetTracesResponseTests extends OpenSearchTestCase { + List traces; + + @Before + public void setup() { + traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ), + new Interaction( + "id1", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 2 + ), + new Interaction( + "id2", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 3 + + ) + ); + } + + public void testGetInteractionsResponseStreaming() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces, 4, true); + assert (response.getTraces().equals(traces)); + assert (response.getNextToken() == 4); + assert (response.hasMorePages()); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesResponse newResp = new GetTracesResponse(in); + assert (newResp.getTraces().equals(traces)); + assert (newResp.getNextToken() == 4); + assert (newResp.hasMorePages()); + } + + public void testToXContent_MoreTokens() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces.subList(0, 1), 2, true); + Interaction trace = response.getTraces().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + System.out.println(result); + String expected = "{\"traces\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":" + + trace.getCreateTime() + + ",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"},\"parent_interaction_id\":\"parent_id\",\"trace_number\":1}],\"next_token\":2}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } + + @Test(expected = NullPointerException.class) + public void testConstructor_NullTraces() { + GetTracesResponse response = new GetTracesResponse(null, 0, false); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java new file mode 100644 index 0000000000..c6aef01097 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetTracesRequest request; + GetTracesTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetTracesRequest("test-iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + + this.action = spy(new GetTracesTransportAction(transportService, actionFilters, cmHandler, client)); + } + + public void testGetTraces_noMorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent-id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testGetTraces_MorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + GetTracesRequest shortPageRequest = new GetTracesRequest("test-trace", 1); + action.doExecute(null, shortPageRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (argCaptor.getValue().hasMorePages()); + } + + public void testGetTracesFails_thenFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new Exception("Testing Failure")); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Failure")); + } + + public void testDoExecuteFails_thenFail() { + doThrow(new RuntimeException("Failure in doExecute")).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in doExecute")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java new file mode 100644 index 0000000000..3b25f1b174 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class UpdateConversationRequestTests { + Map updateContent = new HashMap<>(); + + @Before + public void setUp() { + updateContent.put(META_NAME_FIELD, "new name"); + } + + @Test + public void testConstructor() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest("conversationId", updateContent); + assert (updateConversationRequest.validate() == null); + assert (updateConversationRequest.getConversationId().equals("conversationId")); + assert (updateConversationRequest.getUpdateContent().size() == 1); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateConversationRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateConversationRequest newRequest = new UpdateConversationRequest(in); + assert updateConversationRequest.getConversationId().equals(newRequest.getConversationId()); + assert updateConversationRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + } + + @Test + public void testConstructor_UpdateContentNotAllowed() throws IOException { + Map updateCont = new HashMap<>(); + updateCont.put(META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest("conversationId", updateCont); + assert (updateConversationRequest.validate() == null); + assert (updateConversationRequest.getConversationId().equals("conversationId")); + assert (updateConversationRequest.getUpdateContent().size() == 0); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateConversationRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateConversationRequest newRequest = new UpdateConversationRequest(in); + assert updateConversationRequest.getConversationId().equals(newRequest.getConversationId()); + assert updateConversationRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + assert (newRequest.getUpdateContent().size() == 0); + } + + @Test + public void testConstructor_NullConversationId() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest(null, updateContent); + assert updateConversationRequest.validate().getMessage().equals("Validation Failed: 1: conversation id can't be null;"); + } + + @Test + public void testConstructor_NullUpdateContent() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest(null, null); + assert updateConversationRequest.validate().getMessage().equals("Validation Failed: 1: conversation id can't be null;"); + } + + @Test + public void testParse_Success() throws IOException { + String jsonStr = "{\"name\":\"new name\",\"application_type\":\"new type\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.parse(parser, "conversationId"); + assertEquals(updateConversationRequest.getConversationId(), "conversationId"); + assertEquals("new name", updateConversationRequest.getUpdateContent().get("name")); + } + + @Test + public void fromActionRequest_Success() { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest + .builder() + .conversationId("conversationId") + .updateContent(updateContent) + .build(); + assertSame(UpdateConversationRequest.fromActionRequest(updateConversationRequest), updateConversationRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest + .builder() + .conversationId("conversationId") + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateConversationRequest.writeTo(out); + } + }; + UpdateConversationRequest request = UpdateConversationRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateConversationRequest); + assertEquals(updateConversationRequest.getConversationId(), request.getConversationId()); + assertEquals(updateConversationRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + UpdateConversationRequest.fromActionRequest(actionRequest); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java new file mode 100644 index 0000000000..ea713d99bb --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateConversationTransportActionTests extends OpenSearchTestCase { + private UpdateConversationTransportAction transportUpdateConversationAction; + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private Task task; + + @Mock + private UpdateConversationRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String conversationId = "test_conversation_id"; + Map updateContent = Map.of(META_NAME_FIELD, "new name", META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); + when(updateRequest.getConversationId()).thenReturn(conversationId); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + transportUpdateConversationAction = new UpdateConversationTransportAction(transportService, actionFilters, client); + } + + public void test_execute_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Error in Update Request")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_ThrowException() { + doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java new file mode 100644 index 0000000000..4db4b768e8 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class UpdateInteractionRequestTests { + + Map updateContent = new HashMap<>(); + + @Before + public void setUp() { + updateContent.put(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + } + + @Test + public void testConstructor() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest("interaction_id", updateContent); + assert updateInteractionRequest.validate() == null; + assert updateInteractionRequest.getInteractionId().equals("interaction_id"); + assert updateInteractionRequest.getUpdateContent().size() == 1; + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateInteractionRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateInteractionRequest newRequest = new UpdateInteractionRequest(in); + assert updateInteractionRequest.getInteractionId().equals(newRequest.getInteractionId()); + assert updateInteractionRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + } + + @Test + public void testConstructor_UpdateContentNotAllowed() throws IOException { + updateContent.put(INTERACTIONS_RESPONSE_FIELD, "response"); + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest("interaction_id", updateContent); + assert (updateInteractionRequest.validate() == null); + assert (updateInteractionRequest.getInteractionId().equals("interaction_id")); + assert (updateInteractionRequest.getUpdateContent().size() == 1); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateInteractionRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateInteractionRequest newRequest = new UpdateInteractionRequest(in); + assert updateInteractionRequest.getInteractionId().equals(newRequest.getInteractionId()); + assert updateInteractionRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + assert (newRequest.getUpdateContent().size() == 1); + } + + @Test + public void testConstructor_NullInteractionId() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest(null, updateContent); + assert updateInteractionRequest.validate().getMessage().equals("Validation Failed: 1: interaction id can't be null;"); + } + + @Test + public void testConstructor_NullUpdateContent() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest(null, null); + assert updateInteractionRequest.validate().getMessage().equals("Validation Failed: 1: interaction id can't be null;"); + } + + @Test + public void testParse_Success() throws IOException { + String jsonStr = "{\"additional_info\": {\n" + " \"feedback\": \"thumbs up!\"\n" + " }}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.parse(parser, "interaction_id"); + assertEquals(updateInteractionRequest.getInteractionId(), "interaction_id"); + assertEquals(Map.of("feedback", "thumbs up!"), updateInteractionRequest.getUpdateContent().get(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + } + + @Test + public void testParse_UpdateContentNotAllowed() throws IOException { + String jsonStr = "{\"response\": \"new response!\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.parse(parser, "interaction_id"); + assertEquals(updateInteractionRequest.getInteractionId(), "interaction_id"); + assertEquals(0, updateInteractionRequest.getUpdateContent().size()); + assertNotEquals(null, updateInteractionRequest.getUpdateContent()); + } + + @Test + public void fromActionRequest_Success() { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest + .builder() + .interactionId("interaction_id") + .updateContent(updateContent) + .build(); + assertSame(UpdateInteractionRequest.fromActionRequest(updateInteractionRequest), updateInteractionRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest + .builder() + .interactionId("interaction_id") + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateInteractionRequest.writeTo(out); + } + }; + UpdateInteractionRequest request = UpdateInteractionRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateInteractionRequest); + assertEquals(updateInteractionRequest.getInteractionId(), request.getInteractionId()); + assertEquals(updateInteractionRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + UpdateInteractionRequest.fromActionRequest(actionRequest); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java new file mode 100644 index 0000000000..3dbd16ca64 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateInteractionTransportActionTests extends OpenSearchTestCase { + private UpdateInteractionTransportAction updateInteractionTransportAction; + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private Task task; + + @Mock + private UpdateInteractionRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String interactionId = "test_interaction_id"; + Map updateContent = Map + .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); + when(updateRequest.getInteractionId()).thenReturn(interactionId); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + updateInteractionTransportAction = new UpdateInteractionTransportAction(transportService, actionFilters, client); + } + + public void test_execute_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Error in Update Request")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_ThrowException() { + doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } +} diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 700e251676..e2e37c2212 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -20,6 +20,7 @@ dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') implementation project(':opensearch-ml-common') + implementation project(':opensearch-ml-memory') implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java new file mode 100644 index 0000000000..dc99ef4438 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; + +import com.google.common.base.Preconditions; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Memory manager for Memories. It contains ML memory related operations like create, read interactions etc. + */ +@Log4j2 +@AllArgsConstructor +public class MLMemoryManager { + + private Client client; + + /** + * Create a new Conversation + * @param name the name of the conversation + * @param applicationType the application type that creates this conversation + * @param actionListener action listener to process the response + */ + public void createConversation(String name, String applicationType, ActionListener actionListener) { + try { + client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name, applicationType), actionListener); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param parentIntId the parent interactionId of this interaction + * @param traceNum the trace number for a parent interaction + * @param actionListener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + String parentIntId, + Integer traceNum, + ActionListener actionListener + ) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkNotNull(input); + Preconditions.checkNotNull(response); + // additionalInfo cannot be null as flat object + additionalInfo = (additionalInfo == null) ? new HashMap<>() : additionalInfo; + try { + client + .execute( + CreateInteractionAction.INSTANCE, + new CreateInteractionRequest( + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + parentIntId, + traceNum + ), + actionListener + ); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation that are not traces, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param lastNInteraction Return how many interactions + * @param actionListener get all the final interactions that are not traces + */ + public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); + log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); + + ActionListener al = ActionListener.wrap(getInteractionsResponse -> { + actionListener.onResponse(getInteractionsResponse.getInteractions()); + }, e -> { actionListener.onFailure(e); }); + + try { + client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, lastNInteraction), al); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param parentInteractionId the parent interaction id whose traces to get + * @param actionListener get all the trace interactions that are only traces + */ + public void getTraces(String parentInteractionId, ActionListener> actionListener) { + Preconditions.checkNotNull(parentInteractionId); + log.debug("Getting traces for conversationId {}", parentInteractionId); + + ActionListener al = ActionListener.wrap(getTracesResponse -> { + actionListener.onResponse(getTracesResponse.getTraces()); + }, e -> { actionListener.onFailure(e); }); + + try { + client.execute(GetTracesAction.INSTANCE, new GetTracesRequest(parentInteractionId), al); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param interactionId the parent interaction id whose traces to get + * @param actionListener listener for the update response + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener actionListener) { + Preconditions.checkNotNull(interactionId); + Preconditions.checkNotNull(updateContent); + try { + client.execute(UpdateInteractionAction.INSTANCE, new UpdateInteractionRequest(interactionId, updateContent), actionListener); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java new file mode 100644 index 0000000000..b3a5f0da56 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -0,0 +1,277 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; + +public class MLMemoryManagerTests { + + @Mock + Client client; + + @Mock + MLMemoryManager mlMemoryManager; + + @Mock + ActionListener createConversationResponseActionListener; + + @Mock + ActionListener createInteractionResponseActionListener; + + @Mock + ActionListener> interactionListActionListener; + + @Mock + ActionListener updateResponseActionListener; + + String conversationName; + String applicationType; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + mlMemoryManager = new MLMemoryManager(client); + conversationName = "new conversation"; + applicationType = "ml application"; + } + + @Test + public void testCreateConversation() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateConversationRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateConversationResponse("conversation-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + + verify(client, times(1)) + .execute(eq(CreateConversationAction.INSTANCE), captor.capture(), eq(createConversationResponseActionListener)); + assertEquals(conversationName, captor.getValue().getName()); + assertEquals(applicationType, captor.getValue().getApplicationType()); + } + + @Test + public void testCreateConversationFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testCreateInteraction() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateInteractionRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateInteractionResponse("interaction-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + verify(client, times(1)) + .execute(eq(CreateInteractionAction.INSTANCE), captor.capture(), eq(createInteractionResponseActionListener)); + assertEquals("conversationId", captor.getValue().getConversationId()); + assertEquals("input", captor.getValue().getInput()); + assertEquals("prompt", captor.getValue().getPromptTemplate()); + assertEquals("response", captor.getValue().getResponse()); + assertEquals("origin", captor.getValue().getOrigin()); + assertEquals(Collections.singletonMap("feedback", "thumbsup"), captor.getValue().getAdditionalInfo()); + assertEquals("parent-id", captor.getValue().getParentIid()); + assertEquals("1", captor.getValue().getTraceNumber().toString()); + } + + @Test + public void testCreateInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetInteractions() { + List interactions = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetInteractionsResponse getInteractionsResponse = new GetInteractionsResponse(interactions, 4, false); + al.onResponse(getInteractionsResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture(), any()); + assertEquals("cid", captor.getValue().getConversationId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetTraces() { + List traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetTracesRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetTracesResponse getTracesResponse = new GetTracesResponse(traces, 4, false); + al.onResponse(getTracesResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getTraces("iid", interactionListActionListener); + + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetTracesFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getTraces("cid", interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testUpdateInteraction() { + Map updateContent = Map + .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(updateResponse); + return null; + }).when(client).execute(any(), any(), any()); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateInteractionRequest.class); + mlMemoryManager.updateInteraction("iid", updateContent, updateResponseActionListener); + verify(client, times(1)).execute(eq(UpdateInteractionAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(1, captor.getValue().getUpdateContent().keySet().size()); + assertNotEquals(updateContent, captor.getValue().getUpdateContent()); + } + + @Test + public void testUpdateInteraction_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .updateInteraction( + "iid", + Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")), + updateResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 13b4834d59..e986d7e3c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -147,10 +147,16 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesTransportAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsTransportAction; import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; import org.opensearch.ml.memory.action.conversation.SearchInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionTransportAction; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -189,8 +195,11 @@ import org.opensearch.ml.rest.RestMemoryGetConversationsAction; import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; +import org.opensearch.ml.rest.RestMemoryGetTracesAction; import org.opensearch.ml.rest.RestMemorySearchConversationsAction; import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; +import org.opensearch.ml.rest.RestMemoryUpdateConversationAction; +import org.opensearch.ml.rest.RestMemoryUpdateInteractionAction; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -321,7 +330,10 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), - new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class) + new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), + new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), + new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), + new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) ); } @@ -577,6 +589,9 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); + RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); return ImmutableList .of( restMLStatsAction, @@ -615,7 +630,10 @@ public List getRestHandlers( restSearchConversationsAction, restSearchInteractionsAction, restGetConversationAction, - restGetInteractionAction + restGetInteractionAction, + restMemoryUpdateConversationAction, + restMemoryUpdateInteractionAction, + restMemoryGetTracesAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java new file mode 100644 index 0000000000..12c0815cc3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +public class RestMemoryGetTracesAction extends BaseRestHandler { + private final static String GET_TRACES_NAME = "conversational_get_traces"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH)); + } + + @Override + public String getName() { + return GET_TRACES_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetTracesRequest gtRequest = GetTracesRequest.fromRestRequest(request); + return channel -> client.execute(GetTracesAction.INSTANCE, gtRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java new file mode 100644 index 0000000000..c0934056b6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; + +public class RestMemoryUpdateConversationAction extends BaseRestHandler { + private static final String ML_UPDATE_CONVERSATION_ACTION = "ml_update_conversation_action"; + + @Override + public String getName() { + return ML_UPDATE_CONVERSATION_ACTION; + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.PUT, ActionConstants.UPDATE_CONVERSATIONS_REST_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + UpdateConversationRequest updateConversationRequest = getRequest(request); + return restChannel -> client + .execute(UpdateConversationAction.INSTANCE, updateConversationRequest, new RestToXContentListener<>(restChannel)); + } + + @VisibleForTesting + private UpdateConversationRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Failed to update conversation: Request body is empty"); + } + + String conversationId = getParameterId(request, "conversation_id"); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + return UpdateConversationRequest.parse(parser, conversationId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java new file mode 100644 index 0000000000..dafc0352ec --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; + +public class RestMemoryUpdateInteractionAction extends BaseRestHandler { + private static final String ML_UPDATE_INTERACTION_ACTION = "ml_update_interaction_action"; + + @Override + public String getName() { + return ML_UPDATE_INTERACTION_ACTION; + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.PUT, ActionConstants.UPDATE_INTERACTIONS_REST_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + UpdateInteractionRequest updateInteractionRequest = getRequest(request); + return restChannel -> client + .execute(UpdateInteractionAction.INSTANCE, updateInteractionRequest, new RestToXContentListener<>(restChannel)); + } + + @VisibleForTesting + private UpdateInteractionRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Failed to update interaction: Request body is empty"); + } + + String interactionId = getParameterId(request, "interaction_id"); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + return UpdateInteractionRequest.parse(parser, interactionId); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java new file mode 100644 index 0000000000..67a91db6e8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetTracesActionTests extends OpenSearchTestCase { + + public void testBasics() { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + assert (action.getName().equals("conversational_get_traces")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + Map params = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesRequest.class); + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), argCaptor.capture(), any()); + GetTracesRequest req = argCaptor.getValue(); + assert (req.getInteractionId().equals("iid")); + assert (req.getFrom() == 7); + assert (req.getMaxResults() == 2); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java new file mode 100644 index 0000000000..539527bdf5 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMemoryUpdateConversationTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMemoryUpdateConversationAction restMemoryUpdateConversationAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(UpdateConversationAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + assertNotNull(restMemoryUpdateConversationAction); + } + + public void testGetName() { + String actionName = restMemoryUpdateConversationAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_conversation_action", actionName); + } + + public void testRoutes() { + List routes = restMemoryUpdateConversationAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/memory/conversation/{conversation_id}/_update", route.getPath()); + } + + public void testUpdateConversationRequest() throws Exception { + RestRequest request = getRestRequest(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateConversationRequest.class); + verify(client, times(1)).execute(eq(UpdateConversationAction.INSTANCE), argumentCaptor.capture(), any()); + UpdateConversationRequest updateConversationRequest = argumentCaptor.getValue(); + assertEquals("test_conversationId", updateConversationRequest.getConversationId()); + assertEquals("new name", updateConversationRequest.getUpdateContent().get(META_NAME_FIELD)); + } + + public void testUpdateConnectorRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Failed to update conversation: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + } + + public void testUpdateConnectorRequestWithNullConversationId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain conversation_id"); + RestRequest request = getRestRequestWithNullConversationId(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(META_NAME_FIELD, "new name"); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + params.put("conversation_id", "test_conversationId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("conversation_id", "test_conversationId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullConversationId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(META_NAME_FIELD, "new name"); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java new file mode 100644 index 0000000000..cdfdaa2b3c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMemoryUpdateInteractionActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(UpdateInteractionAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + assertNotNull(restMemoryUpdateInteractionAction); + } + + public void testGetName() { + String actionName = restMemoryUpdateInteractionAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_interaction_action", actionName); + } + + public void testRoutes() { + List routes = restMemoryUpdateInteractionAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/memory/interaction/{interaction_id}/_update", route.getPath()); + } + + public void testUpdateInteractionRequest() throws Exception { + RestRequest request = getRestRequest(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateInteractionRequest.class); + verify(client, times(1)).execute(eq(UpdateInteractionAction.INSTANCE), argumentCaptor.capture(), any()); + UpdateInteractionRequest updateInteractionRequest = argumentCaptor.getValue(); + assertEquals("test_interactionId", updateInteractionRequest.getInteractionId()); + assertEquals(Map.of("feedback", "thumbs up!"), updateInteractionRequest.getUpdateContent().get(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + } + + public void testUpdateInteractionRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Failed to update interaction: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + } + + public void testUpdateInteractionRequestWithNullInteractionId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain interaction_id"); + RestRequest request = getRestRequestWithNullInteractionId(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + params.put("interaction_id", "test_interactionId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("interaction_id", "test_interactionId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullInteractionId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +} From bba62af9c28e6a726194b6079f9b8e51b602994c Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 15 Dec 2023 17:22:33 -0800 Subject: [PATCH 5/7] IndicesHandler and conversationIndexMemory (#1762) * memory structures, index handlers Signed-off-by: Jing Zhang * MLIndicesHandler and UT Signed-off-by: Jing Zhang * move indexHandler from plugin to algorithm Signed-off-by: Jing Zhang * address comments Signed-off-by: Jing Zhang * address more comments Signed-off-by: Jing Zhang * revert previous commit Signed-off-by: Jing Zhang * fix buges Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang --- client/build.gradle | 5 +- common/build.gradle | 81 ++++++ memory/build.gradle | 2 +- ml-algorithms/build.gradle | 5 +- .../ml/engine}/indices/MLIndex.java | 16 +- .../ml/engine}/indices/MLIndicesHandler.java | 14 +- .../indices/MLInputDatasetHandler.java | 15 +- .../ml/engine/memory/BaseMessage.java | 46 ++++ .../memory/ConversationIndexMemory.java | 207 +++++++++++++++ .../memory/ConversationIndexMessage.java | 59 +++++ .../engine/indices/MLIndicesHandlerTest.java | 194 ++++++++++++++ .../ml/engine/memory/BaseMessageTest.java | 29 ++ .../memory/ConversationIndexMemoryTest.java | 248 ++++++++++++++++++ .../memory/ConversationIndexMessageTest.java | 48 ++++ .../ml/engine/memory/MLMemoryManagerTest.java | 124 +++++++++ plugin/build.gradle | 2 +- .../TransportCreateConnectorAction.java | 2 +- .../TransportRegisterModelGroupAction.java | 2 +- .../TransportRegisterModelAction.java | 2 +- .../upload_chunk/MLModelChunkUploader.java | 2 +- .../MLCommonsClusterManagerEventListener.java | 2 +- .../opensearch/ml/cluster/MLSyncUpCron.java | 2 +- .../ml/model/MLModelGroupManager.java | 2 +- .../opensearch/ml/model/MLModelManager.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 4 +- .../ml/task/MLExecuteTaskRunner.java | 2 +- .../ml/task/MLPredictTaskRunner.java | 2 +- .../org/opensearch/ml/task/MLTaskManager.java | 2 +- .../ml/task/MLTrainAndPredictTaskRunner.java | 2 +- .../ml/task/MLTrainingTaskRunner.java | 4 +- .../TransportCreateConnectorActionTests.java | 2 +- ...ransportRegisterModelGroupActionTests.java | 2 +- .../TransportRegisterModelActionTests.java | 2 +- .../MLModelChunkUploaderTests.java | 2 +- .../ml/cluster/MLSyncUpCronTests.java | 2 +- .../ml/indices/MLIndicesHandlerTests.java | 199 -------------- .../indices/MLInputDatasetHandlerTests.java | 164 ------------ .../ml/model/MLModelGroupManagerTests.java | 2 +- .../ml/model/MLModelManagerTests.java | 2 +- .../ml/task/MLExecuteTaskRunnerTests.java | 2 +- .../ml/task/MLPredictTaskRunnerTests.java | 8 +- .../ml/task/MLTaskManagerTests.java | 2 +- .../MLTrainAndPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainingTaskRunnerTests.java | 4 +- .../org/opensearch/ml/utils/MockHelper.java | 2 +- search-processors/build.gradle | 4 +- 46 files changed, 1103 insertions(+), 425 deletions(-) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLIndex.java (67%) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLIndicesHandler.java (94%) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLInputDatasetHandler.java (83%) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java delete mode 100644 plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java delete mode 100644 plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java diff --git a/client/build.gradle b/client/build.gradle index 7e89a3f117..a302842106 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -15,7 +15,7 @@ plugins { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.3.1' @@ -122,4 +122,5 @@ publishing { } } - +compileJava.dependsOn(':opensearch-ml-common:shadowJar') +delombok.dependsOn(':opensearch-ml-common:shadowJar') diff --git a/common/build.gradle b/common/build.gradle index 532e010373..25ca4832b2 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -6,8 +6,11 @@ //TODO: cleanup gradle config file, some overlap plugins { id 'java' + id 'com.github.johnrengelman.shadow' id 'jacoco' id "io.freefair.lombok" + id 'maven-publish' + id 'signing' } dependencies { @@ -21,6 +24,15 @@ dependencies { compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly group: 'org.json', name: 'json', version: '20231013' + + implementation('com.google.guava:guava:32.1.2-jre') { + exclude group: 'com.google.guava', module: 'failureaccess' + exclude group: 'com.google.code.findbugs', module: 'jsr305' + exclude group: 'org.checkerframework', module: 'checker-qual' + exclude group: 'com.google.errorprone', module: 'error_prone_annotations' + exclude group: 'com.google.j2objc', module: 'j2objc-annotations' + exclude group: 'com.google.guava', module: 'listenablefuture' + } } lombok { @@ -54,5 +66,74 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification +shadowJar { + destinationDirectory = file("${project.buildDir}/distributions") + archiveClassifier.set(null) + exclude 'META-INF/maven/com.google.guava/**' + exclude 'com/google/thirdparty/**' + relocate 'com.google.common', 'org.opensearch.ml.repackage.com.google.common' // dependency of cron-utils +} +jar { + enabled false +} +task sourcesJar(type: Jar) { + archiveClassifier.set 'sources' + from sourceSets.main.allJava +} + +task javadocJar(type: Jar) { + archiveClassifier.set 'javadoc' + from javadoc.destinationDir + dependsOn javadoc +} + +publishing { + repositories { + maven { + name = 'staging' + url = "${rootProject.buildDir}/local-staging-repo" + } + maven { + name = "Snapshots" // optional target repository name + url = "https://aws.oss.sonatype.org/content/repositories/snapshots" + credentials { + username "$System.env.SONATYPE_USERNAME" + password "$System.env.SONATYPE_PASSWORD" + } + } + } + publications { + shadow(MavenPublication) { publication -> + project.shadow.component(publication) + artifact sourcesJar + artifact javadocJar + + pom { + name = "OpenSearch ML Commons Comm" + packaging = "jar" + url = "https://github.com/opensearch-project/ml-commons" + description = "OpenSearch ML Common" + scm { + connection = "scm:git@github.com:opensearch-project/ml-commons.git" + developerConnection = "scm:git@github.com:opensearch-project/ml-commons.git" + url = "git@github.com:opensearch-project/ml-commons.git" + } + licenses { + license { + name = "The Apache License, Version 2.0" + url = "http://www.apache.org/licenses/LICENSE-2.0.txt" + } + } + developers { + developer { + name = "OpenSearch" + url = "https://github.com/opensearch-project/ml-commons" + } + } + } + } + } +} +publishShadowPublicationToMavenLocal.mustRunAfter shadowJar \ No newline at end of file diff --git a/memory/build.gradle b/memory/build.gradle index 84b6947c3f..4bb4c4dbd5 100644 --- a/memory/build.gradle +++ b/memory/build.gradle @@ -24,7 +24,7 @@ plugins { } dependencies { - implementation project(":opensearch-ml-common") + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.2' implementation "org.opensearch:common-utils:${common_utils_version}" diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index e2e37c2212..cd79560e90 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -17,10 +17,10 @@ repositories { } dependencies { - compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation project(':opensearch-ml-memory') + compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" @@ -103,6 +103,7 @@ jacocoTestCoverageVerification { dependsOn jacocoTestReport } check.dependsOn jacocoTestCoverageVerification +compileJava.dependsOn(':opensearch-ml-common:shadowJar') spotless { java { diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java similarity index 67% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java index b81682f07e..671f4e548a 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java @@ -3,14 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -26,7 +35,10 @@ public enum MLIndex { MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), - CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION); + CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION), + AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION), + MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING, ML_MEMORY_META_INDEX_SCHEMA_VERSION), + MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING, ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java similarity index 94% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index d278fa6415..ca5f88be78 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import static org.opensearch.ml.common.CommonValue.META; import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; @@ -62,10 +62,22 @@ public void initMLConnectorIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); } + public void initMemoryMetaIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_META, listener); + } + + public void initMemoryMessageIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_MESSAGE, listener); + } + public void initMLConfigIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONFIG, listener); } + public void initMLAgentIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.AGENT, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java similarity index 83% rename from plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java index 1dcf6bdf77..452f836357 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import java.util.ArrayList; import java.util.List; @@ -35,19 +35,6 @@ public class MLInputDatasetHandler { Client client; - // /** - // * Retrieve DataFrame from DataFrameInputDataset - // * @param mlInputDataset MLInputDataset - // * @return DataFrame - // */ - // public DataFrame parseDataFrameInput(MLInputDataset mlInputDataset) { - // if (!mlInputDataset.getInputDataType().equals(MLInputDataType.DATA_FRAME)) { - // throw new IllegalArgumentException("Input dataset is not DATA_FRAME type."); - // } - // DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInputDataset; - // return inputDataset.getDataFrame(); - // } - /** * Create DataFrame based on given search query * @param mlInputDataset MLInputDataset diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java new file mode 100644 index 0000000000..05b3185a34 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.memory.Message; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +public class BaseMessage implements Message, ToXContentObject { + + @Getter + @Setter + protected String type; + @Getter + @Setter + protected String content; + + @Builder + public BaseMessage(String type, String content) { + this.type = type; + this.content = content; + } + + @Override + public String toString() { + return type + ": " + content; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("type", type); + builder.field("content", content); + builder.endObject(); + return builder; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java new file mode 100644 index 0000000000..8dcbe050bb --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; + +import java.util.Map; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Getter +public class ConversationIndexMemory implements Memory { + public static final String TYPE = "conversation_index"; + public static final String CONVERSATION_ID = "conversation_id"; + public static final String FINAL_ANSWER = "final_answer"; + public static final String CREATED_TIME = "created_time"; + public static final String MEMORY_NAME = "memory_name"; + public static final String MEMORY_ID = "memory_id"; + public static final String APP_TYPE = "app_type"; + public static int LAST_N_INTERACTIONS = 10; + protected String memoryMetaIndexName; + protected String memoryMessageIndexName; + protected String conversationId; + protected boolean retrieveFinalAnswer = true; + protected final Client client; + private final MLIndicesHandler mlIndicesHandler; + private MLMemoryManager memoryManager; + + public ConversationIndexMemory( + Client client, + MLIndicesHandler mlIndicesHandler, + String memoryMetaIndexName, + String memoryMessageIndexName, + String conversationId, + MLMemoryManager memoryManager + ) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.memoryMetaIndexName = memoryMetaIndexName; + this.memoryMessageIndexName = memoryMessageIndexName; + this.conversationId = conversationId; + this.memoryManager = memoryManager; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public void save(String id, Message message) { + this.save(id, message, ActionListener.wrap(r -> { log.info("saved message into {} memory, session id: {}", TYPE, id); }, e -> { + log.error("Failed to save message to memory", e); + })); + } + + @Override + public void save(String id, Message message, ActionListener listener) { + mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { + if (created) { + IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName); + ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + indexRequest.source(builder); + client.index(indexRequest, listener); + } else { + listener.onFailure(new RuntimeException("Failed to create memory message index")); + } + }, e -> { listener.onFailure(new RuntimeException("Failed to create memory message index", e)); })); + } + + public void save(Message message, String parentId, Integer traceNum, String action) { + this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { + log + .info( + "saved message into memory {}, parent id: {}, trace number: {}, interaction id: {}", + conversationId, + parentId, + traceNum, + r.getId() + ); + }, e -> { log.error("Failed to save interaction", e); })); + } + + public void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) { + ConversationIndexMessage msg = (ConversationIndexMessage) message; + memoryManager + .createInteraction(conversationId, msg.getQuestion(), null, msg.getResponse(), action, null, parentId, traceNum, listener); + } + + @Override + public void getMessages(String id, ActionListener listener) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(memoryMessageIndexName); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.size(10000); + QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(CONVERSATION_ID, id); + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(sessionIdQueryBuilder); + + if (retrieveFinalAnswer) { + QueryBuilder finalAnswerQueryBuilder = new TermQueryBuilder(FINAL_ANSWER, true); + boolQueryBuilder.must(finalAnswerQueryBuilder); + } + + sourceBuilder.query(boolQueryBuilder); + sourceBuilder.sort(CREATED_TIME, SortOrder.ASC); + searchRequest.source(sourceBuilder); + client.search(searchRequest, listener); + } + + public void getMessages(ActionListener listener) { + memoryManager.getFinalInteractions(conversationId, LAST_N_INTERACTIONS, listener); + } + + @Override + public void clear() { + throw new RuntimeException("clear method is not supported in ConversationIndexMemory"); + } + + @Override + public void remove(String id) { + throw new RuntimeException("remove method is not supported in ConversationIndexMemory"); + } + + public static class Factory implements Memory.Factory { + private Client client; + private MLIndicesHandler mlIndicesHandler; + private String memoryMetaIndexName = ML_MEMORY_META_INDEX; + private String memoryMessageIndexName = ML_MEMORY_MESSAGE_INDEX; + private MLMemoryManager memoryManager; + + public void init(Client client, MLIndicesHandler mlIndicesHandler, MLMemoryManager memoryManager) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.memoryManager = memoryManager; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map == null || map.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating ConversationIndexMemory")); + return; + } + + String memoryId = (String) map.get(MEMORY_ID); + String name = (String) map.get(MEMORY_NAME); + String appType = (String) map.get(APP_TYPE); + create(name, memoryId, appType, listener); + } + + public void create(String name, String memoryId, String appType, ActionListener listener) { + if (Strings.isEmpty(memoryId)) { + memoryManager.createConversation(name, appType, ActionListener.wrap(r -> { + create(r.getId(), listener); + log.debug("Created conversation on memory layer, conversation id: {}", r.getId()); + }, e -> { + log.error("Failed to save interaction", e); + listener.onFailure(e); + })); + } else { + create(memoryId, listener); + } + } + + public void create(String memoryId, ActionListener listener) { + listener + .onResponse( + new ConversationIndexMemory( + client, + mlIndicesHandler, + memoryMetaIndexName, + memoryMessageIndexName, + memoryId, + memoryManager + ) + ); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java new file mode 100644 index 0000000000..2a084ee9b9 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Data; + +@Data +public class ConversationIndexMessage extends BaseMessage { + + private String sessionId; + private String question; + private String response; + private Boolean finalAnswer; + private Instant createdTime; + + @Builder(builderMethodName = "conversationIndexMessageBuilder") + public ConversationIndexMessage(String type, String sessionId, String question, String response, boolean finalAnswer) { + super(type, response); + this.sessionId = sessionId; + this.question = question; + this.response = response; + this.finalAnswer = finalAnswer; + this.createdTime = Instant.now(); + } + + @Override + public String toString() { + return "Human:" + question + "\nAI:" + response; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (sessionId != null) { + builder.field("session_id", sessionId); + } + if (question != null) { + builder.field("question", question); + } + if (response != null) { + builder.field("response", response); + } + if (finalAnswer != null) { + builder.field("final_answer", finalAnswer); + } + builder.field("created_time", createdTime); + builder.endObject(); + return builder; + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java new file mode 100644 index 0000000000..5ca7e2d31a --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java @@ -0,0 +1,194 @@ +package org.opensearch.ml.engine.indices; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.META; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; + +public class MLIndicesHandlerTest { + + @Mock + Client client; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + IndexMetadata indexMetadata; + + @Mock + MappingMetadata mappingMetadata; + + @Mock + private ThreadPool threadPool; + + Settings settings; + ThreadContext threadContext; + MLIndicesHandler indicesHandler; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + doNothing().when(client).execute(any(), any(), any()); + doNothing().when(client).update(any(), any()); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doNothing().when(indicesAdminClient).create(any(), any()); + doNothing().when(indicesAdminClient).refresh(any(), any()); + doNothing().when(indicesAdminClient).putMapping(any(), any()); + doNothing().when(indicesAdminClient).updateSettings(any(), any()); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + when(metadata.indices()).thenReturn(Map.of(ML_AGENT_INDEX, indexMetadata, ML_MEMORY_META_INDEX, indexMetadata)); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + when(mappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(1)))); + settings = Settings.builder().put("test_key", 10).build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + indicesHandler = new MLIndicesHandler(clusterService, client); + } + + @Test + public void initMemoryMetaIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMetaIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMetaIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_MEMORY_META_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMetaIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMessageIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMessageIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMessageIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_MEMORY_MESSAGE_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMessageIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMLAgentIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMLAgentIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMLAgentIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_AGENT_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMLAgentIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java new file mode 100644 index 0000000000..b66fd502ed --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java @@ -0,0 +1,29 @@ +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class BaseMessageTest { + + @Test + public void testToString() { + BaseMessage message = new BaseMessage("test", "test"); + Assert.assertEquals("test: test", message.toString()); + } + + @Test + public void toXContent() throws IOException { + BaseMessage baseMessage = new BaseMessage("test", "test"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + baseMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = BytesReference.bytes(builder).utf8ToString(); + + Assert.assertEquals("{\"type\":\"test\",\"content\":\"test\"}", content); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java new file mode 100644 index 0000000000..e186521d9c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java @@ -0,0 +1,248 @@ +package org.opensearch.ml.engine.memory; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; + +public class ConversationIndexMemoryTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Mock + Client client; + + @Mock + MLIndicesHandler indicesHandler; + + @Mock + MLMemoryManager memoryManager; + + ConversationIndexMemory indexMemory; + ConversationIndexMemory.Factory memoryFactory; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + indexMemory = new ConversationIndexMemory(client, indicesHandler, "test", "test", "test", memoryManager); + doNothing().when(client).index(any(), any()); + doNothing().when(client).search(any(), any()); + doNothing().when(client).get(any(), any()); + doNothing().when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + doNothing().when(memoryManager).getFinalInteractions(any(), anyInt(), any()); + doNothing().when(memoryManager).createConversation(any(), any(), any()); + doNothing().when(indicesHandler).initMemoryMetaIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + memoryFactory = new ConversationIndexMemory.Factory(); + memoryFactory.init(client, indicesHandler, memoryManager); + } + + @Test + public void getType() { + Assert.assertEquals(indexMemory.getType(), ConversationIndexMemory.TYPE); + } + + @Test + public void save() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + + verify(indicesHandler).initMemoryMessageIndex(any()); + } + + @Test + public void save4() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException()); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + + verify(indicesHandler).initMemoryMessageIndex(any()); + } + + @Test + public void save1() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(8); + listener.onResponse(new CreateInteractionResponse("interaction_id")); + return null; + }).when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + indexMemory.save(new ConversationIndexMessage("test", "123", "question", "response", false), "parent_id", 0, "action"); + + verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void save6() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(8); + listener.onFailure(new RuntimeException()); + return null; + }).when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + indexMemory.save(new ConversationIndexMessage("test", "123", "question", "response", false), "parent_id", 0, "action"); + + verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void save2() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(Boolean.TRUE); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + return null; + }).when(client).index(any(), any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onResponse(isA(IndexResponse.class)); + } + + @Test + public void save3() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException()); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onFailure(isA(RuntimeException.class)); + } + + @Test + public void save5() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(Boolean.FALSE); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + return null; + }).when(client).index(any(), any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onFailure(isA(RuntimeException.class)); + } + + @Test + public void getMessages() { + ActionListener listener = mock(ActionListener.class); + indexMemory.getMessages("test_id", listener); + } + + @Test + public void getMessages1() { + ActionListener listener = mock(ActionListener.class); + indexMemory.getMessages(listener); + } + + @Test + public void clear() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("clear method is not supported in ConversationIndexMemory"); + indexMemory.clear(); + } + + @Test + public void remove() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); + indexMemory.remove("test_id"); + } + + @Test + public void factory_create_emptyMap() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(), listener); + + verify(listener).onFailure(isA(IllegalArgumentException.class)); + } + + @Test + public void factory_create() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + } + + @Test + public void factory_create_only_memory_id() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_ID, "123"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + } + + @Test + public void factory_create_empty_memory_id() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(new CreateConversationResponse("interaction_id")); + return null; + }).when(memoryManager).createConversation(any(), any(), any()); + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + verify(memoryManager).createConversation(any(), any(), any()); + } + + @Test + public void factory_create_empty_memory_id_failure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(memoryManager).createConversation(any(), any(), any()); + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onFailure(isA(RuntimeException.class)); + verify(memoryManager).createConversation(any(), any(), any()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java new file mode 100644 index 0000000000..9e91695a5b --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java @@ -0,0 +1,48 @@ +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class ConversationIndexMessageTest { + + ConversationIndexMessage message; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type("test") + .sessionId("123") + .question("question") + .response("response") + .finalAnswer(false) + .build(); + } + + @Test + public void testToString() { + Assert.assertEquals("Human:question\nAI:response", message.toString()); + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + message.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = BytesReference.bytes(builder).utf8ToString(); + + Assert.assertTrue(content.contains("\"session_id\":\"123\"")); + Assert.assertTrue(content.contains("\"question\":\"question\"")); + Assert.assertTrue(content.contains("\"response\":\"response\"")); + Assert.assertTrue(content.contains("\"final_answer\":false")); + Assert.assertTrue(content.contains("\"created_time\":")); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java new file mode 100644 index 0000000000..a21a3ed60d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java @@ -0,0 +1,124 @@ +package org.opensearch.ml.engine.memory; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.index.ConversationMetaIndex; +import org.opensearch.threadpool.ThreadPool; + +public class MLMemoryManagerTest { + + @Mock + Client client; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + ConversationMetaIndex conversationMetaIndex; + + @Mock + private ThreadPool threadPool; + + MLMemoryManager memoryManager; + Settings settings; + ThreadContext threadContext; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + memoryManager = new MLMemoryManager(client); + doNothing().when(client).execute(any(), any(), any()); + doNothing().when(client).update(any(), any()); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doNothing().when(indicesAdminClient).refresh(any(), any()); + doNothing().when(conversationMetaIndex).checkAccess(any(), any()); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + settings = Settings.builder().put("test_key", 10).build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + @Test + public void createConversation() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createConversation("test", "test", actionListener); + } + + @Test + public void createInteraction() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createInteraction("test", "test", "test", "test", "test", Map.of("feedback", "1"), "test", 0, actionListener); + } + + @Test + public void createInteractionNullAdditionalInfo() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createInteraction("test", "test", "test", "test", "test", null, "test", 0, actionListener); + } + + @Test + public void getFinalInteractions() { + ActionListener> actionListener = mock(ActionListener.class); + memoryManager.getFinalInteractions("test", 1, actionListener); + } + + @Test + public void getTracesIndex() { + ActionListener> actionListener = mock(ActionListener.class); + memoryManager.getTraces("test", actionListener); + } + + @Test + public void getTracesNoIndex() { + ActionListener> actionListener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + memoryManager.getTraces("test", actionListener); + } + + @Test + public void updateInteraction() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.updateInteraction("test", Map.of("feedback", "1"), actionListener); + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 042ac13423..a6fbbf1851 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -47,7 +47,7 @@ opensearchplugin { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation project(':opensearch-ml-algorithms') implementation project(':opensearch-ml-search-processors') implementation project(':opensearch-ml-memory') diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index e40bacc207..4cadcc936a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.exceptions.MetaDataException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java index 94d4b5a8a7..4e29db680d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java @@ -17,8 +17,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 8aca5cb140..0261d55b24 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -50,9 +50,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index 1227703a21..683536e21e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -34,8 +34,8 @@ import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import lombok.extern.log4j.Log4j2; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index c3012fdade..327ae2ddae 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -15,7 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 3a5ea83347..12e37b7b4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 83523729e4..2187a4577e 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -34,8 +34,8 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 2402374a99..20286fd3c5 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -111,8 +111,8 @@ import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.utils.FileUtils; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e986d7e3c1..31197cd6be 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -128,10 +128,10 @@ import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 1498b39d07..3e82e7a20e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 348618773a..92e05a5ba9 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -48,7 +48,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index f1feb4ec32..9e9dea5d22 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 9461c3adaf..fe78e88d1e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -29,7 +29,7 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 711a94171f..88366b17f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 9fcc89d701..e16400bc56 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -41,8 +41,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java index 269ac30d95..f54f034b64 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java @@ -29,8 +29,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 0222b4efe1..b40a278289 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -54,9 +54,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 6fba3efe59..292183f8ed 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -39,8 +39,8 @@ import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 2242673a5e..b9f169e3d2 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -70,7 +70,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java deleted file mode 100644 index 9acb84633a..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.indices; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.META; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX_SCHEMA_VERSION; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_SCHEMA_VERSION; -import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; - -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ExecutionException; - -import org.junit.Before; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.client.AdminClient; -import org.opensearch.client.Client; -import org.opensearch.client.IndicesAdminClient; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.action.ActionListener; -import org.opensearch.test.OpenSearchIntegTestCase; - -@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) -public class MLIndicesHandlerTests extends OpenSearchIntegTestCase { - ClusterService clusterService; - Client client; - MLIndicesHandler mlIndicesHandler; - - String OLD_ML_MODEL_INDEX_MAPPING_V0 = "{\n" - + " \"properties\": {\n" - + " \"task_id\": { \"type\": \"keyword\" },\n" - + " \"algorithm\": {\"type\": \"keyword\"},\n" - + " \"model_name\" : { \"type\": \"keyword\"},\n" - + " \"model_version\" : { \"type\": \"keyword\"},\n" - + " \"model_content\" : { \"type\": \"binary\"}\n" - + " }\n" - + "}"; - - String OLD_ML_TASK_INDEX_MAPPING_V0 = "{\n" - + " \"properties\": {\n" - + " \"model_id\": {\"type\": \"keyword\"},\n" - + " \"task_type\": {\"type\": \"keyword\"},\n" - + " \"function_name\": {\"type\": \"keyword\"},\n" - + " \"state\": {\"type\": \"keyword\"},\n" - + " \"input_type\": {\"type\": \"keyword\"},\n" - + " \"progress\": {\"type\": \"float\"},\n" - + " \"output_index\": {\"type\": \"keyword\"},\n" - + " \"worker_node\": {\"type\": \"keyword\"},\n" - + " \"create_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"last_update_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"error\": {\"type\": \"text\"},\n" - + " \"user\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n" - + " }\n" - + "}";; - - @Before - public void setup() { - clusterService = clusterService(); - client = client(); - mlIndicesHandler = new MLIndicesHandler(clusterService, client); - } - - public void testInitMLTaskIndex() { - ActionListener listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initMLTaskIndex(listener); - } - - public void testInitMLTaskIndexWithExistingIndex() throws ExecutionException, InterruptedException { - CreateIndexRequest request = new CreateIndexRequest(ML_TASK_INDEX).mapping(ML_TASK_INDEX_MAPPING); - client.admin().indices().create(request).get(); - testInitMLTaskIndex(); - } - - public void testInitMLModelIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { - testInitMLIndexIfAbsentWithExistingIndex(ML_MODEL_INDEX, OLD_ML_MODEL_INDEX_MAPPING_V0, ML_MODEL_INDEX_SCHEMA_VERSION); - } - - public void testInitMLTaskIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { - testInitMLIndexIfAbsentWithExistingIndex(ML_TASK_INDEX, OLD_ML_TASK_INDEX_MAPPING_V0, ML_TASK_INDEX_SCHEMA_VERSION); - } - - private void testInitMLIndexIfAbsentWithExistingIndex(String indexName, String oldIndexMapping, int schemaVersion) - throws ExecutionException, - InterruptedException, - IOException { - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(oldIndexMapping); - client.admin().indices().create(request).get(); - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertTrue(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - assertNull(getIndexSchemaVersion(indexName)); - ActionListener listener = ActionListener.wrap(r -> { - assertTrue(r); - Integer indexSchemaVersion = getIndexSchemaVersion(indexName); - if (indexSchemaVersion != null) { - assertEquals(schemaVersion, indexSchemaVersion.intValue()); - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - } - }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex() { - ActionListener listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex_Exception() { - Client mockClient = mock(Client.class); - Object[] objects = setUpMockClient(mockClient); - IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; - MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; - String errorMessage = "test exception"; - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new RuntimeException(errorMessage)); - return null; - }).when(adminClient).create(any(), any()); - ActionListener listener = ActionListener.wrap(r -> { throw new RuntimeException("unexpected result"); }, e -> { - assertEquals(errorMessage, e.getMessage()); - }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - - when(mockClient.threadPool()).thenThrow(new RuntimeException(errorMessage)); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex_FalseAcknowledge() { - Client mockClient = mock(Client.class); - Object[] objects = setUpMockClient(mockClient); - IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; - MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - CreateIndexResponse response = new CreateIndexResponse(false, false, ML_MODEL_INDEX); - actionListener.onResponse(response); - return null; - }).when(adminClient).create(any(), any()); - ActionListener listener = ActionListener.wrap(r -> { assertFalse(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - private Object[] setUpMockClient(Client mockClient) { - AdminClient admin = spy(client.admin()); - when(mockClient.admin()).thenReturn(admin); - IndicesAdminClient adminClient = spy(client.admin().indices()); - - MLIndicesHandler mlIndicesHandler = new MLIndicesHandler(clusterService, mockClient); - when(admin.indices()).thenReturn(adminClient); - - when(mockClient.threadPool()).thenReturn(client.threadPool()); - - return new Object[] { adminClient, mlIndicesHandler }; - } - - private Integer getIndexSchemaVersion(String indexName) { - IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { - return null; - } - Integer oldVersion = null; - Map indexMapping = indexMetaData.mapping().getSourceAsMap(); - Object meta = indexMapping.get(META); - if (meta != null && meta instanceof Map) { - Map metaMapping = (Map) meta; - Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); - if (schemaVersion instanceof Integer) { - oldVersion = (Integer) schemaVersion; - } - } - return oldVersion; - } -} diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java deleted file mode 100644 index 5ec2ab686c..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.indices; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.lucene.search.TotalHits; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.ExpectedException; -import org.mockito.ArgumentCaptor; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.test.OpenSearchTestCase; - -public class MLInputDatasetHandlerTests extends OpenSearchTestCase { - Client client; - MLInputDatasetHandler mlInputDatasetHandler; - ActionListener listener; - DataFrame dataFrame; - SearchResponse searchResponse; - - @Rule - public ExpectedException expectedEx = ExpectedException.none(); - - @Before - public void setup() { - Map source = new HashMap<>(); - source.put("taskId", "111"); - List> mapList = new ArrayList<>(); - mapList.add(source); - dataFrame = DataFrameBuilder.load(mapList); - client = mock(Client.class); - mlInputDatasetHandler = new MLInputDatasetHandler(client); - listener = spy(new ActionListener() { - @Override - public void onResponse(MLInputDataset inputDataset) {} - - @Override - public void onFailure(Exception e) {} - }); - - } - - @SuppressWarnings("unchecked") - public void testSearchQueryInputDatasetWithHits() { - searchResponse = mock(SearchResponse.class); - BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}"); - SearchHit hit = new SearchHit(1); - hit.sourceRef(bytesArray); - SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); - when(searchResponse.getHits()).thenReturn(hits); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(MLInputDataset.class); - verify(listener, times(1)).onResponse(captor.capture()); - Assert.assertEquals(captor.getAllValues().size(), 1); - } - - @SuppressWarnings("unchecked") - public void testSearchQueryInputDatasetWithoutHits() { - searchResponse = mock(SearchResponse.class); - SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); - when(searchResponse.getHits()).thenReturn(hits); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWithNullHits() { - searchResponse = mock(SearchResponse.class); - when(searchResponse.getHits()).thenReturn(null); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWithNullResponse() { - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(null); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWrongType() { - expectedEx.expect(IllegalArgumentException.class); - expectedEx.expectMessage("Input dataset is not SEARCH_QUERY type."); - DataFrame testDataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("key1", 2.0D); - } - })); - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(testDataFrame).build(); - mlInputDatasetHandler.parseSearchQueryInput(dataFrameInputDataset, listener); - } - -} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index ccedef9bc1..ce01b44026 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -42,8 +42,8 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 598d2db5a7..2f8ef74f66 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -104,7 +104,7 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 11e9bc3441..9011746797 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -41,7 +41,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 0d0c594458..13de526978 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -55,7 +55,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; @@ -214,7 +214,6 @@ public void testExecuteTask_OnLocalNode() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -237,7 +236,6 @@ public void testExecuteTask_OnLocalNode_QueryInput() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -248,7 +246,6 @@ public void testExecuteTask_OnLocalNode_QueryInput_Failure() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); verify(client, never()).get(any(), any()); } @@ -277,7 +274,6 @@ public void testExecuteTask_OnLocalNode_GetModelFail() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -291,7 +287,6 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client, never()).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -305,7 +300,6 @@ public void testExecuteTask_OnLocalNode_NullGetResponse() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index 69b5f613b8..ab5f82734e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -39,7 +39,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index a40c5c87cf..ff7c963e8a 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -49,7 +49,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index ae397067bc..943bd5740d 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -52,8 +52,8 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java index 497a7c4229..78141b1781 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java @@ -28,7 +28,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; public class MockHelper { diff --git a/search-processors/build.gradle b/search-processors/build.gradle index cd5a28656d..8055cef61f 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -28,12 +28,10 @@ repositories { } dependencies { - + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' implementation 'org.apache.commons:commons-lang3:3.12.0' - //implementation project(':opensearch-ml-client') - implementation project(':opensearch-ml-common') implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}" // https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5 From bab9439e17c98429f4bd9f1ac853c1db20ef3219 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 15 Dec 2023 21:39:00 -0700 Subject: [PATCH 6/7] adding mlmodeltool and agent tool with tests (#1768) * adding mlmodeltool and agent tool with tests Signed-off-by: Dhrubo Saha * updating tests Signed-off-by: Dhrubo Saha * removed connector Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha --- .../opensearch/ml/common/FunctionName.java | 3 +- .../input/execute/agent/AgentMLInput.java | 75 +++++++++ .../prediction/MLPredictionTaskRequest.java | 4 + .../execute/agent/AgentMLInputTests.java | 112 ++++++++++++++ .../MLPredictionTaskRequestTest.java | 13 +- .../opensearch/ml/engine/tools/AgentTool.java | 128 ++++++++++++++++ .../ml/engine/tools/MLModelTool.java | 143 ++++++++++++++++++ .../ml/engine/tools/AgentToolTests.java | 127 ++++++++++++++++ .../ml/engine/tools/MLModelToolTests.java | 127 ++++++++++++++++ 9 files changed, 730 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java create mode 100644 common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 72810459a4..6eff55156d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -24,7 +24,8 @@ public enum FunctionName { SPARSE_ENCODING, SPARSE_TOKENIZE, METRICS_CORRELATION, - REMOTE; + REMOTE, + AGENT; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java new file mode 100644 index 0000000000..3aa3ac382b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.agent; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.AGENT}) +public class AgentMLInput extends MLInput { + public static final String AGENT_ID_FIELD = "agent_id"; + public static final String PARAMETERS_FIELD = "parameters"; + + @Getter @Setter + private String agentId; + + @Builder(builderMethodName = "AgentMLInputBuilder") + public AgentMLInput(String agentId, FunctionName functionName, MLInputDataset inputDataset) { + this.agentId = agentId; + this.algorithm = functionName; + this.inputDataset = inputDataset; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(agentId); + } + + public AgentMLInput(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case AGENT_ID_FIELD: + agentId = parser.text(); + break; + case PARAMETERS_FIELD: + Map parameters = StringUtils.getParameterMap(parser.map()); + inputDataset = new RemoteInferenceInputDataSet(parameters); + break; + default: + parser.skipChildren(); + break; + } + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 963892215f..8060b1c6af 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -47,6 +47,10 @@ public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatch this.user = user; } + public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + this(modelId, mlInput, true, null); + } + public MLPredictionTaskRequest(String modelId, MLInput mlInput, User user) { this(modelId, mlInput, true, user); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java new file mode 100644 index 0000000000..36235adffe --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.agent; + +import org.junit.Test; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AgentMLInputTests { + + @Test + public void testConstructorWithAgentIdFunctionNameAndDataset() { + // Arrange + String agentId = "testAgentId"; + FunctionName functionName = FunctionName.AGENT; // Assuming FunctionName is an enum or similar + MLInputDataset dataset = mock(MLInputDataset.class); // Mock the MLInputDataset + + // Act + AgentMLInput input = new AgentMLInput(agentId, functionName, dataset); + + // Assert + assertEquals(agentId, input.getAgentId()); + assertEquals(functionName, input.getAlgorithm()); + assertEquals(dataset, input.getInputDataset()); + } + + @Test + public void testWriteTo() throws IOException { + // Arrange + String agentId = "testAgentId"; + AgentMLInput input = new AgentMLInput(agentId, FunctionName.AGENT, null); + StreamOutput out = mock(StreamOutput.class); + + // Act + input.writeTo(out); + + // Assert + verify(out).writeString(agentId); + } + + @Test + public void testConstructorWithStreamInput() throws IOException { + // Arrange + String agentId = "testAgentId"; + StreamInput in = mock(StreamInput.class); + when(in.readString()).thenReturn(agentId); + + // Act + AgentMLInput input = new AgentMLInput(in); + + // Assert + assertEquals(agentId, input.getAgentId()); + } + + @Test + public void testConstructorWithXContentParser() throws IOException { + // Arrange + XContentParser parser = mock(XContentParser.class); + + // Simulate parser behavior for START_OBJECT token + when(parser.currentToken()).thenReturn(XContentParser.Token.START_OBJECT); + when(parser.nextToken()).thenReturn(XContentParser.Token.FIELD_NAME) + .thenReturn(XContentParser.Token.VALUE_STRING) + .thenReturn(XContentParser.Token.FIELD_NAME) // For PARAMETERS_FIELD + .thenReturn(XContentParser.Token.START_OBJECT) // Start of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.FIELD_NAME) // Key in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.VALUE_STRING) // Value in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT) // End of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT); // End of the main object + + // Simulate parser behavior for agent_id + when(parser.currentName()).thenReturn("agent_id") + .thenReturn("parameters") + .thenReturn("paramKey"); + when(parser.text()).thenReturn("testAgentId") + .thenReturn("paramValue"); + + // Simulate parser behavior for parameters + Map paramMap = new HashMap<>(); + paramMap.put("paramKey", "paramValue"); + when(parser.map()).thenReturn(paramMap); + + // Act + AgentMLInput input = new AgentMLInput(parser, FunctionName.AGENT); + + // Assert + assertEquals("testAgentId", input.getAgentId()); + assertNotNull(input.getInputDataset()); + assertTrue(input.getInputDataset() instanceof RemoteInferenceInputDataSet); + // Additional assertions for RemoteInferenceInputDataSet + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) input.getInputDataset(); + assertEquals("paramValue", dataset.getParameters().get("paramKey")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index ce96aa56c1..b9cbe7d700 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -16,6 +16,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.dataframe.ColumnType; @@ -53,9 +54,11 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); @@ -73,13 +76,18 @@ public void writeTo_Success() throws IOException { assertEquals(1, dataFrame.getRow(0).size()); assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue()); + User userExpect = request.getUser(); + assertEquals(user.getName(), userExpect.getName()); + assertNull(request.getModelId()); } @Test public void validate_Success() { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); assertNull(request.validate()); @@ -133,8 +141,10 @@ public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_SearchQuery } private void fromActionRequest_Success_WithNonMLPredictionTaskRequest(MLInput mlInput) { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); ActionRequest actionRequest = new ActionRequest() { @Override @@ -151,6 +161,7 @@ public void writeTo(StreamOutput out) throws IOException { assertNotSame(result, request); assertEquals(request.getMlInput().getAlgorithm(), result.getMlInput().getAlgorithm()); assertEquals(request.getMlInput().getInputDataset().getInputDataType(), result.getMlInput().getInputDataset().getInputDataType()); + assertEquals(request.getUser().getName(), request.getUser().getName()); } @Test(expected = UncheckedIOException.class) @@ -168,4 +179,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLPredictionTaskRequest.fromActionRequest(actionRequest); } -} \ No newline at end of file +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java new file mode 100644 index 0000000000..a4a3982505 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running any Agent. + */ +@Log4j2 +@ToolAnnotation(AgentTool.TYPE) +public class AgentTool implements Tool { + public static final String TYPE = "AgentTool"; + private final Client client; + + private String agentId; + @Setter + @Getter + private String name = TYPE; + + @VisibleForTesting + static String DEFAULT_DESCRIPTION = "Use this tool to run any agent."; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + public AgentTool(Client client, String agentId) { + this.client = client; + this.agentId = agentId; + } + + @Override + public void run(Map parameters, ActionListener listener) { + AgentMLInput agentMLInput = AgentMLInput + .AgentMLInputBuilder() + .agentId(agentId) + .functionName(FunctionName.AGENT) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .build(); + ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false); + client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput output = (ModelTensorOutput) r.getOutput(); + listener.onResponse((T) output); + }, e -> { + log.error("Failed to run agent " + agentId, e); + listener.onFailure(e); + })); + + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (AgentTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public AgentTool create(Map map) { + return new AgentTool(client, (String) map.get("agent_id")); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java new file mode 100644 index 0000000000..4b941e6333 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running any ml-commons model. + */ +@Log4j2 +@ToolAnnotation(MLModelTool.TYPE) +public class MLModelTool implements Tool { + public static final String TYPE = "MLModelTool"; + + @Setter + @Getter + private String name = TYPE; + @VisibleForTesting + static String DEFAULT_DESCRIPTION = "Use this tool to run any model."; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private Client client; + @Getter + private String modelId; + @Setter + private Parser inputParser; + @Setter + @Getter + private Parser outputParser; + + public MLModelTool(Client client, String modelId) { + this.client = client; + this.modelId = modelId; + + outputParser = o -> { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + modelTensorOutput.getMlModelOutputs(); + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + }, e -> { + log.error("Failed to run model " + modelId, e); + listener.onFailure(e); + })); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MLModelTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public MLModelTool create(Map map) { + return new MLModelTool(client, (String) map.get("model_id")); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java new file mode 100644 index 0000000000..431e609bba --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class AgentToolTests { + + @Mock + private Client client; + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + @Mock + private Parser mockOutputParser; + + @Mock + private MLExecuteTaskResponse mockResponse; + + @Mock + private ActionListener listener; + + private AgentTool agentTool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + AgentTool.Factory.getInstance().init(client); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testAgenttestRunMethod() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + AgentMLInput agentMLInput = AgentMLInput + .AgentMLInputBuilder() + .agentId("agentId") + .functionName(FunctionName.AGENT) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .build(); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + Tool tool = AgentTool.Factory.getInstance().create(Map.of("agent_id", "modelId")); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLExecuteTaskResponse.builder().functionName(FunctionName.AGENT).output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + + tool.run(parameters, listener); + + // Verify interactions + verify(client).execute(any(), any(), any()); + verify(listener).onResponse(mlModelTensorOutput); + } + + @Test + public void testRunWithError() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + + // Mocking the client.execute to simulate an error + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Test Exception")); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = AgentTool.Factory.getInstance().create(Map.of("agent_id", "modelId")); + tool.run(parameters, listener); + + // Verifying that onFailure was called + verify(listener).onFailure(any(RuntimeException.class)); + } + + @Test + public void testTool() { + Tool tool = AgentTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(AgentTool.TYPE, tool.getName()); + assertEquals(AgentTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertTrue(tool.validate(emptyParams)); + assertEquals(DEFAULT_DESCRIPTION, tool.getDescription()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java new file mode 100644 index 0000000000..e4bcb9db5d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION; + +import java.util.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class MLModelToolTests { + + @Mock + private Client client; + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + @Mock + private Parser mockOutputParser; + + @Mock + private ActionListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + MLModelTool.Factory.getInstance().init(client); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testMLModelsWithOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + + ActionListener actionListener = invocation.getArgument(2); + + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + verify(client).execute(any(), any(), any()); + verify(mockOutputParser).parse(any()); + ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(ModelTensorOutput.class); + verify(listener).onResponse(dataFrameArgumentCaptor.capture()); + } + + @Test + public void testOutputParserLambda() { + // Create a mock ModelTensors object + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "testResponse", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + // Create the lambda expression for outputParser + Parser outputParser = o -> { + List outputs = (List) o; + return outputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }; + + // Invoke the lambda with the mock data + Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + + // Assert that the result matches the expected response + assertEquals("testResponse", result); + } + + @Test + public void testRunWithError() { + // Mocking the client.execute to simulate an error + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Test Exception")); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + // Verifying that onFailure was called + verify(listener).onFailure(any(RuntimeException.class)); + } + + @Test + public void testTool() { + Tool tool = MLModelTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(MLModelTool.TYPE, tool.getName()); + assertEquals(MLModelTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + assertEquals(DEFAULT_DESCRIPTION, tool.getDescription()); + } +} From 4d8d32e73845ead7574d9bac3ce77651c2158404 Mon Sep 17 00:00:00 2001 From: Mingshi Liu <113382730+mingshl@users.noreply.github.com> Date: Sat, 16 Dec 2023 11:27:00 -0800 Subject: [PATCH 7/7] Get and delete agent APIs (#1752) * get and delete agent APIs (#1703) Signed-off-by: Bhavana Ramaram Signed-off-by: Mingshi Liu * Add unit tests for Get and Delete APIs Signed-off-by: Mingshi Liu * Add header and increase code coverage Signed-off-by: Mingshi Liu * change IndexNotFoundException error message Signed-off-by: Mingshi Liu --------- Signed-off-by: Bhavana Ramaram Signed-off-by: Mingshi Liu Signed-off-by: Mingshi Liu <113382730+mingshl@users.noreply.github.com> --- .../transport/agent/MLAgentDeleteAction.java | 16 + .../transport/agent/MLAgentDeleteRequest.java | 71 +++++ .../transport/agent/MLAgentGetAction.java | 16 + .../transport/agent/MLAgentGetRequest.java | 71 +++++ .../transport/agent/MLAgentGetResponse.java | 62 ++++ .../agent/MLAgentDeleteActionTest.java | 19 ++ .../agent/MLAgentDeleteRequestTest.java | 65 ++++ .../transport/agent/MLAgentGetActionTest.java | 21 ++ .../agent/MLAgentGetRequestTest.java | 64 ++++ .../agent/MLAgentGetResponseTest.java | 106 +++++++ .../agents/DeleteAgentTransportAction.java | 70 +++++ .../agents/GetAgentTransportAction.java | 97 ++++++ .../ml/plugin/MachineLearningPlugin.java | 12 + .../ml/rest/RestMLDeleteAgentAction.java | 51 ++++ .../ml/rest/RestMLGetAgentAction.java | 63 ++++ .../opensearch/ml/utils/RestActionUtils.java | 1 + .../DeleteAgentTransportActionTests.java | 112 +++++++ .../agents/GetAgentTransportActionTests.java | 283 ++++++++++++++++++ .../ml/rest/RestMLDeleteAgentActionTests.java | 102 +++++++ .../ml/rest/RestMLGetAgentActionTests.java | 99 ++++++ 20 files changed, 1401 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java new file mode 100644 index 0000000000..c23e810090 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class MLAgentDeleteAction extends ActionType { + public static final MLAgentDeleteAction INSTANCE = new MLAgentDeleteAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/delete"; + + private MLAgentDeleteAction() { super(NAME, DeleteResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java new file mode 100644 index 0000000000..ddc568fc60 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentDeleteRequest extends ActionRequest { + @Getter + String agentId; + + @Builder + public MLAgentDeleteRequest(String agentId) { + this.agentId = agentId; + } + + public MLAgentDeleteRequest(StreamInput input) throws IOException { + super(input); + this.agentId = input.readString(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeString(agentId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.agentId == null) { + exception = addValidationError("ML agent id can't be null", exception); + } + + return exception; + } + + public static MLAgentDeleteRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLAgentDeleteRequest) { + return (MLAgentDeleteRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentDeleteRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLAgentDeleteRequest", e); + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java new file mode 100644 index 0000000000..2a61035ce8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; + +public class MLAgentGetAction extends ActionType { + public static final MLAgentGetAction INSTANCE = new MLAgentGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/get"; + + private MLAgentGetAction() { super(NAME, MLAgentGetResponse::new);} + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java new file mode 100644 index 0000000000..4880a07abf --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +public class MLAgentGetRequest extends ActionRequest { + + String agentId; + + @Builder + public MLAgentGetRequest(String agentId) { + this.agentId = agentId; + } + + public MLAgentGetRequest(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.agentId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.agentId == null) { + exception = addValidationError("ML agent id can't be null", exception); + } + + return exception; + } + + public static MLAgentGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLAgentGetRequest) { + return (MLAgentGetRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLAgentGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java new file mode 100644 index 0000000000..a437ef0ed8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.agent.MLAgent; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLAgentGetResponse extends ActionResponse implements ToXContentObject { + MLAgent mlAgent; + + @Builder + public MLAgentGetResponse(MLAgent mlAgent) { + this.mlAgent = mlAgent; + } + + public MLAgentGetResponse(StreamInput in) throws IOException { + super(in); + mlAgent = MLAgent.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + mlAgent.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlAgent.toXContent(xContentBuilder, params); + } + + public static MLAgentGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLAgentGetResponse) { + return (MLAgentGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLAgentGetResponse", e); + } + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java new file mode 100644 index 0000000000..7cc9e66793 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLAgentDeleteActionTest { + @Test + public void testMLAgentDeleteActionInstance() { + assertNotNull(MLAgentDeleteAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/delete", MLAgentDeleteAction.NAME); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java new file mode 100644 index 0000000000..135271ec47 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentDeleteRequestTest { + String agentId; + + @Test + public void constructor_AgentId() { + agentId = "test-abc"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(mLAgentDeleteRequest.agentId,agentId); + } + + @Test + public void writeTo() throws IOException { + agentId = "test-hij"; + + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + BytesStreamOutput output = new BytesStreamOutput(); + mLAgentDeleteRequest.writeTo(output); + + MLAgentDeleteRequest mLAgentDeleteRequest1 = new MLAgentDeleteRequest(output.bytes().streamInput()); + + assertEquals(mLAgentDeleteRequest.agentId, mLAgentDeleteRequest1.agentId); + assertEquals(agentId, mLAgentDeleteRequest1.agentId); + } + + @Test + public void validate_Success() { + agentId = "not-null"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + + assertEquals(null, mLAgentDeleteRequest.validate()); + } + + @Test + public void validate_Failure() { + agentId = null; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(null,mLAgentDeleteRequest.agentId); + + ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); + mLAgentDeleteRequest.validate().equals(exception) ; + } + + @Test + public void fromActionRequest() throws IOException { + agentId = "test-lmn"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(mLAgentDeleteRequest.fromActionRequest(mLAgentDeleteRequest), mLAgentDeleteRequest); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java new file mode 100644 index 0000000000..cba838fb02 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLAgentGetActionTest { + + @Test + public void testMLAgentGetActionInstance() { + assertNotNull(MLAgentGetAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/get", MLAgentGetAction.NAME); + } + + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java new file mode 100644 index 0000000000..6a04f5a965 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; + +import java.io.IOException; +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentGetRequestTest { + String agentId; + + @Test + public void constructor_AgentId() { + agentId = "test-abc"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(mLAgentGetRequest.getAgentId(),agentId); + } + + @Test + public void writeTo() throws IOException { + agentId = "test-hij"; + + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + BytesStreamOutput output = new BytesStreamOutput(); + mLAgentGetRequest.writeTo(output); + + MLAgentGetRequest mLAgentGetRequest1 = new MLAgentGetRequest(output.bytes().streamInput()); + + assertEquals(mLAgentGetRequest1.getAgentId(), mLAgentGetRequest.getAgentId()); + assertEquals(mLAgentGetRequest1.getAgentId(), agentId); + } + + @Test + public void validate_Success() { + agentId = "not-null"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + + assertEquals(null, mLAgentGetRequest.validate()); + } + + @Test + public void validate_Failure() { + agentId = null; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(null,mLAgentGetRequest.agentId); + + ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); + mLAgentGetRequest.validate().equals(exception) ; + } + @Test + public void fromActionRequest() throws IOException { + agentId = "test-lmn"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(mLAgentGetRequest.fromActionRequest(mLAgentGetRequest), mLAgentGetRequest); + } +} + + diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java new file mode 100644 index 0000000000..7d733a4308 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.*; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.*; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MLAgentGetResponseTest { + + MLAgent mlAgent; + + @Test + public void Create_MLAgentResponse_With_StreamInput() throws IOException { + // Create a BytesStreamOutput to simulate the StreamOutput + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + + //create a test agent using input + bytesStreamOutput.writeString("Test Agent"); + bytesStreamOutput.writeString("flow"); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); + bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); + bytesStreamOutput.writeString("test"); + + StreamInput testInputStream = bytesStreamOutput.bytes().streamInput(); + + MLAgentGetResponse mlAgentGetResponse = new MLAgentGetResponse(testInputStream); + MLAgent testMlAgent = mlAgentGetResponse.mlAgent; + assertEquals("flow",testMlAgent.getType()); + assertEquals("Test Agent",testMlAgent.getName()); + assertEquals("test",testMlAgent.getAppType()); + } + + @Test + public void mLAgentGetResponse_Builder() throws IOException { + + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + + assertEquals(mlAgentGetResponse.mlAgent, mlAgent); + } + @Test + public void writeTo() throws IOException { + //create ml agent using MLAgent and mlAgentGetResponse + mlAgent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + //use write out for both agents + BytesStreamOutput output = new BytesStreamOutput(); + mlAgent.writeTo(output); + mlAgentGetResponse.writeTo(output); + MLAgent agent1 = mlAgentGetResponse.mlAgent; + + assertEquals(mlAgent.getAppType(), agent1.getAppType()); + assertEquals(mlAgent.getDescription(), agent1.getDescription()); + assertEquals(mlAgent.getCreatedTime(), agent1.getCreatedTime()); + assertEquals(mlAgent.getName(), agent1.getName()); + assertEquals(mlAgent.getParameters(), agent1.getParameters()); + assertEquals(mlAgent.getType(), agent1.getType()); + } + + @Test + public void toXContent() throws IOException { + mlAgent = new MLAgent("mock", "flow", "test", null, null, null, null, null, null, "test"); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + ToXContent.Params params = EMPTY_PARAMS; + XContentBuilder getResponseXContentBuilder = mlAgentGetResponse.toXContent(builder, params); + assertEquals(getResponseXContentBuilder, mlAgent.toXContent(builder, params)); + } + + @Test + public void FromActionResponse() throws IOException { + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + assertEquals(mlAgentGetResponse.fromActionResponse(mlAgentGetResponse), mlAgentGetResponse); + + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java new file mode 100644 index 0000000000..0376fdbba9 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteAgentTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public DeleteAgentTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLAgentDeleteAction.NAME, transportService, actionFilters, MLAgentDeleteRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLAgentDeleteRequest mlAgentDeleteRequest = MLAgentDeleteRequest.fromActionRequest(request); + String agentId = mlAgentDeleteRequest.getAgentId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + DeleteRequest deleteRequest = new DeleteRequest(ML_AGENT_INDEX, agentId); + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + log.debug("Completed Delete Agent Request, agent id:{} deleted", agentId); + wrappedListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete ML Agent " + agentId, e); + wrappedListener.onFailure(e); + } + }); + } catch (Exception e) { + log.error("Failed to delete ml agent " + agentId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java new file mode 100644 index 0000000000..59d17651a3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.ml.common.transport.agent.MLAgentGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetAgentTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public GetAgentTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLAgentGetAction.NAME, transportService, actionFilters, MLAgentGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest.fromActionRequest(request); + String agentId = mlAgentGetRequest.getAgentId(); + GetRequest getRequest = new GetRequest(ML_AGENT_INDEX).id(agentId); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Agent Request, id:{}", agentId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + actionListener.onResponse(MLAgentGetResponse.builder().mlAgent(mlAgent).build()); + } catch (Exception e) { + log.error("Failed to parse ml agent" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + log.error("Failed to get agent index", e); + actionListener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML agent " + agentId, e); + actionListener.onFailure(e); + } + }), context::restore)); + } catch (Exception e) { + log.error("Failed to get ML agent " + agentId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 31197cd6be..bbfd5fdfbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -36,6 +36,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.ml.action.agents.DeleteAgentTransportAction; +import org.opensearch.ml.action.agents.GetAgentTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; @@ -90,6 +92,8 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; @@ -161,12 +165,14 @@ import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; import org.opensearch.ml.rest.RestMLDeleteModelGroupAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; +import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetModelAction; import org.opensearch.ml.rest.RestMLGetModelGroupAction; @@ -331,6 +337,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), + new ActionHandler<>(MLAgentGetAction.INSTANCE, GetAgentTransportAction.class), + new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) @@ -589,6 +597,8 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); + RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(); + RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); @@ -631,6 +641,8 @@ public List getRestHandlers( restSearchInteractionsAction, restGetConversationAction, restGetInteractionAction, + restMLGetAgentAction, + restMLDeleteAgentAction, restMemoryUpdateConversationAction, restMemoryUpdateInteractionAction, restMemoryGetTracesAction diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java new file mode 100644 index 0000000000..c8a667055e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete ML Agent. + */ +public class RestMLDeleteAgentAction extends BaseRestHandler { + private static final String ML_DELETE_AGENT_ACTION = "ml_delete_agent_action"; + + public void RestMLDeleteAgentAction() {} + + @Override + public String getName() { + return ML_DELETE_AGENT_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/agents/{%s}", ML_BASE_URI, PARAMETER_AGENT_ID))); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String agentId = request.param(PARAMETER_AGENT_ID); + + MLAgentDeleteRequest mlAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + return channel -> client.execute(MLAgentDeleteAction.INSTANCE, mlAgentDeleteRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java new file mode 100644 index 0000000000..efed1d84c3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetAgentAction extends BaseRestHandler { + private static final String ML_GET_Agent_ACTION = "ml_get_agent_action"; + + /** + * Constructor + */ + public RestMLGetAgentAction() {} + + @Override + public String getName() { + return ML_GET_Agent_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/agents/{%s}", ML_BASE_URI, PARAMETER_AGENT_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLAgentGetRequest mlAgentGetRequest = getRequest(request); + return channel -> client.execute(MLAgentGetAction.INSTANCE, mlAgentGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLAgentGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLAgentGetRequest + */ + @VisibleForTesting + MLAgentGetRequest getRequest(RestRequest request) throws IOException { + String agentId = getParameterId(request, PARAMETER_AGENT_ID); + + return new MLAgentGetRequest(agentId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 98f5f87d22..3a2f9daae4 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -50,6 +50,7 @@ public class RestActionUtils { public static final String PARAMETER_ASYNC = "async"; public static final String PARAMETER_RETURN_CONTENT = "return_content"; public static final String PARAMETER_MODEL_ID = "model_id"; + public static final String PARAMETER_AGENT_ID = "agent_id"; public static final String PARAMETER_TASK_ID = "task_id"; public static final String PARAMETER_CONNECTOR_ID = "connector_id"; public static final String PARAMETER_DEPLOY_MODEL = "deploy"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java new file mode 100644 index 0000000000..212112841a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.agents; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteAgentTransportActionTests { + + @Mock + private Client client; + @Mock + ThreadPool threadPool; + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @InjectMocks + private DeleteAgentTransportAction deleteAgentTransportAction; + + ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + deleteAgentTransportAction = new DeleteAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + @Test + public void testConstructor() { + // Verify that the dependencies were correctly injected + assertEquals(deleteAgentTransportAction.client, client); + assertEquals(deleteAgentTransportAction.xContentRegistry, xContentRegistry); + } + + @Test + public void testDoExecute_Success() { + String agentId = "test-agent-id"; + DeleteResponse deleteResponse = mock(DeleteResponse.class); + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + @Test + public void testDoExecute_Failure() { + String agentId = "test-non-existed-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + NullPointerException NullPointerException = new NullPointerException("Failed to delete ML Agent " + agentId); + listener.onFailure(NullPointerException); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to delete ML Agent " + agentId, argumentCaptor.getValue().getMessage()); + + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java new file mode 100644 index 0000000000..07f406ac07 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -0,0 +1,283 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.agents; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.ml.common.transport.agent.MLAgentGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetAgentTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + @Mock + ThreadPool threadPool; + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @InjectMocks + private GetAgentTransportAction getAgentTransportAction; + + ThreadContext threadContext; + MLAgent mlAgent; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + getAgentTransportAction = new GetAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + } + + @Test + public void testDoExecute_Failure_Get_Agent() { + String agentId = "test-agent-id-no-existed"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new Exception("Failed to get ML agent " + agentId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get ML agent " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_IndexNotFound() { + String agentId = "test-agent-id-IndexNotFound"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new IndexNotFoundException("Failed to get agent index " + agentId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get agent index", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_OpenSearchStatus() throws IOException { + String agentId = "test-agent-id-OpenSearchStatus"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find agent with the provided agent id: " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_RuntimeException() { + String agentId = "test-agent-id-RuntimeException"; + Task task = mock(Task.class); + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Failed to get ML agent " + agentId)); + return null; + }).when(client).get(any(), any()); + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get ML agent " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetTask_NullResponse() { + String agentId = "test-agent-id-NullResponse"; + Task task = mock(Task.class); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find agent with the provided agent id: " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_Context_Exception() { + String agentId = "test-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + GetAgentTransportAction getAgentTransportActionNullContext = new GetAgentTransportAction( + transportService, + actionFilters, + client, + xContentRegistry + ); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenThrow(new RuntimeException()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(client).get(any(), any()); + try { + getAgentTransportActionNullContext.doExecute(task, getRequest, actionListener); + } catch (Exception e) { + assertEquals(e.getClass(), RuntimeException.class); + } + } + + @Test + public void testDoExecute_NoAgentId() throws IOException { + GetResponse getResponse = prepareMLAgent(null); + String agentId = "test-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + try { + getAgentTransportAction.doExecute(task, request, actionListener); + } catch (Exception e) { + assertEquals(e.getClass(), IllegalArgumentException.class); + } + } + + @Test + public void testDoExecute_Success() throws IOException { + + String agentId = "test-agent-id"; + GetResponse getResponse = prepareMLAgent(agentId); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLAgentGetResponse.class)); + } + + public GetResponse prepareMLAgent(String agentId) throws IOException { + + mlAgent = new MLAgent( + "test", + "test", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test" + ); + + XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", agentId, 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java new file mode 100644 index 0000000000..19849294f8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLDeleteAgentActionTests extends OpenSearchTestCase { + private RestMLDeleteAgentAction restMLDeleteAgentAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLDeleteAgentAction = new RestMLDeleteAgentAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLAgentDeleteAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteAgentAction mLDeleteAgentAction = new RestMLDeleteAgentAction(); + assertNotNull(mLDeleteAgentAction); + } + + public void testGetName() { + String actionName = restMLDeleteAgentAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_agent_action", actionName); + } + + public void testRoutes() { + List routes = restMLDeleteAgentAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/agents/{agent_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteAgentAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLAgentDeleteRequest.class); + verify(client, times(1)).execute(eq(MLAgentDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String agentId = argumentCaptor.getValue().getAgentId(); + assertEquals(agentId, "agent_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "agent_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java new file mode 100644 index 0000000000..7b2f4eaae8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetAgentActionTests extends OpenSearchTestCase { + private RestMLGetAgentAction restMLGetAgentAction; + NodeClient client; + private ThreadPool threadPool; + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLGetAgentAction = new RestMLGetAgentAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLAgentGetAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetAgentAction mLGetAgentAction = new RestMLGetAgentAction(); + assertNotNull(mLGetAgentAction); + } + + public void testGetName() { + String actionName = restMLGetAgentAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_agent_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetAgentAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/agents/{agent_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetAgentAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLAgentGetRequest.class); + verify(client, times(1)).execute(eq(MLAgentGetAction.INSTANCE), argumentCaptor.capture(), any()); + String agentId = argumentCaptor.getValue().getAgentId(); + assertEquals(agentId, "agent_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "agent_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } + +}