Skip to content

Commit

Permalink
Added register model group API for MLClient (#1493)
Browse files Browse the repository at this point in the history
* Added register model group API for MLClient

Signed-off-by: Owais Kazi <[email protected]>

* Resolved formatting errors

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 authored Oct 12, 2023
1 parent d265fc7 commit 66d4cd4
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 376 deletions.
10 changes: 10 additions & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ plugins {
id 'jacoco'
id 'com.github.johnrengelman.shadow'
id 'maven-publish'
id 'com.diffplug.spotless' version '6.18.0'
id 'signing'
}

Expand All @@ -20,6 +21,15 @@ dependencies {

}

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}

jacocoTestReport {
reports {
xml.getRequired().set(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@

package org.opensearch.ml.client;

import java.util.Map;

import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Map;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
*/
Expand Down Expand Up @@ -84,7 +85,6 @@ default ActionFuture<MLOutput> train(MLInput mlInput, boolean asyncTask) {
return actionFuture;
}


/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* For more info on train model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#train-model
Expand Down Expand Up @@ -205,15 +205,13 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
return actionFuture;
}


/**
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
* @param searchRequest searchRequest to search the ML Model
* @param listener action listener
*/
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);


/**
* For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task
* @param searchRequest searchRequest to search the ML Task
Expand Down Expand Up @@ -282,4 +280,23 @@ default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnecto
}

void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
* @param mlRegisterModelGroupInput model group input
*/
default ActionFuture<MLRegisterModelGroupResponse> registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput) {
PlainActionFuture<MLRegisterModelGroupResponse> actionFuture = PlainActionFuture.newFuture();
registerModelGroup(mlRegisterModelGroupInput, actionFuture);
return actionFuture;
}

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
* @param mlRegisterModelGroupInput model group input
* @param listener a listener to be notified of the result
*/
void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener<MLRegisterModelGroupResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,36 @@

package org.opensearch.ml.client;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
import static org.opensearch.ml.common.input.Constants.TRAIN;
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;
import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter;
import static org.opensearch.ml.common.input.InputHelper.getAction;
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;

import java.util.Map;
import java.util.function.Function;

import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
Expand All @@ -46,7 +43,10 @@
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
Expand All @@ -63,20 +63,9 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
import static org.opensearch.ml.common.input.Constants.TRAIN;
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;
import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter;
import static org.opensearch.ml.common.input.InputHelper.getAction;
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
Expand All @@ -88,33 +77,32 @@ public class MachineLearningNodeClient implements MachineLearningClient {
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
.mlInput(mlInput)
.modelId(modelId)
.dispatchTask(true)
.build();
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
.builder()
.mlInput(mlInput)
.modelId(modelId)
.dispatchTask(true)
.build();
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
}

@Override
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder()
.mlInput(mlInput)
.dispatchTask(true)
.build();
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).dispatchTask(true).build();

client.execute(MLTrainAndPredictionTaskAction.INSTANCE, request, getMlPredictionTaskResponseActionListener(listener));
}

@Override
public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);
MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder()
.mlInput(mlInput)
.async(asyncTask)
.dispatchTask(true)
.build();
MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest
.builder()
.mlInput(mlInput)
.async(asyncTask)
.dispatchTask(true)
.build();

client.execute(MLTrainingTaskAction.INSTANCE, trainingTaskRequest, getMlPredictionTaskResponseActionListener(listener));
}
Expand Down Expand Up @@ -144,15 +132,13 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
trainAndPredict(mlInput, listener);
break;
default:
throw new IllegalArgumentException("Unsupported action.");
throw new IllegalArgumentException("Unsupported action.");
}
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder()
.modelId(modelId)
.build();
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();

client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
}
Expand All @@ -170,9 +156,7 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder()
.modelId(modelId)
.build();
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();

client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
Expand All @@ -181,17 +165,26 @@ public void deleteModel(String modelId, ActionListener<DeleteResponse> listener)

@Override
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
listener.onResponse(searchResponse);
}, listener::onFailure));
client
.execute(
MLModelSearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
);
}

@Override
public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
) {
MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput);
client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
.taskId(taskId)
.build();
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();

client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(response -> {
listener.onResponse(MLTaskGetResponse.fromActionResponse(response).getMlTask());
Expand All @@ -200,9 +193,7 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder()
.taskId(taskId)
.build();
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();

client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
Expand All @@ -211,25 +202,34 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {

@Override
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
listener.onResponse(searchResponse);
}, listener::onFailure));
client
.execute(
MLTaskSearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
);
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
client
.execute(
MLRegisterModelAction.INSTANCE,
registerRequest,
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
client
.execute(
MLDeployModelAction.INSTANCE,
deployModelRequest,
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
);
}

@Override
Expand All @@ -249,20 +249,22 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(final ActionListener<T> listener, final Function<ActionResponse, T> recreate) {
ActionListener<T> actionListener = ActionListener.wrap(r-> {
listener.onResponse(recreate.apply(r));;
}, e->{
listener.onFailure(e);
});
private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
) {
ActionListener<T> actionListener = ActionListener.wrap(r -> {
listener.onResponse(recreate.apply(r));
;
}, e -> { listener.onFailure(e); });
return actionListener;
}

private void validateMLInput(MLInput mlInput, boolean requireInput) {
if (mlInput == null) {
throw new IllegalArgumentException("ML Input can't be null");
}
if(requireInput && mlInput.getInputDataset() == null) {
if (requireInput && mlInput.getInputDataset() == null) {
throw new IllegalArgumentException("input data set can't be null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.client;
package org.opensearch.ml.client;
Loading

0 comments on commit 66d4cd4

Please sign in to comment.