diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 8452dcc3c2..5bbe1f029e 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -16,6 +16,9 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.Map; @@ -226,4 +229,42 @@ default ActionFuture searchTask(SearchRequest searchRequest) { * @param listener action listener */ void searchTask(SearchRequest searchRequest, ActionListener listener); + + /** + * Register model + * For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model + * @param mlInput ML input + */ + default ActionFuture register(MLRegisterModelInput mlInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + register(mlInput, actionFuture); + return actionFuture; + } + + /** + * Register model + * For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model + * @param mlInput ML input + * @param listener a listener to be notified of the result + */ + void register(MLRegisterModelInput mlInput, ActionListener listener); + + /** + * Deploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * @param modelId the model id + */ + default ActionFuture deploy(String modelId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deploy(modelId, actionFuture); + return actionFuture; + } + + /** + * Deploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * @param modelId the model id + * @param listener a listener to be notified of the result + */ + void deploy(String modelId, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index d7328b9766..d9774e657b 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -8,19 +8,34 @@ import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.experimental.FieldDefaults; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelInput; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -30,6 +45,10 @@ import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; import org.opensearch.ml.common.transport.task.MLTaskGetAction; @@ -40,9 +59,12 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import java.io.IOException; +import java.time.Instant; import java.util.Map; import java.util.function.Function; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.input.Constants.ASYNC; import static org.opensearch.ml.common.input.Constants.MODELID; import static org.opensearch.ml.common.input.Constants.PREDICT; @@ -190,6 +212,22 @@ public void searchTask(SearchRequest searchRequest, ActionListener listener) { + MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); + client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> { + listener.onFailure(e); + })); + } + + @Override + public void deploy(String modelId, ActionListener listener) { + MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> { + listener.onFailure(e); + })); + } + private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(predictionResponse -> { listener.onResponse(predictionResponse.getOutput()); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 2a98812090..a57275292d 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -21,9 +21,19 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.HashMap; import java.util.Map; @@ -58,6 +68,12 @@ public class MachineLearningClientTest { @Mock SearchResponse searchResponse; + @Mock + MLRegisterModelResponse registerModelResponse; + + @Mock + MLDeployModelResponse deployModelResponse; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -132,6 +148,16 @@ public void deleteTask(String taskId, ActionListener listener) { public void searchTask(SearchRequest searchRequest, ActionListener listener) { listener.onResponse(searchResponse); } + + @Override + public void register(MLRegisterModelInput mlInput, ActionListener listener) { + listener.onResponse(registerModelResponse); + } + + @Override + public void deploy(String modelId, ActionListener listener) { + listener.onResponse(deployModelResponse); + } }; } @@ -251,4 +277,31 @@ public void deleteTask() { public void searchTask() { assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet()); } + + @Test + public void register() { + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(FunctionName.KMEANS) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet()); + } + + @Test + public void deploy() { + assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index b515fc71cc..a2d2b494a3 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -31,13 +31,20 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -46,6 +53,10 @@ import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; import org.opensearch.ml.common.transport.task.MLTaskGetAction; @@ -109,6 +120,12 @@ public class MachineLearningNodeClientTest { @Mock ActionListener searchTaskActionListener; + @Mock + ActionListener RegisterModelActionListener; + + @Mock + ActionListener DeployModelActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -554,6 +571,67 @@ public void searchTask() { assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD)); } + @Test + public void register() { + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + FunctionName functionName = FunctionName.KMEANS; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + machineLearningNodeClient.register(mlInput, RegisterModelActionListener); + + verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any()); + verify(RegisterModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); + assertEquals(status, (argumentCaptor.getValue()).getStatus()); + } + + @Test + public void deploy() { + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + String modelId = "modelId"; + FunctionName functionName = FunctionName.KMEANS; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class); + machineLearningNodeClient.deploy(modelId, DeployModelActionListener); + + verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any()); + verify(DeployModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); + assertEquals(status, (argumentCaptor.getValue()).getStatus()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);