Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model access control dev rebase #928

Merged
merged 20 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3beba53
model access control changes rebased to 2.x (#902)
rbhavna May 23, 2023
8c9285e
Model access control dev1 (#905)
zane-neo May 23, 2023
09b759e
fix undeploy model API (#906)
ylwu-amzn May 23, 2023
466b9bf
Fix some bugs for model access control (#907)
ylwu-amzn May 24, 2023
ca2737c
get 10k model group ids when search model (#908)
ylwu-amzn May 24, 2023
b3fb6f3
Fix no corresponding backend roles user can update model issue (#909)
zane-neo May 24, 2023
e24e754
1. revert model content hash change; 2.fix task search with permissio…
zane-neo May 24, 2023
d296fff
remove tags and add unit tests for register model group (#912)
rbhavna May 24, 2023
5ebb436
add access control to register version via local file and minor fixes…
rbhavna May 24, 2023
d5b7d66
Change model_access_control to optional when register model (#914)
zane-neo May 25, 2023
829b88a
Add ITs and UTs and refactor model group search code and fixed minor …
zane-neo May 25, 2023
9a8afca
fix metric correlation (#915)
ylwu-amzn May 25, 2023
45631a3
fix model meta API and add update model group UTs (#918)
rbhavna May 25, 2023
4ea9cff
Fix model search IT bugs (#920)
zane-neo May 26, 2023
1cbc374
Add undeploy models UTs (#926)
zane-neo May 30, 2023
a6ec9b9
Unit tests for model access control new classes (#927)
rbhavna May 30, 2023
4cbd86f
bump 2.8 to unblock 2.8 release (#896)
jngz-es May 23, 2023
6210255
spotlessApply styling
b4sjoo May 30, 2023
2115f54
Add Search Model Group Rest Action unit tests and minor fixes (#929)
rbhavna May 30, 2023
df49851
Add unit tests for new classes
b4sjoo May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;

import java.util.Map;

Expand Down Expand Up @@ -200,13 +200,15 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
return actionFuture;
}


/**
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
* @param searchRequest searchRequest to search the ML Model
* @param listener action listener
*/
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);


/**
* For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task
* @param searchRequest searchRequest to search the ML Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.*;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
Expand Down Expand Up @@ -154,6 +160,7 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
}, listener::onFailure));
}


@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
Expand Down
57 changes: 56 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ public class CommonValue {
public static String HOT_BOX_TYPE = "hot";
// warm node
public static String WARM_BOX_TYPE = "warm";

public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group";
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
Expand All @@ -43,6 +44,57 @@ public class CommonValue {
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+ " }\n"
+ " }\n";
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" +
" \"_meta\": {\n" +
" \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" +
" },\n" +
" \"properties\": {\n" +
" \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed backend_roles field

"backend_roles": {
					"type": "text",
					"fields": {
						"keyword": {
							"type": "keyword",
							"ignore_above": 256
						}
					}
				},

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I forgot adding it, I guess

" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" +
" \"type\": \"integer\"\n" +
" },\n" +
" \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.ACCESS+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.OWNER+"\": {\n" +
" \"type\": \"nested\",\n" +
" \"properties\": {\n" +
" \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" +
" \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" +
" \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" +
" }\n" +
"}";

public static final String ML_MODEL_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MODEL_INDEX_SCHEMA_VERSION
Expand All @@ -61,6 +113,9 @@ public class CommonValue {
+ MLModel.MODEL_VERSION_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_GROUP_ID_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_FIELD
+ "\" : {\"type\": \"binary\"},\n"
+ " \""
Expand Down
14 changes: 14 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
public class MLModel implements ToXContentObject {
public static final String ALGORITHM_FIELD = "algorithm";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
// use String type for version. Keep this old version field for old models.
public static final String OLD_MODEL_VERSION_FIELD = "version";
Expand Down Expand Up @@ -71,6 +72,7 @@ public class MLModel implements ToXContentObject {
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

private String name;
private String modelGroupId;
private FunctionName algorithm;
private String version;
private String content;
Expand Down Expand Up @@ -102,6 +104,7 @@ public class MLModel implements ToXContentObject {
private boolean deployToAllNodes;
@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
FunctionName algorithm,
String version,
String content,
Expand All @@ -125,6 +128,7 @@ public MLModel(String name,
String[] planningWorkerNodes,
boolean deployToAllNodes) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
this.version = version;
this.content = content;
Expand Down Expand Up @@ -186,6 +190,7 @@ public MLModel(StreamInput input) throws IOException{
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
}
}

Expand Down Expand Up @@ -234,6 +239,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalString(modelGroupId);
}

@Override
Expand All @@ -242,6 +248,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (algorithm != null) {
builder.field(ALGORITHM_FIELD, algorithm);
}
Expand Down Expand Up @@ -317,6 +326,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par

public static MLModel parse(XContentParser parser, String algorithmName) throws IOException {
String name = null;
String modelGroupId = null;
FunctionName algorithm = null;
String version = null;
Integer oldVersion = null;
Expand Down Expand Up @@ -356,6 +366,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case MODEL_CONTENT_FIELD:
content = parser.text();
break;
Expand Down Expand Up @@ -454,6 +467,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
}
return MLModel.builder()
.name(name)
.modelGroupId(modelGroupId)
.algorithm(algorithm)
.version(version == null ? oldVersion + "" : version)
.content(content == null ? oldContent : content)
Expand Down
Loading