From 6f4394b729269fef06457ca3c3ec52a8c0bc869d Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sun, 12 Nov 2023 10:33:54 -0800 Subject: [PATCH] add memory factory; fix tool interface Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 44 ++++++-- .../opensearch/ml/common/agent/MLAgent.java | 49 ++------- .../algorithms/agent/MLAgentExecutor.java | 10 +- .../algorithms/agent/MLFlowAgentRunner.java | 6 +- .../algorithms/agent/MLReActAgentRunner.java | 77 +++++++------ .../ml/engine}/indices/MLIndex.java | 12 +- .../ml/engine}/indices/MLIndicesHandler.java | 10 +- .../indices/MLInputDatasetHandler.java | 2 +- .../memory/ConversationIndexMemory.java | 104 +++++++++++++++--- .../opensearch/ml/engine/tools/AgentTool.java | 18 +-- .../ml/engine/tools/CatIndexTool.java | 13 ++- .../ml/engine/tools/MLModelTool.java | 19 +--- .../opensearch/ml/engine/tools/MathTool.java | 18 +-- .../ml/engine/tools/PainlessScriptTool.java | 19 +--- .../ml/engine/tools/VectorDBTool.java | 18 +-- .../ml/engine/tools/CatIndexToolTests.java | 2 +- .../agents/TransportRegisterAgentAction.java | 2 +- .../TransportCreateConnectorAction.java | 2 +- .../TransportRegisterModelGroupAction.java | 2 +- .../TransportRegisterModelAction.java | 2 +- .../upload_chunk/MLModelChunkUploader.java | 2 +- .../MLCommonsClusterManagerEventListener.java | 2 +- .../opensearch/ml/cluster/MLSyncUpCron.java | 2 +- .../ml/model/MLModelGroupManager.java | 2 +- .../opensearch/ml/model/MLModelManager.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 34 +++--- .../ml/task/MLExecuteTaskRunner.java | 2 +- .../ml/task/MLPredictTaskRunner.java | 2 +- .../org/opensearch/ml/task/MLTaskManager.java | 2 +- .../ml/task/MLTrainAndPredictTaskRunner.java | 2 +- .../ml/task/MLTrainingTaskRunner.java | 4 +- .../TransportCreateConnectorActionTests.java | 2 +- ...ransportRegisterModelGroupActionTests.java | 2 +- .../TransportRegisterModelActionTests.java | 2 +- .../MLModelChunkUploaderTests.java | 2 +- .../ml/cluster/MLSyncUpCronTests.java | 2 +- .../ml/indices/MLIndicesHandlerTests.java | 1 + .../indices/MLInputDatasetHandlerTests.java | 1 + .../ml/model/MLModelGroupManagerTests.java | 2 +- .../ml/model/MLModelManagerTests.java | 2 +- .../ml/task/MLExecuteTaskRunnerTests.java | 2 +- .../ml/task/MLPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTaskManagerTests.java | 2 +- .../MLTrainAndPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainingTaskRunnerTests.java | 4 +- .../org/opensearch/ml/utils/MockHelper.java | 2 +- .../ml/common/spi/memory/Memory.java | 12 ++ 47 files changed, 296 insertions(+), 231 deletions(-) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLIndex.java (76%) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLIndicesHandler.java (96%) rename {plugin/src/main/java/org/opensearch/ml => ml-algorithms/src/main/java/org/opensearch/ml/engine}/indices/MLInputDatasetHandler.java (98%) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 1cd4e4131d..1aa520c676 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -47,6 +47,10 @@ public class CommonValue { public static final String ML_MAP_RESPONSE_KEY = "response"; public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; + public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER @@ -328,7 +332,6 @@ public class CommonValue { + " }\n" + "}"; - public static final String ML_AGENT_INDEX_MAPPING = "{\n" + " \"_meta\": {\"schema_version\": " + ML_AGENT_INDEX_SCHEMA_VERSION @@ -339,17 +342,11 @@ public class CommonValue { + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + " \"" + MLAgent.AGENT_TYPE_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + "\" : {\"type\":\"keyword\"},\n" + " \"" + MLAgent.DESCRIPTION_FIELD + "\" : {\"type\": \"text\"},\n" + " \"" - + MLAgent.PROMPT_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLAgent.MODEL_ID_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" + MLAgent.LLM_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" @@ -362,14 +359,37 @@ public class CommonValue { + MLAgent.MEMORY_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" - + MLAgent.MEMORY_ID_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLAgent.CREATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" + MLAgent.LAST_UPDATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; + + public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MEMORY_META_INDEX_SCHEMA_VERSION + + " },\n" + + " \"properties\": {\n" + + " \"name\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"application_type\" : {\"type\":\"keyword\"},\n" + + " \"created_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"last_updated_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + + public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION + + " },\n" + + " \"properties\": {\n" + + " \"question\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"response\" : {\"type\":\"text\"},\n" + + " \"session_id\" : {\"type\":\"keyword\"},\n" + + " \"final_answer\" : {\"type\":\"boolean\"},\n" + + " \"created_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"last_updated_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 7e85ec411b..a162762637 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -17,8 +17,11 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -29,8 +32,6 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String AGENT_NAME_FIELD = "name"; public static final String AGENT_TYPE_FIELD = "type"; public static final String DESCRIPTION_FIELD = "description"; - public static final String PROMPT_FIELD = "prompt"; - public static final String MODEL_ID_FIELD = "model_id"; public static final String LLM_FIELD = "llm"; public static final String TOOLS_FIELD = "tools"; public static final String PARAMETERS_FIELD = "parameters"; @@ -57,13 +58,10 @@ public class MLAgent implements ToXContentObject, Writeable { public MLAgent(String name, String type, String description, - String prompt, - String modelId, LLMSpec llm, List tools, Map parameters, MLMemorySpec memory, - String memoryId, Instant createdTime, Instant lastUpdateTime) { if (name == null) { @@ -72,13 +70,10 @@ public MLAgent(String name, this.name = name; this.type = type; this.description = description; - this.prompt = prompt; - this.modelId = modelId; this.llm = llm; this.tools = tools; this.parameters = parameters; this.memory = memory; - this.memoryId = memoryId; this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; } @@ -87,8 +82,6 @@ public MLAgent(StreamInput input) throws IOException{ name = input.readString(); type = input.readString(); description = input.readOptionalString(); - prompt = input.readOptionalString(); - modelId = input.readString(); if (input.readBoolean()) { llm = new LLMSpec(input); } @@ -105,17 +98,23 @@ public MLAgent(StreamInput input) throws IOException{ if (input.readBoolean()) { memory = new MLMemorySpec(input); } - memoryId = input.readOptionalString(); createdTime = input.readInstant(); lastUpdateTime = input.readInstant(); + if (!"flow".equals(type)) { + Set toolNames = new HashSet<>(); + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); + } + } + } } public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeString(type); out.writeOptionalString(description); - out.writeOptionalString(prompt); - out.writeString(modelId); if (llm != null) { out.writeBoolean(true); llm.writeTo(out); @@ -160,12 +159,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } - if (prompt != null) { - builder.field(PROMPT_FIELD, prompt); - } - if (modelId != null) { - builder.field(MODEL_ID_FIELD, modelId); - } if (llm != null) { builder.field(LLM_FIELD, llm); } @@ -178,9 +171,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (memory != null) { builder.field(MEMORY_FIELD, memory); } - if (memoryId != null) { - builder.field(MEMORY_ID_FIELD, memoryId); - } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } @@ -195,13 +185,10 @@ public static MLAgent parse(XContentParser parser) throws IOException { String name = null; String type = null; String description = null;; - String prompt = null; - String modelId = null; LLMSpec llm = null; List tools = null; Map parameters = null; MLMemorySpec memory = null; - String memoryId = null; Instant createdTime = null; Instant lastUpdateTime = null; @@ -220,12 +207,6 @@ public static MLAgent parse(XContentParser parser) throws IOException { case DESCRIPTION_FIELD: description = parser.text(); break; - case PROMPT_FIELD: - prompt = parser.text(); - break; - case MODEL_ID_FIELD: - modelId = parser.text(); - break; case LLM_FIELD: llm = LLMSpec.parse(parser); break; @@ -242,9 +223,6 @@ public static MLAgent parse(XContentParser parser) throws IOException { case MEMORY_FIELD: memory = MLMemorySpec.parse(parser); break; - case MEMORY_ID_FIELD: - memoryId = parser.text(); - break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -260,13 +238,10 @@ public static MLAgent parse(XContentParser parser) throws IOException { .name(name) .type(type) .description(description) - .prompt(prompt) - .modelId(modelId) .llm(llm) .tools(tools) .parameters(parameters) .memory(memory) - .memoryId(memoryId) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index c160d9676d..cf37ebd7a3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -58,15 +58,15 @@ public class MLAgentExecutor implements Executable { private ClusterService clusterService; private NamedXContentRegistry xContentRegistry; private Map toolFactories; - private Map memoryMap; + private Map memoryFactoryMap; - public MLAgentExecutor(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryMap) { + public MLAgentExecutor(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryFactoryMap) { this.client = client; this.settings = settings; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.toolFactories = toolFactories; - this.memoryMap = memoryMap; + this.memoryFactoryMap = memoryFactoryMap; } @Override @@ -130,10 +130,10 @@ public void execute(Input input, ActionListener listener) { listener.onFailure(ex); }); if ("flow".equals(mlAgent.getType())) { - MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap); + MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); flowAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener); } else if ("cot".equals(mlAgent.getType())) { - MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap); + MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); reactAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index b1c5521ede..e4da99d35c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -42,15 +42,15 @@ public class MLFlowAgentRunner { private ClusterService clusterService; private NamedXContentRegistry xContentRegistry; private Map toolFactories; - private Map memoryMap; + private Map memoryFactoryMap; - public MLFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryMap) { + public MLFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryFactoryMap) { this.client = client; this.settings = settings; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.toolFactories = toolFactories; - this.memoryMap = memoryMap; + this.memoryFactoryMap = memoryFactoryMap; } public void run(MLAgent mlAgent, Map params, ActionListener listener) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLReActAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLReActAgentRunner.java index be00ec6695..f9f5d2f8ef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLReActAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLReActAgentRunner.java @@ -81,15 +81,15 @@ public class MLReActAgentRunner { private ClusterService clusterService; private NamedXContentRegistry xContentRegistry; private Map toolFactories; - private Map memoryMap; + private Map memoryFactoryMap; - public MLReActAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryMap) { + public MLReActAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map toolFactories, Map memoryFactoryMap) { this.client = client; this.settings = settings; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.toolFactories = toolFactories; - this.memoryMap = memoryMap; + this.memoryFactoryMap = memoryFactoryMap; } public void run(MLAgent mlAgent, Map params, ActionListener listener) { @@ -97,39 +97,43 @@ public void run(MLAgent mlAgent, Map params, ActionListenerwrap(r -> { //TODO: support onlyIncludeFinalAnswerInChatHistory parameters - List messageList = new ArrayList<>(); - Iterator iterator = r.getHits().iterator(); - while(iterator.hasNext()) { - SearchHit next = iterator.next(); - Map map = next.getSourceAsMap(); - String question = (String)map.get("question"); - String response = (String)map.get("response"); - messageList.add(ConversationIndexMessage.conversationIndexMessageBuilder().sessionId(sessionId).question(question).response(response).build()); - } + memoryFactoryMap.get(memoryType).create(params, ActionListener.wrap(memory->{ + if (clusterService.state().metadata().hasIndex(memory.getMemoryMessageIndexName())) { + memory.getMessages(sessionId, ActionListener.wrap(r -> { //TODO: support onlyIncludeFinalAnswerInChatHistory parameters + List messageList = new ArrayList<>(); + Iterator iterator = r.getHits().iterator(); + while(iterator.hasNext()) { + SearchHit next = iterator.next(); + Map map = next.getSourceAsMap(); + String question = (String)map.get("question"); + String response = (String)map.get("response"); + messageList.add(ConversationIndexMessage.conversationIndexMessageBuilder().sessionId(sessionId).question(question).response(response).build()); + } - StringBuilder chatHistoryBuilder = new StringBuilder(); - if (messageList.size() > 0) { - chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n"); - for (Message message : messageList) { - chatHistoryBuilder.append(message.toString()).append("\n"); + StringBuilder chatHistoryBuilder = new StringBuilder(); + if (messageList.size() > 0) { + chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n"); + for (Message message : messageList) { + chatHistoryBuilder.append(message.toString()).append("\n"); + } + params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); } - params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - } + runAgent(mlAgent, params, listener, toolSpecs, memory, sessionId); + }, e-> { + log.error("Failed to get session history", e); + listener.onFailure(e); + })); + } else { runAgent(mlAgent, params, listener, toolSpecs, memory, sessionId); - }, e-> { - log.error("Failed to get session history", e); - listener.onFailure(e); - })); - } else { - runAgent(mlAgent, params, listener, toolSpecs, memory, sessionId); - } + } + }, e->{ + listener.onFailure(e); + })); + } else { runAgent(mlAgent, params, listener, toolSpecs, null, sessionId); } @@ -152,13 +156,17 @@ private void runAgent(MLAgent mlAgent, Map params, ActionListene executeParams.put(key.replace(toolSpec.getName()+".", ""), params.get(key)); } } - Tool tool = toolFactories.get(toolSpec.getName()).create(toolParams); + if (!toolFactories.containsKey(toolSpec.getType())) { + listener.onFailure(new IllegalArgumentException("No tool factory found for " + toolSpec.getType())); + return; + } + Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams); tool.setName(toolSpec.getName()); if (toolSpec.getDescription() != null) { tool.setDescription(toolSpec.getDescription()); } - String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getName()); + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); tools.put(toolName, tool); toolSpecMap.put(toolName, toolSpec); } @@ -207,8 +215,9 @@ private void runReAct(LLMSpec llm, Map tools, Map entry : tools.entrySet()) { - String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getName()); - inputTools.add(toolName); +// String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getType()); +// String toolName = Optional.ofNullable(entry.getKey()).orElse(entry.getValue().getType()); + inputTools.add(entry.getKey()); } } diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java similarity index 76% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java index 85348e5230..671f4e548a 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX_MAPPING; @@ -14,6 +14,12 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -30,7 +36,9 @@ public enum MLIndex { TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION), - AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION); + AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION), + MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING, ML_MEMORY_META_INDEX_SCHEMA_VERSION), + MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING, ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java similarity index 96% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 61fdc1350e..ca5f88be78 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import static org.opensearch.ml.common.CommonValue.META; import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; @@ -62,6 +62,14 @@ public void initMLConnectorIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); } + public void initMemoryMetaIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_META, listener); + } + + public void initMemoryMessageIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_MESSAGE, listener); + } + public void initMLConfigIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONFIG, listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java similarity index 98% rename from plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java index 1dcf6bdf77..0858476b73 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import java.util.ArrayList; import java.util.List; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java index 10fcfddd29..9234774d19 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -7,6 +7,7 @@ import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.client.Client; @@ -19,22 +20,36 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; +import software.amazon.awssdk.utils.ImmutableMap; import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.engine.algorithms.agent.MLReActAgentRunner.SESSION_ID; @Log4j2 +@Getter public class ConversationIndexMemory implements Memory { public static final String TYPE = "conversation_index"; - @Getter - protected String indexName; + protected String memoryMetaIndexName; + protected String memoryMessageIndexName; + protected String conversationId; protected boolean retrieveFinalAnswer = true; protected final Client client; + private final MLIndicesHandler mlIndicesHandler; - public ConversationIndexMemory(Client client) { + public ConversationIndexMemory(Client client, MLIndicesHandler mlIndicesHandler, String memoryMetaIndexName, String memoryMessageIndexName, String conversationId) { this.client = client; - this.indexName = "my_sessions"; + this.mlIndicesHandler = mlIndicesHandler; + this.memoryMetaIndexName = memoryMetaIndexName; + this.memoryMessageIndexName = memoryMessageIndexName; + this.conversationId = conversationId; } @Override @@ -53,25 +68,29 @@ public void save(String id, Message message) { @Override public void save(String id, Message message, ActionListener listener) { - IndexRequest indexRequest = new IndexRequest(indexName); - try { - ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage)message; - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); - indexRequest.source(builder); - client.index(indexRequest, listener); - } catch (IOException e) { - throw new RuntimeException(e); - } + mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { + if (created) { + IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName); + ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + indexRequest.source(builder); + client.index(indexRequest, listener); + } else { + listener.onFailure(new RuntimeException("Failed to create memory message index")); + } + }, e -> { + listener.onFailure(new RuntimeException("Failed to create memory message index")); + })); } @Override public void getMessages(String id, ActionListener listener) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(indexName); + searchRequest.indices(memoryMessageIndexName); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.size(10000); - QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder("session_id", id); + QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(SESSION_ID, id); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.must(sessionIdQueryBuilder); @@ -95,4 +114,57 @@ public void clear() { public void remove(String id) { } + public static class Factory implements Memory.Factory { + private Client client; + private MLIndicesHandler mlIndicesHandler; + private String memoryMetaIndexName = ML_MEMORY_META_INDEX; + private String memoryMessageIndexName = ML_MEMORY_MESSAGE_INDEX; + + public void init(Client client, MLIndicesHandler mlIndicesHandler) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map != null) { + if (map.containsKey("memory_index_name")) { + memoryMetaIndexName = (String) map.get("memory_index_name"); + } + if (map.containsKey("memory_message_index_name")) { + memoryMessageIndexName = (String) map.get("memory_message_index_name"); + } + if (map.containsKey(SESSION_ID)) { + String conversationId = (String) map.get(SESSION_ID); + GetRequest getRequest = new GetRequest(memoryMetaIndexName).id(conversationId); + client.get(getRequest, ActionListener.wrap(r -> { + listener.onResponse(new ConversationIndexMemory(client, mlIndicesHandler, memoryMetaIndexName, memoryMessageIndexName, r.getId())); + }, e-> { + listener.onFailure(new IllegalArgumentException("Can't find conversation " + conversationId)); + })); + } else if (map.containsKey("question")) { + String question = (String) map.get("question"); + mlIndicesHandler.initMemoryMetaIndex(ActionListener.wrap(created -> { + if (created) { + IndexRequest indexRequest = new IndexRequest(memoryMetaIndexName); + indexRequest.source(ImmutableMap.of("name", question, "created_time", Instant.now().toEpochMilli())); + client.index(indexRequest, ActionListener.wrap(r-> { + listener.onResponse(new ConversationIndexMemory(client, mlIndicesHandler, memoryMetaIndexName, memoryMessageIndexName, r.getId())); + }, e-> { + listener.onFailure(e); + })); + } else { + listener.onFailure(new RuntimeException("Failed to create memory meta index")); + } + }, e -> { + listener.onFailure(new RuntimeException("Failed to create memory meta index")); + })); + } else { + listener.onFailure(new IllegalArgumentException("Invalid input parameter. Must set conversation id or question")); + } + } else { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating ConversationIndexMemory")); + } + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 2485881dca..7f82ed165a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -26,14 +26,14 @@ * This tool supports running any Agent. */ @Log4j2 -@ToolAnnotation(AgentTool.NAME) +@ToolAnnotation(AgentTool.TYPE) public class AgentTool implements Tool { - public static final String NAME = "AgentTool"; + public static final String TYPE = "AgentTool"; private final Client client; private String agentId; @Setter @Getter - private String alias; + private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent."; @Getter @Setter @@ -60,7 +60,7 @@ public void run(Map parameters, ActionListener listener) @Override public String getType() { - return null; + return TYPE; } @Override @@ -68,16 +68,6 @@ public String getVersion() { return null; } - @Override - public String getName() { - return AgentTool.NAME; - } - - @Override - public void setName(String s) { - - } - @Override public boolean validate(Map parameters) { return true; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index adacf88a8c..bfb6194631 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -48,20 +48,18 @@ import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; import static org.opensearch.ml.common.utils.StringUtils.gson; -@ToolAnnotation(CatIndexTool.NAME) +@ToolAnnotation(CatIndexTool.TYPE) public class CatIndexTool implements Tool { - public static final String NAME = "CatIndexTool"; + public static final String TYPE = "CatIndexTool"; private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information."; @Setter @Getter - private String name = CatIndexTool.NAME; + private String name = CatIndexTool.TYPE; @Getter @Setter private String description = DEFAULT_DESCRIPTION; @Getter - private String type; - @Getter private String version; private Client client; @@ -176,6 +174,11 @@ public void onFailure(final Exception e) { ); } + @Override + public String getType() { + return TYPE; + } + /** * We're using the Get Settings API here to resolve the authorized indices for the user. * This is because the Cluster State and Cluster Health APIs do not filter output based diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index bf2a79b38f..549362fe89 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -30,12 +30,12 @@ * This tool supports running any ml-commons model. */ @Log4j2 -@ToolAnnotation(MLModelTool.NAME) +@ToolAnnotation(MLModelTool.TYPE) public class MLModelTool implements Tool { - public static final String NAME = "MLModelTool"; + public static final String TYPE = "MLModelTool"; @Setter @Getter - private String alias; + private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to run any model."; @Getter @Setter private String description = DEFAULT_DESCRIPTION; @@ -80,7 +80,7 @@ public void run(Map parameters, ActionListener listener) @Override public String getType() { - return null; + return TYPE; } @Override @@ -88,17 +88,6 @@ public String getVersion() { return null; } - - @Override - public String getName() { - return MLModelTool.NAME; - } - - @Override - public void setName(String s) { - - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java index be14d9647c..aeb8d0be33 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java @@ -19,12 +19,12 @@ import static org.opensearch.ml.engine.utils.ScriptUtils.executeScript; -@ToolAnnotation(MathTool.NAME) +@ToolAnnotation(MathTool.TYPE) public class MathTool implements Tool { - public static final String NAME = "MathTool"; + public static final String TYPE = "MathTool"; @Setter @Getter - private String alias; + private String name = TYPE; @Setter private ScriptService scriptService; @@ -58,7 +58,7 @@ public void run(Map parameters, ActionListener listener) @Override public String getType() { - return null; + return TYPE; } @Override @@ -66,16 +66,6 @@ public String getVersion() { return null; } - @Override - public String getName() { - return MathTool.NAME; - } - - @Override - public void setName(String s) { - - } - @Override public boolean validate(Map parameters) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java index e5898a00d2..e4ed766b41 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java @@ -24,12 +24,12 @@ @Log4j2 -@ToolAnnotation(PainlessScriptTool.NAME) +@ToolAnnotation(PainlessScriptTool.TYPE) public class PainlessScriptTool implements Tool { - public static final String NAME = "PainlessScriptTool"; + public static final String TYPE = "PainlessScriptTool"; @Setter @Getter - private String alias; + private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to get index information."; @Getter @Setter private String description = DEFAULT_DESCRIPTION; @@ -66,7 +66,7 @@ public void run(Map parameters, ActionListener listener) @Override public String getType() { - return null; + return TYPE; } @Override @@ -74,17 +74,6 @@ public String getVersion() { return null; } - - @Override - public String getName() { - return PainlessScriptTool.NAME; - } - - @Override - public void setName(String s) { - - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index 68c5e7a320..ce5d1ab0f8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -34,11 +34,11 @@ * This tool supports neural search with embedding models and knn index. */ @Log4j2 -@ToolAnnotation(VectorDBTool.NAME) +@ToolAnnotation(VectorDBTool.TYPE) public class VectorDBTool implements Tool { - public static final String NAME = "VectorDBTool"; + public static final String TYPE = "VectorDBTool"; @Setter @Getter - private String alias; + private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; @Getter @Setter private String description = DEFAULT_DESCRIPTION; @@ -113,7 +113,7 @@ public void run(Map parameters, ActionListener listener) @Override public String getType() { - return null; + return TYPE; } @Override @@ -121,16 +121,6 @@ public String getVersion() { return null; } - @Override - public String getName() { - return NAME; - } - - @Override - public void setName(String s) { - - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java index b696ecec73..12353177db 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -238,7 +238,7 @@ public void testRunAsyncIndexStats() throws Exception { @Test public void testTool() { Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(CatIndexTool.NAME, tool.getName()); + assertEquals(CatIndexTool.TYPE, tool.getName()); assertTrue(tool.validate(indicesParams)); assertTrue(tool.validate(otherParams)); assertFalse(tool.validate(emptyParams)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index c49026c637..346967fa9f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -28,9 +28,9 @@ import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index e40bacc207..4cadcc936a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.exceptions.MetaDataException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java index 94d4b5a8a7..4e29db680d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java @@ -17,8 +17,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a5f12876eb..5abb3a7497 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -50,9 +50,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index 1227703a21..683536e21e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -34,8 +34,8 @@ import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import lombok.extern.log4j.Log4j2; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index c3012fdade..327ae2ddae 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -15,7 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 89c8776aa1..aaa85f76fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.search.SearchHit; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 94cbcf5364..fa758b01c5 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -31,8 +31,8 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 1c6ca7ec57..e72046a5c0 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -110,8 +110,8 @@ import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.utils.FileUtils; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableSet; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 0b12d607a4..0b411ab013 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -135,7 +135,8 @@ import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.engine.memory.ConversationBufferWindowMemory; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.tools.AgentTool; import org.opensearch.ml.engine.tools.CatIndexTool; @@ -145,8 +146,6 @@ import org.opensearch.ml.engine.tools.VectorDBTool; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction; @@ -475,21 +474,30 @@ public Collection createComponents( AgentTool.Factory.getInstance().init(client); CatIndexTool.Factory.getInstance().init(client, clusterService); PainlessScriptTool.Factory.getInstance().init(client, scriptService); - toolFactories.put(MLModelTool.NAME, MLModelTool.Factory.getInstance()); - toolFactories.put(MathTool.NAME, MathTool.Factory.getInstance()); - toolFactories.put(VectorDBTool.NAME, VectorDBTool.Factory.getInstance()); - toolFactories.put(AgentTool.NAME, AgentTool.Factory.getInstance()); - toolFactories.put(CatIndexTool.NAME, CatIndexTool.Factory.getInstance()); - toolFactories.put(PainlessScriptTool.NAME, PainlessScriptTool.Factory.getInstance()); + toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); + toolFactories.put(MathTool.TYPE, MathTool.Factory.getInstance()); + toolFactories.put(VectorDBTool.TYPE, VectorDBTool.Factory.getInstance()); + toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); + toolFactories.put(CatIndexTool.TYPE, CatIndexTool.Factory.getInstance()); + toolFactories.put(PainlessScriptTool.TYPE, PainlessScriptTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); } - Map memoryMap = new HashMap<>(); - memoryMap.put(ConversationBufferWindowMemory.TYPE, new ConversationBufferWindowMemory()); - memoryMap.put(ConversationIndexMemory.TYPE, new ConversationIndexMemory(client)); - MLAgentExecutor agentExecutor = new MLAgentExecutor(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap); + Map memoryFactoryMap = new HashMap<>(); + ConversationIndexMemory.Factory conversationIndexMemoryFactory = new ConversationIndexMemory.Factory(); + conversationIndexMemoryFactory.init(client, mlIndicesHandler); + memoryFactoryMap.put(ConversationIndexMemory.TYPE, conversationIndexMemoryFactory); + + MLAgentExecutor agentExecutor = new MLAgentExecutor( + client, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap + ); MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator); MLEngineClassLoader.register(FunctionName.AGENT, agentExecutor); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 0c76e6f684..fb526e6e55 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -20,7 +20,7 @@ import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 87e96a0d67..41b665a416 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -48,7 +48,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.ActionName; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index ca4a192f16..23bb1a13bf 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableSet; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index a69f5ea0f4..9ae4d9991e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -29,7 +29,7 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index cdb6a4dc13..81a50ac760 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 7ac8662b18..3ffbfdfd6f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -41,8 +41,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java index 269ac30d95..f54f034b64 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java @@ -29,8 +29,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index e50a352d50..77c2127c66 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -54,9 +54,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 6fba3efe59..292183f8ed 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -39,8 +39,8 @@ import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 728dc03f5f..cd5a71ea04 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -70,7 +70,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableSet; import org.opensearch.ml.utils.TestHelper; diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java index 9acb84633a..3904097aae 100644 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java @@ -31,6 +31,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchIntegTestCase; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java index 5ec2ab686c..93c1bd2603 100644 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f7eb759026..fe22a1bcb9 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -35,8 +35,8 @@ import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 1adb97f93a..fe9ee1b717 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -99,7 +99,7 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableSet; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 11e9bc3441..9011746797 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -41,7 +41,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 1dd5423649..b6cbbb496b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -55,7 +55,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index 7d0d6d11c0..4714a6c43b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -39,7 +39,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 94d256a4bd..b779c186ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -49,7 +49,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index d99f1e4bbc..a21abe900e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -52,8 +52,8 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java index 497a7c4229..78141b1781 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java @@ -28,7 +28,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; public class MockHelper { diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java index 98136c12b7..61e695f6a0 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java @@ -7,6 +7,8 @@ import org.opensearch.core.action.ActionListener; +import java.util.Map; + /** * A general memory interface. * @param @@ -46,4 +48,14 @@ default void getMessages(String id, ActionListener listener){} * @param id memory id */ void remove(String id); + + interface Factory { + /** + * Create an instance of this Memory. + * + * @param params Parameters for the memory + * @param listener Action listern for the memory creation action + */ + void create(Map params, ActionListener listener); + } }