Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add register and deploy api in client (#1359) #1388

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Map;

Expand Down Expand Up @@ -226,4 +229,42 @@ default ActionFuture<SearchResponse> searchTask(SearchRequest searchRequest) {
* @param listener action listener
*/
void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
*/
default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlInput) {
PlainActionFuture<MLRegisterModelResponse> actionFuture = PlainActionFuture.newFuture();
register(mlInput, actionFuture);
return actionFuture;
}

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener);

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
PlainActionFuture<MLDeployModelResponse> actionFuture = PlainActionFuture.newFuture();
deploy(modelId, actionFuture);
return actionFuture;
}

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,34 @@
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -30,6 +45,10 @@
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand All @@ -40,9 +59,12 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
Expand Down Expand Up @@ -190,6 +212,22 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
}, listener::onFailure));
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getOutput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -58,6 +68,12 @@ public class MachineLearningClientTest {
@Mock
SearchResponse searchResponse;

@Mock
MLRegisterModelResponse registerModelResponse;

@Mock
MLDeployModelResponse deployModelResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -132,6 +148,16 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}
};
}

Expand Down Expand Up @@ -251,4 +277,31 @@ public void deleteTask() {
public void searchTask() {
assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet());
}

@Test
public void register() {
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(FunctionName.KMEANS)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet());
}

@Test
public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -46,6 +53,10 @@
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand Down Expand Up @@ -109,6 +120,12 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<SearchResponse> searchTaskActionListener;

@Mock
ActionListener<MLRegisterModelResponse> RegisterModelActionListener;

@Mock
ActionListener<MLDeployModelResponse> DeployModelActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -554,6 +571,67 @@ public void searchTask() {
assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD));
}

@Test
public void register() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(2);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
machineLearningNodeClient.register(mlInput, RegisterModelActionListener);

verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any());
verify(RegisterModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

@Test
public void deploy() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
String modelId = "modelId";
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
machineLearningNodeClient.deploy(modelId, DeployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any());
verify(DeployModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Loading