Skip to content

Commit

Permalink
Model access control dev rebase (#928)
Browse files Browse the repository at this point in the history
(cherry picked from commit 6e1541a)
  • Loading branch information
b4sjoo authored and github-actions[bot] committed May 31, 2023
1 parent a1e55c8 commit aadb649
Show file tree
Hide file tree
Showing 115 changed files with 7,540 additions and 686 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;

import java.util.Map;

Expand Down Expand Up @@ -200,13 +200,15 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
return actionFuture;
}


/**
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
* @param searchRequest searchRequest to search the ML Model
* @param listener action listener
*/
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);


/**
* For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task
* @param searchRequest searchRequest to search the ML Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.*;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
Expand Down Expand Up @@ -154,6 +160,7 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
}, listener::onFailure));
}


@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
Expand Down
57 changes: 56 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ public class CommonValue {
public static String HOT_BOX_TYPE = "hot";
// warm node
public static String WARM_BOX_TYPE = "warm";

public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group";
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
Expand All @@ -43,6 +44,57 @@ public class CommonValue {
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+ " }\n"
+ " }\n";
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" +
" \"_meta\": {\n" +
" \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" +
" },\n" +
" \"properties\": {\n" +
" \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" +
" \"type\": \"integer\"\n" +
" },\n" +
" \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.ACCESS+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.OWNER+"\": {\n" +
" \"type\": \"nested\",\n" +
" \"properties\": {\n" +
" \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" +
" \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" +
" \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" +
" }\n" +
"}";

public static final String ML_MODEL_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MODEL_INDEX_SCHEMA_VERSION
Expand All @@ -61,6 +113,9 @@ public class CommonValue {
+ MLModel.MODEL_VERSION_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_GROUP_ID_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_FIELD
+ "\" : {\"type\": \"binary\"},\n"
+ " \""
Expand Down
14 changes: 14 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
public class MLModel implements ToXContentObject {
public static final String ALGORITHM_FIELD = "algorithm";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
// use String type for version. Keep this old version field for old models.
public static final String OLD_MODEL_VERSION_FIELD = "version";
Expand Down Expand Up @@ -71,6 +72,7 @@ public class MLModel implements ToXContentObject {
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

private String name;
private String modelGroupId;
private FunctionName algorithm;
private String version;
private String content;
Expand Down Expand Up @@ -102,6 +104,7 @@ public class MLModel implements ToXContentObject {
private boolean deployToAllNodes;
@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
FunctionName algorithm,
String version,
String content,
Expand All @@ -125,6 +128,7 @@ public MLModel(String name,
String[] planningWorkerNodes,
boolean deployToAllNodes) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
this.version = version;
this.content = content;
Expand Down Expand Up @@ -186,6 +190,7 @@ public MLModel(StreamInput input) throws IOException{
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
}
}

Expand Down Expand Up @@ -234,6 +239,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalString(modelGroupId);
}

@Override
Expand All @@ -242,6 +248,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (algorithm != null) {
builder.field(ALGORITHM_FIELD, algorithm);
}
Expand Down Expand Up @@ -317,6 +326,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par

public static MLModel parse(XContentParser parser, String algorithmName) throws IOException {
String name = null;
String modelGroupId = null;
FunctionName algorithm = null;
String version = null;
Integer oldVersion = null;
Expand Down Expand Up @@ -356,6 +366,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case MODEL_CONTENT_FIELD:
content = parser.text();
break;
Expand Down Expand Up @@ -454,6 +467,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
}
return MLModel.builder()
.name(name)
.modelGroupId(modelGroupId)
.algorithm(algorithm)
.version(version == null ? oldVersion + "" : version)
.content(content == null ? oldContent : content)
Expand Down
Loading

0 comments on commit aadb649

Please sign in to comment.