Skip to content

Commit

Permalink
hidden model implementation (#1736)
Browse files Browse the repository at this point in the history
* hidden model implementation

Signed-off-by: Dhrubo Saha <[email protected]>

* adding more unit tests

* adding more unit tests

Signed-off-by: Dhrubo Saha <[email protected]>

* fixing spotless

Signed-off-by: Dhrubo Saha <[email protected]>

* adding more unit test in MLModelManager

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Dec 7, 2023
1 parent 14a3882 commit d71c77f
Show file tree
Hide file tree
Showing 31 changed files with 1,164 additions and 270 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -186,6 +186,9 @@ public class CommonValue {
+ MLModel.DEPLOY_TO_ALL_NODES_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.IS_HIDDEN_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
17 changes: 17 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public class MLModel implements ToXContentObject {
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

public static final String IS_HIDDEN_FIELD = "is_hidden";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";

Expand Down Expand Up @@ -110,6 +112,9 @@ public class MLModel implements ToXContentObject {
private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

//is domain manager creates any special hidden model in the cluster this status will be true. Otherwise,
// False by default
private Boolean isHidden;
@Setter
private Connector connector;
private String connectorId;
Expand Down Expand Up @@ -139,6 +144,7 @@ public MLModel(String name,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
this.name = name;
Expand Down Expand Up @@ -166,6 +172,7 @@ public MLModel(String name,
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
this.isHidden = isHidden;
this.connector = connector;
this.connectorId = connectorId;
}
Expand Down Expand Up @@ -210,6 +217,7 @@ public MLModel(StreamInput input) throws IOException{
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
isHidden = input.readOptionalBoolean();
modelGroupId = input.readOptionalString();
if (input.readBoolean()) {
connector = Connector.fromStream(input);
Expand Down Expand Up @@ -263,6 +271,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalBoolean(isHidden);
out.writeOptionalString(modelGroupId);
if (connector != null) {
out.writeBoolean(true);
Expand Down Expand Up @@ -351,6 +360,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
if (connector != null) {
builder.field(CONNECTOR_FIELD, connector);
}
Expand Down Expand Up @@ -393,6 +405,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;
boolean isHidden = false;
Connector connector = null;
String connectorId = null;

Expand Down Expand Up @@ -476,6 +489,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case IS_HIDDEN_FIELD:
isHidden = parser.booleanValue();
break;
case CONNECTOR_FIELD:
connector = createConnector(parser);
break;
Expand Down Expand Up @@ -537,6 +553,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.isHidden(isHidden)
.connector(connector)
.connectorId(connectorId)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,31 @@ public class MLModelGetRequest extends ActionRequest {

String modelId;
boolean returnContent;
// This is to identify if the get request is initiated by user or not. Sometimes during
// delete/update options, we also perform get operation. This field is to distinguish between
// these two situations.
boolean isUserInitiatedGetRequest;

@Builder
public MLModelGetRequest(String modelId, boolean returnContent) {
public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest) {
this.modelId = modelId;
this.returnContent = returnContent;
this.isUserInitiatedGetRequest = isUserInitiatedGetRequest;
}

public MLModelGetRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.returnContent = in.readBoolean();
this.isUserInitiatedGetRequest = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.modelId);
out.writeBoolean(returnContent);
out.writeBoolean(isUserInitiatedGetRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand All @@ -25,6 +26,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -42,7 +44,6 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String DESCRIPTION_FIELD = "description";
public static final String VERSION_FIELD = "version";
public static final String URL_FIELD = "url";
public static final String HASH_VALUE_FIELD = "model_content_hash_value";
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String DEPLOY_MODEL_FIELD = "deploy_model";
Expand Down Expand Up @@ -75,6 +76,8 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private AccessMode accessMode;
private Boolean doesVersionCreateModelGroup;

private Boolean isHidden;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
String modelName,
Expand All @@ -92,13 +95,10 @@ public MLRegisterModelInput(FunctionName functionName,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode,
Boolean doesVersionCreateModelGroup
Boolean doesVersionCreateModelGroup,
Boolean isHidden
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
} else {
this.functionName = functionName;
}
this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING);
if (modelName == null) {
throw new IllegalArgumentException("model name is null");
}
Expand Down Expand Up @@ -126,6 +126,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
this.isHidden = isHidden;
}


Expand Down Expand Up @@ -161,6 +162,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
this.isHidden = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -207,6 +209,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
out.writeOptionalBoolean(isHidden);
}

@Override
Expand All @@ -227,7 +230,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(URL_FIELD, url);
}
if (hashValue != null) {
builder.field(HASH_VALUE_FIELD, hashValue);
builder.field(MODEL_CONTENT_HASH_VALUE_FIELD, hashValue);
}
if (modelFormat != null) {
builder.field(MODEL_FORMAT_FIELD, modelFormat);
Expand Down Expand Up @@ -257,6 +260,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
builder.endObject();
return builder;
}
Expand All @@ -276,6 +282,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
Boolean doesVersionCreateModelGroup = null;
Boolean isHidden = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -291,7 +298,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case URL_FIELD:
url = parser.text();
break;
case HASH_VALUE_FIELD:
case MODEL_CONTENT_HASH_VALUE_FIELD:
hashValue = parser.text();
break;
case DESCRIPTION_FIELD:
Expand Down Expand Up @@ -324,6 +331,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case ADD_ALL_BACKEND_ROLES_FIELD:
addAllBackendRoles = parser.booleanValue();
break;
case MLModel.IS_HIDDEN_FIELD:
isHidden = parser.booleanValue();
break;
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
Expand All @@ -335,7 +345,8 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden);

}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -355,6 +366,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
AccessMode accessMode = null;
Boolean addAllBackendRoles = null;
Boolean doesVersionCreateModelGroup = null;
Boolean isHidden = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -383,7 +395,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case CONNECTOR_FIELD:
connector = createConnector(parser);
break;
case HASH_VALUE_FIELD:
case MODEL_CONTENT_HASH_VALUE_FIELD:
hashValue = parser.text();
break;
case CONNECTOR_ID_FIELD:
Expand Down Expand Up @@ -416,11 +428,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
case MLModel.IS_HIDDEN_FIELD:
isHidden = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden);
}
}
Loading

0 comments on commit d71c77f

Please sign in to comment.