From 0a1b4b51ccf59447f5d7ecfa80416e100d1f7422 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 13 Dec 2023 08:50:02 -0800 Subject: [PATCH] agent meta classes in common (#1757) (#1759) * agent meta classes Signed-off-by: Jing Zhang * address comments Signed-off-by: Jing Zhang * add one more UT Signed-off-by: Jing Zhang * add more UT Signed-off-by: Jing Zhang * more UT Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang (cherry picked from commit 3949ef9efb87b4a81f9e358fa1400295dec76f8c) Co-authored-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 116 ++++++++ .../opensearch/ml/common/agent/LLMSpec.java | 102 +++++++ .../opensearch/ml/common/agent/MLAgent.java | 264 ++++++++++++++++++ .../ml/common/agent/MLMemorySpec.java | 108 +++++++ .../ml/common/agent/MLToolSpec.java | 143 ++++++++++ .../ml/common/agent/LLMSpecTest.java | 91 ++++++ .../ml/common/agent/MLAgentTest.java | 143 ++++++++++ .../ml/common/agent/MLMemorySpecTest.java | 69 +++++ .../ml/common/agent/MLToolSpecTest.java | 75 +++++ 9 files changed, 1111 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index f82e742866..53a12a4224 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,8 +5,25 @@ package org.opensearch.ml.common; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_CREATED_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.USER_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.ALL_CONFIG_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD; @@ -44,6 +61,12 @@ public class CommonValue { public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; public static final String ML_MAP_RESPONSE_KEY = "response"; + public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; + public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; + public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" @@ -326,4 +349,97 @@ public class CommonValue { + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; + + public static final String ML_AGENT_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_AGENT_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLAgent.AGENT_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLAgent.AGENT_TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + MLAgent.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLAgent.LLM_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.TOOLS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.MEMORY_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLAgent.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + + public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + META_NAME_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + META_CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + META_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + + "\": {\"type\": \"keyword\"}\n" + + " }\n" + + "}"; + + public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + INTERACTIONS_CONVERSATION_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + INTERACTIONS_INPUT_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_RESPONSE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_ORIGIN_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + + " }\n" + + "}"; } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java new file mode 100644 index 0000000000..6c0fda289a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class LLMSpec implements ToXContentObject { + public static final String MODEL_ID_FIELD = "model_id"; + public static final String PARAMETERS_FIELD = "parameters"; + + private String modelId; + private Map parameters; + + + @Builder(toBuilder = true) + public LLMSpec(String modelId, Map parameters) { + if (modelId == null) { + throw new IllegalArgumentException("model id is null"); + } + this.modelId = modelId; + this.parameters = parameters; + } + + public LLMSpec(StreamInput input) throws IOException{ + modelId = input.readString(); + if (input.readBoolean()) { + parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + builder.endObject(); + return builder; + } + + public static LLMSpec parse(XContentParser parser) throws IOException { + String modelId = null; + Map parameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + return LLMSpec.builder() + .modelId(modelId) + .parameters(parameters) + .build(); + } + + public static LLMSpec fromStream(StreamInput in) throws IOException { + LLMSpec toolSpec = new LLMSpec(in); + return toolSpec; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java new file mode 100644 index 0000000000..ba2f241375 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -0,0 +1,264 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class MLAgent implements ToXContentObject, Writeable { + public static final String AGENT_NAME_FIELD = "name"; + public static final String AGENT_TYPE_FIELD = "type"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String LLM_FIELD = "llm"; + public static final String TOOLS_FIELD = "tools"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String MEMORY_FIELD = "memory"; + public static final String MEMORY_ID_FIELD = "memory_id"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; + public static final String APP_TYPE_FIELD = "app_type"; + + private String name; + private String type; + private String description; + private LLMSpec llm; + private List tools; + private Map parameters; + private MLMemorySpec memory; + + private Instant createdTime; + private Instant lastUpdateTime; + private String appType; + + @Builder(toBuilder = true) + public MLAgent(String name, + String type, + String description, + LLMSpec llm, + List tools, + Map parameters, + MLMemorySpec memory, + Instant createdTime, + Instant lastUpdateTime, + String appType) { + if (name == null) { + throw new IllegalArgumentException("agent name is null"); + } + this.name = name; + this.type = type; + this.description = description; + this.llm = llm; + this.tools = tools; + this.parameters = parameters; + this.memory = memory; + this.createdTime = createdTime; + this.lastUpdateTime = lastUpdateTime; + this.appType = appType; + } + + public MLAgent(StreamInput input) throws IOException{ + name = input.readString(); + type = input.readString(); + description = input.readOptionalString(); + if (input.readBoolean()) { + llm = new LLMSpec(input); + } + if (input.readBoolean()) { + tools = new ArrayList<>(); + int size = input.readInt(); + for (int i=0; i toolNames = new HashSet<>(); + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); + } + } + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(type); + out.writeOptionalString(description); + if (llm != null) { + out.writeBoolean(true); + llm.writeTo(out); + } else { + out.writeBoolean(false); + } + if (tools != null && tools.size() > 0) { + out.writeBoolean(true); + out.writeInt(tools.size()); + for (MLToolSpec tool : tools) { + tool.writeTo(out); + } + } else { + out.writeBoolean(false); + } + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + if (memory != null) { + out.writeBoolean(true); + memory.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeInstant(createdTime); + out.writeInstant(lastUpdateTime); + out.writeString(appType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(AGENT_NAME_FIELD, name); + } + if (type != null) { + builder.field(AGENT_TYPE_FIELD, type); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (llm != null) { + builder.field(LLM_FIELD, llm); + } + if (tools != null && tools.size() > 0) { + builder.field(TOOLS_FIELD, tools); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + if (memory != null) { + builder.field(MEMORY_FIELD, memory); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (appType != null) { + builder.field(APP_TYPE_FIELD, appType); + } + builder.endObject(); + return builder; + } + + public static MLAgent parse(XContentParser parser) throws IOException { + String name = null; + String type = null; + String description = null;; + LLMSpec llm = null; + List tools = null; + Map parameters = null; + MLMemorySpec memory = null; + Instant createdTime = null; + Instant lastUpdateTime = null; + String appType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case AGENT_NAME_FIELD: + name = parser.text(); + break; + case AGENT_TYPE_FIELD: + type = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case LLM_FIELD: + llm = LLMSpec.parse(parser); + break; + case TOOLS_FIELD: + tools = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + tools.add(MLToolSpec.parse(parser)); + } + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + case MEMORY_FIELD: + memory = MLMemorySpec.parse(parser); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + break; + case APP_TYPE_FIELD: + appType = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLAgent.builder() + .name(name) + .type(type) + .description(description) + .llm(llm) + .tools(tools) + .parameters(parameters) + .memory(memory) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .appType(appType) + .build(); + } + + public static MLAgent fromStream(StreamInput in) throws IOException { + MLAgent agent = new MLAgent(in); + return agent; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java new file mode 100644 index 0000000000..aa192a7ee2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +@Getter +public class MLMemorySpec implements ToXContentObject { + public static final String MEMORY_TYPE_FIELD = "type"; + public static final String WINDOW_SIZE_FIELD = "window_size"; + public static final String SESSION_ID_FIELD = "session_id"; + + private String type; + @Setter + private String sessionId; + private Integer windowSize; + + + @Builder(toBuilder = true) + public MLMemorySpec(String type, + String sessionId, + Integer windowSize) { + if (type == null) { + throw new IllegalArgumentException("agent name is null"); + } + this.type = type; + this.sessionId = sessionId; + this.windowSize = windowSize; + } + + public MLMemorySpec(StreamInput input) throws IOException{ + type = input.readString(); + sessionId = input.readOptionalString(); + windowSize = input.readOptionalInt(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeOptionalString(sessionId); + out.writeOptionalInt(windowSize); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEMORY_TYPE_FIELD, type); + if (windowSize != null) { + builder.field(WINDOW_SIZE_FIELD, windowSize); + } + if (sessionId != null) { + builder.field(SESSION_ID_FIELD, sessionId); + } + builder.endObject(); + return builder; + } + + public static MLMemorySpec parse(XContentParser parser) throws IOException { + String type = null; + String sessionId = null; + Integer windowSize = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_TYPE_FIELD: + type = parser.text(); + break; + case SESSION_ID_FIELD: + sessionId = parser.text(); + break; + case WINDOW_SIZE_FIELD: + windowSize = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLMemorySpec.builder() + .type(type) + .sessionId(sessionId) + .windowSize(windowSize) + .build(); + } + + public static MLMemorySpec fromStream(StreamInput in) throws IOException { + MLMemorySpec toolSpec = new MLMemorySpec(in); + return toolSpec; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java new file mode 100644 index 0000000000..055c59d449 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + + +@Getter +public class MLToolSpec implements ToXContentObject { + public static final String TOOL_TYPE_FIELD = "type"; + public static final String TOOL_NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + + private String type; + private String name; + private String description; + private Map parameters; + private boolean includeOutputInAgentResponse; + + + @Builder(toBuilder = true) + public MLToolSpec(String type, + String name, + String description, + Map parameters, + boolean includeOutputInAgentResponse) { + if (type == null) { + throw new IllegalArgumentException("tool type is null"); + } + this.type = type; + this.name = name; + this.description = description; + this.parameters = parameters; + this.includeOutputInAgentResponse = includeOutputInAgentResponse; + } + + public MLToolSpec(StreamInput input) throws IOException{ + type = input.readString(); + name = input.readOptionalString(); + description = input.readOptionalString(); + if (input.readBoolean()) { + parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + includeOutputInAgentResponse = input.readBoolean(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeOptionalString(name); + out.writeOptionalString(description); + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + out.writeBoolean(includeOutputInAgentResponse); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (type != null) { + builder.field(TOOL_TYPE_FIELD, type); + } + if (name != null) { + builder.field(TOOL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAMETERS_FIELD, parameters); + } + builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse); + builder.endObject(); + return builder; + } + + public static MLToolSpec parse(XContentParser parser) throws IOException { + String type = null; + String name = null; + String description = null; + Map parameters = null; + boolean includeOutputInAgentResponse = false; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TOOL_TYPE_FIELD: + type = parser.text(); + break; + case TOOL_NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: + includeOutputInAgentResponse = parser.booleanValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLToolSpec.builder() + .type(type) + .name(name) + .description(description) + .parameters(parameters) + .includeOutputInAgentResponse(includeOutputInAgentResponse) + .build(); + } + + public static MLToolSpec fromStream(StreamInput in) throws IOException { + MLToolSpec toolSpec = new MLToolSpec(in); + return toolSpec; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java new file mode 100644 index 0000000000..0964efb7d1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java @@ -0,0 +1,91 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class LLMSpecTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void constructor_NonModelID() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model id is null"); + + LLMSpec spec = new LLMSpec(null, Map.of("test_key", "test_value")); + } + + @Test + public void writeTo() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = new LLMSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + } + + @Test + public void writeTo_EmptyParameters() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of()); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = new LLMSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(null, spec1.getParameters()); + } + + @Test + public void toXContent() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + LLMSpec spec = LLMSpec.parse(parser); + + Assert.assertEquals(spec.getModelId(), "test_model"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); + } + + @Test + public void fromStream() throws IOException { + LLMSpec spec = new LLMSpec("test_model", Map.of("test_key", "test_value")); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + LLMSpec spec1 = LLMSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getModelId(), spec1.getModelId()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java new file mode 100644 index 0000000000..bfaec959c4 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -0,0 +1,143 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLAgentTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void constructor_NullName() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("agent name is null"); + + MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + } + + @Test + public void writeTo() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent.getAppType(), agent1.getAppType()); + Assert.assertEquals(agent.getDescription(), agent1.getDescription()); + Assert.assertEquals(agent.getCreatedTime(), agent1.getCreatedTime()); + Assert.assertEquals(agent.getName(), agent1.getName()); + Assert.assertEquals(agent.getParameters(), agent1.getParameters()); + Assert.assertEquals(agent.getType(), agent1.getType()); + } + + @Test + public void writeTo_NullLLM() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getLlm(), null); + } + + @Test + public void writeTo_NullTools() throws IOException { + MLAgent agent = new MLAgent("test", "flow", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getTools(), null); + } + + @Test + public void writeTo_NullParameters() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getParameters(), null); + } + + @Test + public void writeTo_NullMemory() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), null, Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = new MLAgent(output.bytes().streamInput()); + + Assert.assertEquals(agent1.getMemory(), null); + } + + @Test + public void toXContent() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + String expectedStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + + Assert.assertEquals(content, expectedStr); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLAgent agent = MLAgent.parse(parser); + + Assert.assertEquals(agent.getName(), "test"); + Assert.assertEquals(agent.getType(), "test"); + Assert.assertEquals(agent.getDescription(), "test"); + Assert.assertEquals(agent.getLlm().getModelId(), "test_model"); + Assert.assertEquals(agent.getLlm().getParameters(), Map.of("test_key", "test_value")); + Assert.assertEquals(agent.getTools().get(0).getName(), "test"); + Assert.assertEquals(agent.getTools().get(0).getType(), "test"); + Assert.assertEquals(agent.getTools().get(0).getDescription(), "test"); + Assert.assertEquals(agent.getTools().get(0).getParameters(), Map.of("test", "test")); + Assert.assertEquals(agent.getTools().get(0).isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(agent.getCreatedTime(), Instant.EPOCH); + Assert.assertEquals(agent.getLastUpdateTime(), Instant.EPOCH); + Assert.assertEquals(agent.getAppType(), "test"); + Assert.assertEquals(agent.getMemory().getSessionId(), "123"); + Assert.assertEquals(agent.getParameters(), Map.of("test", "test")); + } + + @Test + public void fromStream() throws IOException { + MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + BytesStreamOutput output = new BytesStreamOutput(); + agent.writeTo(output); + MLAgent agent1 = MLAgent.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(agent.getAppType(), agent1.getAppType()); + Assert.assertEquals(agent.getDescription(), agent1.getDescription()); + Assert.assertEquals(agent.getCreatedTime(), agent1.getCreatedTime()); + Assert.assertEquals(agent.getName(), agent1.getName()); + Assert.assertEquals(agent.getParameters(), agent1.getParameters()); + Assert.assertEquals(agent.getType(), agent1.getType()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java new file mode 100644 index 0000000000..2d028985e0 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java @@ -0,0 +1,69 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLMemorySpecTest { + + @Test + public void writeTo() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLMemorySpec spec1 = new MLMemorySpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getSessionId(), spec1.getSessionId()); + Assert.assertEquals(spec.getWindowSize(), spec1.getWindowSize()); + } + + @Test + public void toXContent() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLMemorySpec spec = MLMemorySpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test"); + Assert.assertEquals(spec.getWindowSize(), Integer.valueOf(0)); + Assert.assertEquals(spec.getSessionId(), "123"); + } + + @Test + public void fromStream() throws IOException { + MLMemorySpec spec = new MLMemorySpec("test", "123", 0); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLMemorySpec spec1 = MLMemorySpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getSessionId(), spec1.getSessionId()); + Assert.assertEquals(spec.getWindowSize(), spec1.getWindowSize()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java new file mode 100644 index 0000000000..d831611035 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.common.agent; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.*; + +public class MLToolSpecTest { + + @Test + public void writeTo() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + } + + @Test + public void toXContent() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLToolSpec spec = MLToolSpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test"); + Assert.assertEquals(spec.getName(), "test"); + Assert.assertEquals(spec.getDescription(), "test"); + Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + } + + @Test + public void fromStream() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + } +} \ No newline at end of file