Skip to content

Commit

Permalink
add register and deploy api in client (opensearch-project#1359) (open…
Browse files Browse the repository at this point in the history
…search-project#1388)

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Sep 26, 2023
1 parent b48546a commit 234f81e
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Map;

Expand Down Expand Up @@ -226,4 +229,42 @@ default ActionFuture<SearchResponse> searchTask(SearchRequest searchRequest) {
* @param listener action listener
*/
void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
*/
default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlInput) {
PlainActionFuture<MLRegisterModelResponse> actionFuture = PlainActionFuture.newFuture();
register(mlInput, actionFuture);
return actionFuture;
}

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener);

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
PlainActionFuture<MLDeployModelResponse> actionFuture = PlainActionFuture.newFuture();
deploy(modelId, actionFuture);
return actionFuture;
}

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,34 @@
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -30,6 +45,10 @@
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand All @@ -40,9 +59,12 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
Expand Down Expand Up @@ -190,6 +212,22 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
}, listener::onFailure));
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getOutput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -58,6 +68,12 @@ public class MachineLearningClientTest {
@Mock
SearchResponse searchResponse;

@Mock
MLRegisterModelResponse registerModelResponse;

@Mock
MLDeployModelResponse deployModelResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -132,6 +148,16 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}
};
}

Expand Down Expand Up @@ -251,4 +277,31 @@ public void deleteTask() {
public void searchTask() {
assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet());
}

@Test
public void register() {
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(FunctionName.KMEANS)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet());
}

@Test
public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -46,6 +53,10 @@
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand Down Expand Up @@ -109,6 +120,12 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<SearchResponse> searchTaskActionListener;

@Mock
ActionListener<MLRegisterModelResponse> RegisterModelActionListener;

@Mock
ActionListener<MLDeployModelResponse> DeployModelActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -554,6 +571,67 @@ public void searchTask() {
assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD));
}

@Test
public void register() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(2);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
machineLearningNodeClient.register(mlInput, RegisterModelActionListener);

verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any());
verify(RegisterModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

@Test
public void deploy() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
String modelId = "modelId";
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
machineLearningNodeClient.deploy(modelId, DeployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any());
verify(DeployModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down

0 comments on commit 234f81e

Please sign in to comment.