Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dtaivpp authored Oct 3, 2023
2 parents 278389f + a0c20e6 commit e87f0db
Show file tree
Hide file tree
Showing 66 changed files with 2,062 additions and 505 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Machine Learning Commons for OpenSearch is a new solution that make it easy to d
Until today, the challenge is significant to build a new machine learning feature inside OpenSearch. The reasons include:

* **Disruption to OpenSearch Core features**. Machine learning is very computationally intensive. But currently there is no way to add dedicated computation resources in OpenSearch for machine learning jobs, hence these jobs have to share same resources with Core features, such as: indexing and searching. That might cause the latency increasing on search request, and cause circuit breaker exception on memory usage. To address this, we have to carefully distribute models and limit the data size to run the AD job. When more and more ML features are added into OpenSearch, it will become much harder to manage.
* **Lack of support for machine learning algorithms.** Customers need more algorighms within Opensearch, otherwise the data need be exported to outside of elasticsearch, such as s3 first to do the job, which will bring extra cost and latency.
* **Lack of support for machine learning algorithms.** Customers need more algorithms within Opensearch, otherwise the data need be exported to outside of elasticsearch, such as s3 first to do the job, which will bring extra cost and latency.
* **Lack of resource management mechanism between multiple machine learning jobs.** It's hard to coordinate the resources between multi features.


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Map;

Expand Down Expand Up @@ -226,4 +229,42 @@ default ActionFuture<SearchResponse> searchTask(SearchRequest searchRequest) {
* @param listener action listener
*/
void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
*/
default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlInput) {
PlainActionFuture<MLRegisterModelResponse> actionFuture = PlainActionFuture.newFuture();
register(mlInput, actionFuture);
return actionFuture;
}

/**
* Register model
* For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener);

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
PlainActionFuture<MLDeployModelResponse> actionFuture = PlainActionFuture.newFuture();
deploy(modelId, actionFuture);
return actionFuture;
}

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,34 @@
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -30,6 +45,10 @@
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand All @@ -40,9 +59,12 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
Expand Down Expand Up @@ -190,6 +212,22 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
}, listener::onFailure));
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> {
listener.onFailure(e);
}));
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getOutput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -58,6 +68,12 @@ public class MachineLearningClientTest {
@Mock
SearchResponse searchResponse;

@Mock
MLRegisterModelResponse registerModelResponse;

@Mock
MLDeployModelResponse deployModelResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -132,6 +148,16 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
}

@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}
};
}

Expand Down Expand Up @@ -251,4 +277,31 @@ public void deleteTask() {
public void searchTask() {
assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet());
}

@Test
public void register() {
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(FunctionName.KMEANS)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet());
}

@Test
public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -46,6 +53,10 @@
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
Expand Down Expand Up @@ -109,6 +120,12 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<SearchResponse> searchTaskActionListener;

@Mock
ActionListener<MLRegisterModelResponse> RegisterModelActionListener;

@Mock
ActionListener<MLDeployModelResponse> DeployModelActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -554,6 +571,67 @@ public void searchTask() {
assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD));
}

@Test
public void register() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(2);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
machineLearningNodeClient.register(mlInput, RegisterModelActionListener);

verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any());
verify(RegisterModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

@Test
public void deploy() {
String taskId = "taskId";
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
String modelId = "modelId";
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
machineLearningNodeClient.deploy(modelId, DeployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any());
verify(DeployModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class CommonValue {
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@Getter
Expand Down Expand Up @@ -101,7 +102,7 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
return;
}
if (response instanceof String && isJson((String)response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
Expand Down
Loading

0 comments on commit e87f0db

Please sign in to comment.