From e2b640e47a56506e31f1bba499043fd0a1c9b1b6 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 19 Oct 2023 14:51:56 -0700 Subject: [PATCH] Added more tests and updated MLClient initialization Signed-off-by: Owais Kazi --- .../flowframework/client/MLClient.java | 34 ------------------- .../workflow/CreateConnectorStep.java | 13 +++++-- .../workflow/DeployModelStep.java | 6 ++-- .../workflow/RegisterModelStep.java | 5 +-- .../workflow/WorkflowStepFactory.java | 1 + .../workflow/CreateConnectorStepTests.java | 28 ++++++++++++++- .../workflow/DeployModelStepTests.java | 28 +++++++++++++-- .../workflow/RegisterModelStepTests.java | 25 ++++++++++++++ 8 files changed, 97 insertions(+), 43 deletions(-) delete mode 100644 src/main/java/org/opensearch/flowframework/client/MLClient.java diff --git a/src/main/java/org/opensearch/flowframework/client/MLClient.java b/src/main/java/org/opensearch/flowframework/client/MLClient.java deleted file mode 100644 index 977e24588..000000000 --- a/src/main/java/org/opensearch/flowframework/client/MLClient.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.client; - -import org.opensearch.client.Client; -import org.opensearch.ml.client.MachineLearningNodeClient; - -/** - Class to initiate an instance of MLClient - */ -public class MLClient { - private static MachineLearningNodeClient INSTANCE; - - private MLClient() {} - - /** - * Creates machine learning client. - * - * @param client client of OpenSearch. - * @return machine learning client from ml-commons. - */ - public static MachineLearningNodeClient createMLClient(Client client) { - if (INSTANCE == null) { - INSTANCE = new MachineLearningNodeClient(client); - } - return INSTANCE; - } -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index afae055bc..5e040baaf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -29,8 +29,17 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -import static org.opensearch.flowframework.common.CommonValue.*; - +import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; + +/** + * Step to create a connector for a remote model + */ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 07558fe0c..9dd183e37 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -53,8 +55,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @Override public void onFailure(Exception e) { - logger.error("Model deployment failed"); - deployModelFuture.completeExceptionally(e); + logger.error("Failed to deploy model"); + deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index bdbda66e4..9a430c6f5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; @@ -18,7 +20,6 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -76,7 +77,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new IOException("Failed to register model")); + registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index be52a5fcd..5aabd679f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -31,6 +31,7 @@ public class WorkflowStepFactory { * * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use + * @param mlClient Machine Learning client to perform ml operations */ public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 65329661f..9fd125a1a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -9,6 +9,8 @@ package org.opensearch.flowframework.workflow; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -18,11 +20,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -59,7 +63,7 @@ public void setUp() throws Exception { } - public void testCreateConnector() throws IOException { + public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); @@ -78,7 +82,29 @@ public void testCreateConnector() throws IOException { verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector-id")); } + public void testCreateConnectorFailure() throws IOException { + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create connector", ex.getCause().getMessage()); + } + } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index e32a7c75f..4cdfaebae 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -11,6 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -20,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -47,8 +50,7 @@ public void setUp() throws Exception { } - public void testDeployModel() { - + public void testDeployModel() throws ExecutionException, InterruptedException { String taskId = "taskId"; String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; @@ -70,6 +72,28 @@ public void testDeployModel() { verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(status, future.get().getContent().get("deploy_model_status")); + } + + public void testDeployModelFailure() { + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + CompletableFuture future = deployModel.execute(List.of(inputData)); + + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to deploy model", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 1e8026a15..59fb1b173 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -11,6 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; @@ -95,7 +97,30 @@ public void testRegisterModel() throws ExecutionException, InterruptedException verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get("model_id")); + assertEquals(status, future.get().getContent().get("register_model_status")); } + public void testRegisterModelFailure() { + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register model", ex.getCause().getMessage()); + } + }