Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge update model API and model level throttling/quota #1624

Merged
11 changes: 10 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -204,6 +204,15 @@ public class CommonValue {
+ MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.QUOTA_FLAG_FIELD
+ "\" : {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.RATE_LIMIT_NUMBER_FIELD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to fix in this PR. Suggest merge RATE_LIMIT_NUMBER_FIELD and RATE_LIMIT_UNIT_FIELD into one filed

+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.RATE_LIMIT_UNIT_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
Expand Down
57 changes: 55 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;

import java.io.IOException;
import java.sql.Time;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand All @@ -50,9 +52,14 @@ public class MLModel implements ToXContentObject {
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
//SHA256 hash value of model content.
// SHA256 hash value of model content.
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";

// Model level quota and throttling control
public static final String QUOTA_FLAG_FIELD = "quota_flag";
public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number";
public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit";

public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
Expand Down Expand Up @@ -92,6 +99,9 @@ public class MLModel implements ToXContentObject {
private Long modelContentSizeInBytes;
private String modelContentHash;
private MLModelConfig modelConfig;
private Boolean quotaFlag;
private String rateLimitNumber;
private TimeUnit rateLimitUnit;
private Instant createdTime;
private Instant lastUpdateTime;
private Instant lastRegisteredTime;
Expand Down Expand Up @@ -126,6 +136,9 @@ public MLModel(String name,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
Boolean quotaFlag,
String rateLimitNumber,
TimeUnit rateLimitUnit,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Expand All @@ -152,6 +165,9 @@ public MLModel(String name,
this.modelState = modelState;
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHash = modelContentHash;
this.quotaFlag = quotaFlag;
this.rateLimitNumber = rateLimitNumber;
this.rateLimitUnit = rateLimitUnit;
this.modelConfig = modelConfig;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
Expand Down Expand Up @@ -197,6 +213,11 @@ public MLModel(StreamInput input) throws IOException{
modelConfig = new TextEmbeddingModelConfig(input);
}
}
quotaFlag = input.readOptionalBoolean();
rateLimitNumber = input.readOptionalString();
if (input.readBoolean()) {
rateLimitUnit = input.readEnum(TimeUnit.class);
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
lastRegisteredTime = input.readOptionalInstant();
Expand Down Expand Up @@ -250,6 +271,14 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(quotaFlag);
out.writeOptionalString(rateLimitNumber);
if (rateLimitUnit != null) {
out.writeBoolean(true);
out.writeEnum(rateLimitUnit);
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalInstant(lastRegisteredTime);
Expand Down Expand Up @@ -312,6 +341,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (quotaFlag != null) {
builder.field(QUOTA_FLAG_FIELD, quotaFlag);
}
if (rateLimitNumber != null) {
builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber);
}
if (rateLimitUnit != null) {
builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand Down Expand Up @@ -371,12 +409,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
String oldContent = null;
User user = null;

String description = null;;
String description = null;
MLModelFormat modelFormat = null;
MLModelState modelState = null;
Long modelContentSizeInBytes = null;
String modelContentHash = null;
MLModelConfig modelConfig = null;
Boolean quotaFlag = null;
String rateLimitNumber = null;
TimeUnit rateLimitUnit = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
Instant lastUploadedTime = null;
Expand Down Expand Up @@ -461,6 +502,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case QUOTA_FLAG_FIELD:
quotaFlag = parser.booleanValue();
break;
case RATE_LIMIT_NUMBER_FIELD:
rateLimitNumber = parser.text();
break;
case RATE_LIMIT_UNIT_FIELD:
rateLimitUnit = TimeUnit.valueOf(parser.text());
break;
case PLANNING_WORKER_NODE_COUNT_FIELD:
planningWorkerNodeCount = parser.intValue();
break;
Expand Down Expand Up @@ -524,6 +574,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.modelContentSizeInBytes(modelContentSizeInBytes)
.modelContentHash(modelContentHash)
.modelConfig(modelConfig)
.quotaFlag(quotaFlag)
.rateLimitNumber(rateLimitNumber)
.rateLimitUnit(rateLimitUnit)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;

public class MLInPlaceUpdateModelAction extends ActionType<MLInPlaceUpdateModelNodesResponse> {
public static final MLInPlaceUpdateModelAction INSTANCE = new MLInPlaceUpdateModelAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/in_place_update";

private MLInPlaceUpdateModelAction() { super(NAME, MLInPlaceUpdateModelNodesResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.support.nodes.BaseNodeRequest;
import java.io.IOException;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLInPlaceUpdateModelNodeRequest extends BaseNodeRequest {
@Getter
private MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest;

public MLInPlaceUpdateModelNodeRequest(StreamInput in) throws IOException {
super(in);
this.mlInPlaceUpdateModelNodesRequest = new MLInPlaceUpdateModelNodesRequest(in);
}

public MLInPlaceUpdateModelNodeRequest(MLInPlaceUpdateModelNodesRequest request) {
this.mlInPlaceUpdateModelNodesRequest = request;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
mlInPlaceUpdateModelNodesRequest.writeTo(out);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

@Getter
@Log4j2
public class MLInPlaceUpdateModelNodeResponse extends BaseNodeResponse implements ToXContentFragment {
private Map<String, String> modelUpdateStatus;

public MLInPlaceUpdateModelNodeResponse(DiscoveryNode node, Map<String, String> modelUpdateStatus) {
super(node);
this.modelUpdateStatus = modelUpdateStatus;
}

public MLInPlaceUpdateModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUpdateStatus = in.readMap(StreamInput::readString, StreamInput::readString);
}
}

public static MLInPlaceUpdateModelNodeResponse readStats(StreamInput in) throws IOException {
return new MLInPlaceUpdateModelNodeResponse(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);

if (!isEmpty()) {
out.writeBoolean(true);
out.writeMap(modelUpdateStatus, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("stats");
if (modelUpdateStatus != null && modelUpdateStatus.size() > 0) {
for (Map.Entry<String, String> stat : modelUpdateStatus.entrySet()) {
builder.field(stat.getKey(), stat.getValue());
}
}
builder.endObject();
return builder;
}

public boolean isEmpty() {
return modelUpdateStatus == null || modelUpdateStatus.size() == 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Getter;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import java.io.IOException;

public class MLInPlaceUpdateModelNodesRequest extends BaseNodesRequest<MLInPlaceUpdateModelNodesRequest> {

@Getter
private String modelId;
@Getter
private boolean updatePredictorFlag;

public MLInPlaceUpdateModelNodesRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.updatePredictorFlag = in.readBoolean();
}

public MLInPlaceUpdateModelNodesRequest(String[] nodeIds, String modelId, boolean updatePredictorFlag) {
super(nodeIds);
this.modelId = modelId;
this.updatePredictorFlag = updatePredictorFlag;
}

public MLInPlaceUpdateModelNodesRequest(DiscoveryNode[] nodeIds, String modelId, boolean updatePredictorFlag) {
super(nodeIds);
this.modelId = modelId;
this.updatePredictorFlag = updatePredictorFlag;
}

public MLInPlaceUpdateModelNodesRequest(DiscoveryNode... nodes) {
super(nodes);
this.updatePredictorFlag = false;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeBoolean(updatePredictorFlag);
}
}
Loading
Loading