From 9b10b2384fc37385452c83667597ef3f5c13a282 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 12 Oct 2023 08:43:31 -0700 Subject: [PATCH] Adds Register and Deploy Model Step for remote model (#52) * Initial UploadModel integration Signed-off-by: Owais Kazi * Implemented Register Model Step Signed-off-by: Owais Kazi * Integrated register for remote model Signed-off-by: Owais Kazi * Integrated deploy model Signed-off-by: Owais Kazi * Separated Register and Deploy Steps Signed-off-by: Owais Kazi * Added tests Signed-off-by: Owais Kazi * Added NodeClient Signed-off-by: Owais Kazi * Added javadocs Signed-off-by: Owais Kazi * Addressed PR comments Signed-off-by: Owais Kazi * Addressed PR comments Signed-off-by: Owais Kazi * Addressed PR comments - 2 Signed-off-by: Owais Kazi * Fixed test failure Signed-off-by: Owais Kazi * Addressed PR comments Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- build.gradle | 3 + gradle.properties | 38 ++++ src/main/java/demo/DemoWorkflowStep.java | 2 +- src/main/java/demo/TemplateParseDemo.java | 1 - .../flowframework/client/MLClient.java | 8 +- .../flowframework/common/CommonValue.java | 9 + .../workflow/DeployModelStep.java | 81 ++++++++ .../flowframework/workflow/GetTask.java | 65 ++++++ .../workflow/RegisterModelStep.java | 152 ++++++++++++++ .../workflow/WorkflowStepFactory.java | 3 + .../FlowFrameworkPluginTests.java | 12 +- .../workflow/DeployModelStepTests.java | 82 ++++++++ .../workflow/RegisterModelStepTests.java | 108 ++++++++++ .../workflow/WorkflowProcessSorterTests.java | 1 + src/test/resources/template/demo.json | 33 +++ .../resources/template/finaltemplate.json | 190 +++++++++--------- 16 files changed, 686 insertions(+), 102 deletions(-) create mode 100644 gradle.properties create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/GetTask.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java diff --git a/build.gradle b/build.gradle index 6265e5bd9..36e2a6605 100644 --- a/build.gradle +++ b/build.gradle @@ -112,6 +112,9 @@ publishing { allprojects { group = opensearch_group version = "${opensearch_build}" +} + +java { targetCompatibility = JavaVersion.VERSION_11 sourceCompatibility = JavaVersion.VERSION_11 } diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 000000000..200c21212 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,38 @@ +# +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# +# Modifications Copyright OpenSearch Contributors. See +# GitHub history for details. +# + + +# Enable build caching +org.gradle.caching=true +org.gradle.warning.mode=none +org.gradle.parallel=true +# Workaround for https://github.com/diffplug/spotless/issues/834 +org.gradle.jvmargs=-Xmx3g -XX:+HeapDumpOnOutOfMemoryError -Xss2m \ + --add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED +options.forkOptions.memoryMaximumSize=2g + +# Disable duplicate project id detection +# See https://docs.gradle.org/current/userguide/upgrading_version_6.html#duplicate_project_names_may_cause_publication_to_fail +systemProp.org.gradle.dependency.duplicate.project.detection=false + +# Enforce the build to fail on deprecated gradle api usage +systemProp.org.gradle.warning.mode=fail + +# forcing to use TLS1.2 to avoid failure in vault +# see https://github.com/hashicorp/vault/issues/8750#issuecomment-631236121 +systemProp.jdk.tls.client.protocols=TLSv1.2 + +# jvm args for faster test execution by default +systemProp.tests.jvm.argline=-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m diff --git a/src/main/java/demo/DemoWorkflowStep.java b/src/main/java/demo/DemoWorkflowStep.java index 037d9b6f6..267a8c8ab 100644 --- a/src/main/java/demo/DemoWorkflowStep.java +++ b/src/main/java/demo/DemoWorkflowStep.java @@ -37,7 +37,7 @@ public CompletableFuture execute(List data) { CompletableFuture.runAsync(() -> { try { Thread.sleep(this.delay); - future.complete(null); + future.complete(WorkflowData.EMPTY); } catch (InterruptedException e) { future.completeExceptionally(e); } diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index e9bddb749..a2d0db443 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -55,7 +55,6 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/client/MLClient.java b/src/main/java/org/opensearch/flowframework/client/MLClient.java index a1ef7d61e..977e24588 100644 --- a/src/main/java/org/opensearch/flowframework/client/MLClient.java +++ b/src/main/java/org/opensearch/flowframework/client/MLClient.java @@ -8,7 +8,7 @@ */ package org.opensearch.flowframework.client; -import org.opensearch.client.node.NodeClient; +import org.opensearch.client.Client; import org.opensearch.ml.client.MachineLearningNodeClient; /** @@ -22,12 +22,12 @@ private MLClient() {} /** * Creates machine learning client. * - * @param nodeClient node client of OpenSearch. + * @param client client of OpenSearch. * @return machine learning client from ml-commons. */ - public static MachineLearningNodeClient createMLClient(NodeClient nodeClient) { + public static MachineLearningNodeClient createMLClient(Client client) { if (INSTANCE == null) { - INSTANCE = new MachineLearningNodeClient(nodeClient); + INSTANCE = new MachineLearningNodeClient(client); } return INSTANCE; } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index a8fdf2929..0bf8ae890 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -19,4 +19,13 @@ public class CommonValue { public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context"; public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + public static final String MODEL_ID = "model_id"; + public static final String FUNCTION_NAME = "function_name"; + public static final String MODEL_NAME = "name"; + public static final String MODEL_VERSION = "model_version"; + public static final String MODEL_GROUP_ID = "model_group_id"; + public static final String DESCRIPTION = "description"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String MODEL_FORMAT = "model_format"; + public static final String MODEL_CONFIG = "model_config"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java new file mode 100644 index 000000000..e4c9b1a14 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.client.MLClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; + +/** + * Step to deploy a model + */ +public class DeployModelStep implements WorkflowStep { + private static final Logger logger = LogManager.getLogger(DeployModelStep.class); + + private Client client; + static final String NAME = "deploy_model"; + + /** + * Instantiate this class + * @param client client to instantiate MLClient + */ + public DeployModelStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture deployModelFuture = new CompletableFuture<>(); + + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLDeployModelResponse mlDeployModelResponse) { + logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); + deployModelFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus()))) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Model deployment failed"); + deployModelFuture.completeExceptionally(e); + } + }; + + String modelId = null; + + for (WorkflowData workflowData : data) { + if (workflowData.getContent().containsKey(MODEL_ID)) { + modelId = (String) workflowData.getContent().get(MODEL_ID); + break; + } + } + machineLearningNodeClient.deploy(modelId, actionListener); + return deployModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java new file mode 100644 index 000000000..a3d1caa4e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.task.MLTaskGetResponse; + +/** + * Step to get modelID of a registered local model + */ +@SuppressForbidden(reason = "This class is for the future work of registering local model") +public class GetTask { + + private static final Logger logger = LogManager.getLogger(GetTask.class); + private MachineLearningNodeClient machineLearningNodeClient; + private String taskId; + + /** + * Instantiate this class + * @param machineLearningNodeClient client to instantiate ml-commons APIs + * @param taskId taskID of the model + */ + public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) { + this.machineLearningNodeClient = machineLearningNodeClient; + this.taskId = taskId; + } + + /** + * Invokes get task API of ml-commons + */ + public void getTask() { + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLTask mlTask) { + if (mlTask.getState() == MLTaskState.COMPLETED) { + logger.info("Model registration successful"); + MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build(); + logger.info("Response from task {}", response); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Model registration failed"); + } + }; + + machineLearningNodeClient.getTask(taskId, actionListener); + + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java new file mode 100644 index 000000000..b97c56d57 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.client.MLClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; + +/** + * Step to register a remote model + */ +public class RegisterModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + + private Client client; + + static final String NAME = "register_model"; + + /** + * Instantiate this class + * @param client client to instantiate MLClient + */ + public RegisterModelStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture registerModelFuture = new CompletableFuture<>(); + + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("Model registration successful"); + registerModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry("model_id", mlRegisterModelResponse.getModelId()), + Map.entry("register_model_status", mlRegisterModelResponse.getStatus()) + ) + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register model"); + registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + } + }; + + FunctionName functionName = null; + String modelName = null; + String modelVersion = null; + String modelGroupId = null; + String connectorId = null; + String description = null; + MLModelFormat modelFormat = null; + MLModelConfig modelConfig = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_FORMAT: + modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); + break; + case MODEL_CONFIG: + modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case CONNECTOR_ID: + connectorId = (String) content.get(CONNECTOR_ID); + break; + default: + break; + + } + } + } + + if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) { + + // TODO: Add model Config and type cast correctly + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName(modelName) + .description(description) + .connectorId(connectorId) + .build(); + + machineLearningNodeClient.register(mlInput, actionListener); + } + + return registerModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 73468f5f6..fdb82ef0b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -31,6 +31,7 @@ public class WorkflowStepFactory { * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use */ + public WorkflowStepFactory(ClusterService clusterService, Client client) { populateMap(clusterService, client); } @@ -38,6 +39,8 @@ public WorkflowStepFactory(ClusterService clusterService, Client client) { private void populateMap(ClusterService clusterService, Client client) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(client)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index d211e3928..ea8a3b520 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -10,6 +10,8 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.node.NodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -23,13 +25,21 @@ public class FlowFrameworkPluginTests extends OpenSearchTestCase { private Client client; + private NodeClient nodeClient; + + private AdminClient adminClient; + + private ClusterAdminClient clusterAdminClient; private ThreadPool threadPool; @Override public void setUp() throws Exception { super.setUp(); client = mock(Client.class); - when(client.admin()).thenReturn(mock(AdminClient.class)); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java new file mode 100644 index 000000000..87db208c2 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.client.NoOpNodeClient; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class DeployModelStepTests extends OpenSearchTestCase { + + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId"))); + + MockitoAnnotations.openMocks(this); + + nodeClient = new NoOpNodeClient("xyz"); + + } + + public void testDeployModel() { + + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + + DeployModelStep deployModel = new DeployModelStep(nodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + CompletableFuture future = deployModel.execute(List.of(inputData)); + + // TODO: Find a way to verify the below + // verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java new file mode 100644 index 000000000..b1a2b2fc0 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.client.NoOpNodeClient; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class RegisterModelStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + @Mock + ActionListener registerModelActionListener; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("function_name", "remote"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry("connector_id", "abcdefg") + ) + ); + + nodeClient = new NoOpNodeClient("xyz"); + } + + public void testRegisterModel() throws ExecutionException, InterruptedException { + + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(FunctionName.from("REMOTE")) + .modelName("testModelName") + .description("description") + .connectorId("abcdefgh") + .build(); + + RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + // TODO: Find a way to verify the below + // verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index eab29121d..e8ada0e15 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -60,6 +60,7 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index e27158bff..8719bf2fe 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -11,6 +11,10 @@ "id": "fetch_model", "type": "demo_delay_3" }, + { + "id": "create_index", + "type": "demo_delay_3" + }, { "id": "create_ingest_pipeline", "type": "demo_delay_3" @@ -22,11 +26,32 @@ { "id": "create_neural_search_index", "type": "demo_delay_3" + }, + { + "id": "register_model", + "type": "demo_delay_3", + "inputs": { + "name": "openAI-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model", + "connector_id": "uDna54oB76l1MtYJF84U" + } + }, + { + "id": "deploy_model", + "type": "demo_delay_3", + "inputs": { + "model_id": "abc" + } } ], "edges": [ { "source": "fetch_model", + "dest": "create_index" + }, + { + "source": "create_index", "dest": "create_ingest_pipeline" }, { @@ -40,6 +65,14 @@ { "source": "create_search_pipeline", "dest": "create_neural_search_index" + }, + { + "source": "create_neural_search_index", + "dest": "register_model" + }, + { + "source": "register_model", + "dest": "deploy_model" } ] } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index fe1a57e36..a950f069f 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -1,101 +1,101 @@ { - "name": "semantic-search", - "description": "My semantic search use case", - "use_case": "SEMANTIC_SEARCH", - "operations": [ - "PROVISION", - "INGEST", - "QUERY" - ], - "version": { - "template": "1.0.0", - "compatibility": [ - "2.9.0", - "3.0.0" - ] + "name": "semantic-search", + "description": "My semantic search use case", + "use_case": "SEMANTIC_SEARCH", + "operations": [ + "PROVISION", + "INGEST", + "QUERY" + ], + "version": { + "template": "1.0.0", + "compatibility": [ + "2.9.0", + "3.0.0" + ] + }, + "user_inputs": { + "index_name": "my-knn-index", + "index_settings": {} + }, + "workflows": { + "provision": { + "nodes": [{ + "id": "create_index", + "type": "create_index", + "inputs": { + "name": "user_inputs.index_name", + "settings": "user_inputs.index_settings", + "node_timeout": "10s" + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "inputs": { + "name": "my-ingest-pipeline", + "description": "some description", + "processors": [{ + "type": "text_embedding", + "params": { + "model_id": "my-existing-model-id", + "input_field": "text_passage", + "output_field": "text_embedding" + } + }], + "node_timeout": "10s" + } + } + ], + "edges": [{ + "source": "create_index", + "dest": "create_ingest_pipeline" + }] }, - "user_inputs": { - "index_name": "my-knn-index", - "index_settings": {} + "ingest": { + "user_params": { + "document": "doc" + }, + "nodes": [{ + "id": "ingest_index", + "type": "ingest_index", + "inputs": { + "index": "user_inputs.index_name", + "ingest_pipeline": "my-ingest-pipeline", + "document": "user_params.document", + "node_timeout": "10s" + } + }] }, - "workflows": { - "provision": { - "nodes": [{ - "id": "create_index", - "type": "create_index", - "inputs": { - "name": "user_inputs.index_name", - "settings": "user_inputs.index_settings", - "node_timeout": "10s" - } - }, - { - "id": "create_ingest_pipeline", - "type": "create_ingest_pipeline", - "inputs": { - "name": "my-ingest-pipeline", - "description": "some description", - "processors": [{ - "type": "text_embedding", - "params": { - "model_id": "my-existing-model-id", - "input_field": "text_passage", - "output_field": "text_embedding" - } - }], - "node_timeout": "10s" - } - } - ], - "edges": [{ - "source": "create_index", - "dest": "create_ingest_pipeline" - }] - }, - "ingest": { - "user_params": { - "document": "doc" - }, - "nodes": [{ - "id": "ingest_index", - "type": "ingest_index", - "inputs": { - "index": "user_inputs.index_name", - "ingest_pipeline": "my-ingest-pipeline", - "document": "user_params.document", - "node_timeout": "10s" - } - }] - }, - "query": { - "user_params": { - "plaintext": "string" - }, - "nodes": [{ - "id": "transform_query", - "type": "transform_query", - "inputs": { - "template": "neural-search-template-1", - "plaintext": "user_params.plaintext", - "node_timeout": "10s" - } - }, - { - "id": "query_index", - "type": "query_index", - "inputs": { - "index": "user_inputs.index_name", - "query": "{{output-from-prev-step}}.query", - "search_request_processors": [], - "search_response_processors": [], - "node_timeout": "10s" - } - } - ], - "edges": [{ - "source": "transform_query", - "dest": "query_index" - }] + "query": { + "user_params": { + "plaintext": "string" + }, + "nodes": [{ + "id": "transform_query", + "type": "transform_query", + "inputs": { + "template": "neural-search-template-1", + "plaintext": "user_params.plaintext", + "node_timeout": "10s" + } + }, + { + "id": "query_index", + "type": "query_index", + "inputs": { + "index": "user_inputs.index_name", + "query": "{{output-from-prev-step}}.query", + "search_request_processors": [], + "search_response_processors": [], + "node_timeout": "10s" + } } + ], + "edges": [{ + "source": "transform_query", + "dest": "query_index" + }] } + } }