Skip to content

Commit

Permalink
model access control changes rebased to 2.x (#902)
Browse files Browse the repository at this point in the history
* Backport changes to model-access-control feature branch (#837)

* rename model meta/chunk API (#827)

Signed-off-by: Yaliang Wu <[email protected]>

* Change the ziputil dependency to fix a potential security concern (#824)

Signed-off-by: Yaliang Wu <[email protected]>

* add text docs ML input (#830)

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Co-authored-by: Sicheng Song <[email protected]>

* add model group (#840)

* add model group

Signed-off-by: Yaliang Wu <[email protected]>

* rename create model group as register model group

Signed-off-by: Yaliang Wu <[email protected]>

* fix class name in build.gradle

Signed-off-by: Yaliang Wu <[email protected]>

* remove unused code; stash thread context

Signed-off-by: Yaliang Wu <[email protected]>

* exclude class for low coverage

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>

* access validation for register/update model-group, register/get/delete/deploy/predict model

Signed-off-by: Bhavana Ramaram <[email protected]>

* changes to security utils and register/search model group

Signed-off-by: Bhavana Ramaram <[email protected]>

* update last_updated_time in model group when new version added

Signed-off-by: Bhavana Ramaram <[email protected]>

* fix undeploy model action

Signed-off-by: Bhavana Ramaram <[email protected]>

* fix format violations

Signed-off-by: Bhavana Ramaram <[email protected]>

* rebased to 2.x and prediction/search API fix

Signed-off-by: Bhavana Ramaram <[email protected]>

* fix rebase conflicts

Signed-off-by: Bhavana Ramaram <[email protected]>

* new delete model group API

Signed-off-by: Bhavana Ramaram <[email protected]>

* fix formal violations

Signed-off-by: Bhavana Ramaram <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Bhavana Ramaram <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
Co-authored-by: Sicheng Song <[email protected]>
  • Loading branch information
3 people authored May 23, 2023
1 parent 134e065 commit e49227b
Show file tree
Hide file tree
Showing 65 changed files with 3,070 additions and 330 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;

import java.util.Map;

Expand Down Expand Up @@ -200,13 +200,15 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
return actionFuture;
}


/**
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
* @param searchRequest searchRequest to search the ML Model
* @param listener action listener
*/
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);


/**
* For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task
* @param searchRequest searchRequest to search the ML Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.*;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
Expand Down Expand Up @@ -154,6 +160,7 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
}, listener::onFailure));
}


@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
Expand Down
48 changes: 47 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ public class CommonValue {
public static String HOT_BOX_TYPE = "hot";
// warm node
public static String WARM_BOX_TYPE = "warm";

public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group";
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
Expand All @@ -43,6 +44,48 @@ public class CommonValue {
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+ " }\n"
+ " }\n";
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" +
" \"_meta\": {\n" +
" \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" +
" },\n" +
" \"properties\": {\n" +
" \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" +
" \"type\": \"integer\"\n" +
" },\n" +
" \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.ACCESS+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.OWNER+"\": {\n" +
" \"type\": \"nested\",\n" +
" \"properties\": {\n" +
" \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" +
" \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" +
" \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" +
" }\n" +
"}";

public static final String ML_MODEL_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MODEL_INDEX_SCHEMA_VERSION
Expand All @@ -61,6 +104,9 @@ public class CommonValue {
+ MLModel.MODEL_VERSION_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_GROUP_ID_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_FIELD
+ "\" : {\"type\": \"binary\"},\n"
+ " \""
Expand Down
14 changes: 14 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
public class MLModel implements ToXContentObject {
public static final String ALGORITHM_FIELD = "algorithm";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
// use String type for version. Keep this old version field for old models.
public static final String OLD_MODEL_VERSION_FIELD = "version";
Expand Down Expand Up @@ -71,6 +72,7 @@ public class MLModel implements ToXContentObject {
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

private String name;
private String modelGroupId;
private FunctionName algorithm;
private String version;
private String content;
Expand Down Expand Up @@ -102,6 +104,7 @@ public class MLModel implements ToXContentObject {
private boolean deployToAllNodes;
@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
FunctionName algorithm,
String version,
String content,
Expand All @@ -125,6 +128,7 @@ public MLModel(String name,
String[] planningWorkerNodes,
boolean deployToAllNodes) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
this.version = version;
this.content = content;
Expand Down Expand Up @@ -186,6 +190,7 @@ public MLModel(StreamInput input) throws IOException{
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
}
}

Expand Down Expand Up @@ -234,6 +239,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalString(modelGroupId);
}

@Override
Expand All @@ -242,6 +248,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (algorithm != null) {
builder.field(ALGORITHM_FIELD, algorithm);
}
Expand Down Expand Up @@ -317,6 +326,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par

public static MLModel parse(XContentParser parser, String algorithmName) throws IOException {
String name = null;
String modelGroupId = null;
FunctionName algorithm = null;
String version = null;
Integer oldVersion = null;
Expand Down Expand Up @@ -356,6 +366,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case MODEL_CONTENT_FIELD:
content = parser.text();
break;
Expand Down Expand Up @@ -454,6 +467,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
}
return MLModel.builder()
.name(name)
.modelGroupId(modelGroupId)
.algorithm(algorithm)
.version(version == null ? oldVersion + "" : version)
.content(content == null ? oldContent : content)
Expand Down
Loading

0 comments on commit e49227b

Please sign in to comment.