diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 5a849fd89..ac2cb2a86 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -81,6 +81,10 @@ private CommonValue() {} public static final String OUTPUT_FIELD_NAME = "output_field_name"; /** Model Id field */ public static final String MODEL_ID = "model_id"; + /** Task Id field */ + public static final String TASK_ID = "task_id"; + /** Register Model Status field */ + public static final String REGISTER_MODEL_STATUS = "register_model_status"; /** Function Name field */ public static final String FUNCTION_NAME = "function_name"; /** Name field */ @@ -95,8 +99,20 @@ private CommonValue() {} public static final String CONNECTOR_ID = "connector_id"; /** Model format field */ public static final String MODEL_FORMAT = "model_format"; + /** Model content hash value field */ + public static final String MODEL_CONTENT_HASH_VALUE = "model_content_hash_value"; + /** URL field */ + public static final String URL = "url"; /** Model config field */ public static final String MODEL_CONFIG = "model_config"; + /** Model type field */ + public static final String MODEL_TYPE = "model_type"; + /** Embedding dimension field */ + public static final String EMBEDDING_DIMENSION = "embedding_dimension"; + /** Framework type field */ + public static final String FRAMEWORK_TYPE = "framework_type"; + /** All config field */ + public static final String ALL_CONFIG = "all_config"; /** Version field */ public static final String VERSION_FIELD = "version"; /** Connector protocol field */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 2d2008b3d..00c029a3e 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -335,7 +335,7 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL } }, e -> { - String errorMessage = "Failed to create global_context index"; + String errorMessage = "Failed to create workflow_state index"; logger.error(errorMessage, e); listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); })); diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java new file mode 100644 index 000000000..893f34a0d --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; + +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.TASK_ID; + +/** + * Step to retrieve an ML Task + */ +public class GetMLTaskStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class); + private MachineLearningNodeClient mlClient; + static final String NAME = "get_ml_task"; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public GetMLTaskStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture getMLTaskFuture = new CompletableFuture<>(); + + ActionListener actionListener = ActionListener.wrap(response -> { + + // TODO : Add retry capability if response status is not COMPLETED : + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/158 + + logger.info("ML Task retrieval successful"); + getMLTaskFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())) + ) + ); + }, exception -> { + logger.error("Failed to retrieve ML Task"); + getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + }); + + String taskId = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case TASK_ID: + taskId = (String) content.get(TASK_ID); + break; + default: + break; + } + } + } + + if (taskId == null) { + logger.error("Failed to retrieve ML Task"); + getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); + } else { + mlClient.getTask(taskId, actionListener); + } + + return getMLTaskFuture; + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 9e1010ec1..35a3bdfff 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -38,7 +38,7 @@ */ public class ModelGroupStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + private static final Logger logger = LogManager.getLogger(ModelGroupStep.class); private MachineLearningNodeClient mlClient; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java new file mode 100644 index 000000000..17dd0b068 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -0,0 +1,200 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; + +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; +import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.TASK_ID; +import static org.opensearch.flowframework.common.CommonValue.URL; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; + +/** + * Step to register a local model + */ +public class RegisterLocalModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "register_local_model"; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("Local Model registration task creation successful"); + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), + Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) + ) + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register local model"); + registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String modelName = null; + String modelVersion = null; + String description = null; + MLModelFormat modelFormat = null; + String modelGroupId = null; + String modelContentHashValue = null; + String modelType = null; + String embeddingDimension = null; + FrameworkType frameworkType = null; + String allConfig = null; + String url = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + modelName = (String) content.get(NAME_FIELD); + break; + case VERSION_FIELD: + modelVersion = (String) content.get(VERSION_FIELD); + break; + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); + break; + case MODEL_FORMAT: + modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_TYPE: + modelType = (String) content.get(MODEL_TYPE); + break; + case EMBEDDING_DIMENSION: + embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); + break; + case FRAMEWORK_TYPE: + frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); + break; + case ALL_CONFIG: + allConfig = (String) content.get(ALL_CONFIG); + break; + case MODEL_CONTENT_HASH_VALUE: + modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); + break; + case URL: + url = (String) content.get(URL); + break; + default: + break; + + } + } + } + + if (Stream.of( + modelName, + modelVersion, + modelFormat, + modelGroupId, + modelType, + embeddingDimension, + frameworkType, + modelContentHashValue, + url + ).allMatch(x -> x != null)) { + + // Create Model configudation + TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() + .modelType(modelType) + .embeddingDimension(Integer.valueOf(embeddingDimension)) + .frameworkType(frameworkType); + if (allConfig != null) { + modelConfigBuilder.allConfig(allConfig); + } + MLModelConfig modelConfig = modelConfigBuilder.build(); + + // Create register local model input + MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder() + .modelName(modelName) + .version(modelVersion) + .modelFormat(modelFormat) + .modelGroupId(modelGroupId) + .hashValue(modelContentHashValue) + .modelConfig(modelConfig) + .url(url); + if (description != null) { + mlInputBuilder.description(description); + } + + MLRegisterModelInput mlInput = mlInputBuilder.build(); + + mlClient.register(mlInput, actionListener); + } else { + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); + } + + return registerLocalModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java similarity index 65% rename from src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java rename to src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index b6ae176d3..4dedc8bf2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -16,9 +16,8 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.List; @@ -31,45 +30,44 @@ import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; -import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; -import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; /** * Step to register a remote model */ -public class RegisterModelStep implements WorkflowStep { +public class RegisterRemoteModelStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + private static final Logger logger = LogManager.getLogger(RegisterRemoteModelStep.class); private MachineLearningNodeClient mlClient; - static final String NAME = "register_model"; + static final String NAME = "register_remote_model"; /** * Instantiate this class * @param mlClient client to instantiate MLClient */ - public RegisterModelStep(MachineLearningNodeClient mlClient) { + public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { this.mlClient = mlClient; } @Override public CompletableFuture execute(List data) { - CompletableFuture registerModelFuture = new CompletableFuture<>(); + CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("Model registration successful"); - registerModelFuture.complete( + logger.info("Remote Model registration successful"); + registerRemoteModelFuture.complete( new WorkflowData( Map.ofEntries( - Map.entry("model_id", mlRegisterModelResponse.getModelId()), - Map.entry("register_model_status", mlRegisterModelResponse.getStatus()) + Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ) ) ); @@ -77,43 +75,33 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { - logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + logger.error("Failed to register remote model"); + registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; - FunctionName functionName = null; String modelName = null; - String modelVersion = null; + FunctionName functionName = null; String modelGroupId = null; - String connectorId = null; String description = null; - MLModelFormat modelFormat = null; - MLModelConfig modelConfig = null; + String connectorId = null; + + // TODO : Handle inline connector configuration : https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/149 for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); + Map content = workflowData.getContent(); for (Entry entry : content.entrySet()) { switch (entry.getKey()) { - case FUNCTION_NAME: - functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); - break; case NAME_FIELD: modelName = (String) content.get(NAME_FIELD); break; - case MODEL_VERSION: - modelVersion = (String) content.get(MODEL_VERSION); + case FUNCTION_NAME: + functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); break; case MODEL_GROUP_ID: modelGroupId = (String) content.get(MODEL_GROUP_ID); break; - case MODEL_FORMAT: - modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); - break; - case MODEL_CONFIG: - modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); - break; case DESCRIPTION_FIELD: description = (String) content.get(DESCRIPTION_FIELD); break; @@ -127,28 +115,34 @@ public void onFailure(Exception e) { } } - if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) { + if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) { - // TODO: Add model Config and type cast correctly - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(functionName) .modelName(modelName) - .description(description) - .connectorId(connectorId) - .build(); + .connectorId(connectorId); + + if (modelGroupId != null) { + builder.modelGroupId(modelGroupId); + } + if (description != null) { + builder.description(description); + } + MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, actionListener); } else { - registerModelFuture.completeExceptionally( + registerRemoteModelFuture.completeExceptionally( new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) ); } - return registerModelFuture; + return registerRemoteModelFuture; } @Override public String getName() { return NAME; } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c30bdf87c..af83f0ad9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -40,10 +40,12 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); + stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(mlClient)); + stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); + stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 241a8ecbc..5bd88147b 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -39,14 +39,30 @@ "connector_id" ] }, - "register_model": { + "register_local_model": { "inputs":[ - "function_name", "name", - "description", - "connector_id" + "version", + "model_format", + "model_group_id", + "model_content_hash_value", + "model_type", + "embedding_dimension", + "framework_type", + "url" ], "outputs":[ + "task_id", + "register_model_status" + ] + }, + "register_remote_model": { + "inputs": [ + "name", + "function_name", + "connector_id" + ], + "outputs": [ "model_id", "register_model_status" ] @@ -67,5 +83,14 @@ "model_group_id", "model_group_status" ] + }, + "get_ml_task": { + "inputs":[ + "task_id" + ], + "outputs":[ + "model_id", + "register_model_status" + ] } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java new file mode 100644 index 000000000..3a83b1fdd --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.TASK_ID; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class GetMLTaskStepTests extends OpenSearchTestCase { + + private GetMLTaskStep getMLTaskStep; + private WorkflowData workflowData; + + @Mock + MachineLearningNodeClient mlNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + this.getMLTaskStep = new GetMLTaskStep(mlNodeClient); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test"))); + } + + public void testGetMLTaskSuccess() throws Exception { + String taskId = "test"; + String modelId = "abcd"; + MLTaskState status = MLTaskState.COMPLETED; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask(taskId, modelId, null, null, status, null, null, null, null, null, null, null, null, false); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).getTask(any(), any()); + + CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + + verify(mlNodeClient, times(1)).getTask(any(), any()); + + assertTrue(future.isDone()); + assertTrue(!future.isCompletedExceptionally()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status.name(), future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testGetMLTaskFailure() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IllegalArgumentException("test")); + return null; + }).when(mlNodeClient).getTask(any(), any()); + + CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("test", ex.getCause().getMessage()); + } + + public void testMissingInputs() { + CompletableFuture future = this.getMLTaskStep.execute(List.of(WorkflowData.EMPTY)); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Required fields are not provided", ex.getCause().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java similarity index 56% rename from src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index ea1518d75..d41096624 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -11,10 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -27,17 +25,20 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.TASK_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) -public class RegisterModelStepTests extends OpenSearchTestCase { - private WorkflowData inputData = WorkflowData.EMPTY; +public class RegisterLocalModelStepTests extends OpenSearchTestCase { + + private RegisterLocalModelStep registerLocalModelStep; + private WorkflowData workflowData; @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -46,6 +47,9 @@ public class RegisterModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); + MockitoAnnotations.openMocks(this); + this.registerLocalModelStep = new RegisterLocalModelStep(machineLearningNodeClient); + MLModelConfig config = TextEmbeddingModelConfig.builder() .modelType("testModelType") .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") @@ -53,73 +57,68 @@ public void setUp() throws Exception { .embeddingDimension(100) .build(); - MockitoAnnotations.openMocks(this); - - inputData = new WorkflowData( + this.workflowData = new WorkflowData( Map.ofEntries( - Map.entry("function_name", "remote"), Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), Map.entry("description", "description"), - Map.entry("connector_id", "abcdefg") + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry("model_group_id", "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com") ) ); } - public void testRegisterModel() throws ExecutionException, InterruptedException { + public void testRegisterLocalModelSuccess() throws Exception { String taskId = "abcd"; - String modelId = "efgh"; String status = MLTaskState.CREATED.name(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() - .functionName(FunctionName.from("REMOTE")) - .modelName("testModelName") - .description("description") - .connectorId("abcdefgh") - .build(); - - RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); - - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = registerModelStep.execute(List.of(inputData)); - - verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + CompletableFuture future = registerLocalModelStep.execute(List.of(workflowData)); + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); assertTrue(future.isDone()); - assertEquals(modelId, future.get().getContent().get("model_id")); - assertEquals(status, future.get().getContent().get("register_model_status")); + assertTrue(!future.isCompletedExceptionally()); + assertEquals(taskId, future.get().getContent().get(TASK_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } - public void testRegisterModelFailure() { - RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); - - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + public void testRegisterLocalModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new FlowFrameworkException("Failed to register model", RestStatus.INTERNAL_SERVER_ERROR)); + actionListener.onFailure(new IllegalArgumentException("test")); return null; - }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = registerModelStep.execute(List.of(inputData)); - - verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + CompletableFuture future = this.registerLocalModelStep.execute(List.of(workflowData)); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("test", ex.getCause().getMessage()); + } + public void testMissingInputs() { + CompletableFuture future = registerLocalModelStep.execute(List.of(WorkflowData.EMPTY)); + assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Failed to register model", ex.getCause().getMessage()); + assertEquals("Required fields are not provided", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java new file mode 100644 index 000000000..ca9d5e7a5 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class RegisterRemoteModelStepTests extends OpenSearchTestCase { + + private RegisterRemoteModelStep registerRemoteModelStep; + private WorkflowData workflowData; + + @Mock + MachineLearningNodeClient mlNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient); + this.workflowData = new WorkflowData( + Map.ofEntries( + Map.entry("function_name", "remote"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry("connector_id", "abcdefg") + ) + ); + } + + public void testRegisterRemoteModelSuccess() throws Exception { + + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + + verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + + assertTrue(future.isDone()); + assertTrue(!future.isCompletedExceptionally()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + + } + + public void testRegisterRemoteModelFailure() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IllegalArgumentException("test")); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("test", ex.getCause().getMessage()); + + } + + public void testMissingInputs() { + CompletableFuture future = this.registerRemoteModelStep.execute(List.of(WorkflowData.EMPTY)); + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Required fields are not provided", ex.getCause().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 65fccbb7e..3b1d55a69 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -236,7 +236,7 @@ public void testSuccessfulGraphValidation() throws Exception { ); WorkflowNode registerModel = new WorkflowNode( "workflow_step_2", - RegisterModelStep.NAME, + RegisterRemoteModelStep.NAME, Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) ); @@ -261,7 +261,7 @@ public void testFailedGraphValidation() { // Create Register Model workflow node with missing connector_id field WorkflowNode registerModel = new WorkflowNode( "workflow_step_1", - RegisterModelStep.NAME, + RegisterRemoteModelStep.NAME, Map.of(), Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) );