Skip to content

Commit

Permalink
Added more tests and updated MLClient initialization
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 19, 2023
1 parent 75fb51b commit e2b640e
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 43 deletions.
34 changes: 0 additions & 34 deletions src/main/java/org/opensearch/flowframework/client/MLClient.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.*;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to create a connector for a remote model
*/
public class CreateConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

Expand Down Expand Up @@ -53,8 +55,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) {

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
logger.error("Failed to deploy model");
deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -76,7 +77,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
@Override
public void onFailure(Exception e) {
logger.error("Failed to register model");
registerModelFuture.completeExceptionally(new IOException("Failed to register model"));
registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class WorkflowStepFactory {
*
* @param clusterService The OpenSearch cluster service
* @param client The OpenSearch client steps can use
* @param mlClient Machine Learning client to perform ml operations
*/

public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.opensearch.flowframework.workflow;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
Expand All @@ -18,11 +20,13 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -59,7 +63,7 @@ public void setUp() throws Exception {

}

public void testCreateConnector() throws IOException {
public void testCreateConnector() throws IOException, ExecutionException, InterruptedException {

String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient);
Expand All @@ -78,7 +82,29 @@ public void testCreateConnector() throws IOException {
verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get("connector-id"));

}

public void testCreateConnectorFailure() throws IOException {
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(List.of(inputData));

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to create connector", ex.getCause().getMessage());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
Expand All @@ -20,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
Expand Down Expand Up @@ -47,8 +50,7 @@ public void setUp() throws Exception {

}

public void testDeployModel() {

public void testDeployModel() throws ExecutionException, InterruptedException {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
Expand All @@ -70,6 +72,28 @@ public void testDeployModel() {
verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

assertTrue(future.isDone());
assertEquals(status, future.get().getContent().get("deploy_model_status"));
}

public void testDeployModelFailure() {
DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = deployModel.execute(List.of(inputData));

verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to deploy model", ex.getCause().getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -95,7 +97,30 @@ public void testRegisterModel() throws ExecutionException, InterruptedException
verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());

assertTrue(future.isDone());
assertEquals(modelId, future.get().getContent().get("model_id"));
assertEquals(status, future.get().getContent().get("register_model_status"));

}

public void testRegisterModelFailure() {
RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to register model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData));

verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to register model", ex.getCause().getMessage());
}

}

0 comments on commit e2b640e

Please sign in to comment.