Skip to content

Commit

Permalink
Adding UT coverage for in-cache update and fine-tuning throttling fea…
Browse files Browse the repository at this point in the history
  • Loading branch information
b4sjoo authored Jan 23, 2024
1 parent a14521d commit d0895bb
Show file tree
Hide file tree
Showing 88 changed files with 3,232 additions and 2,842 deletions.
842 changes: 422 additions & 420 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java

Large diffs are not rendered by default.

115 changes: 58 additions & 57 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ public class MLModel implements ToXContentObject {

// Model level quota and throttling control
public static final String IS_ENABLED_FIELD = "is_enabled";
public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config";
public static final String IS_MODEL_CONTROLLER_ENABLED_FIELD = "is_model_controller_enabled";
public static final String RATE_LIMITER_FIELD = "rate_limiter";
public static final String IS_CONTROLLER_ENABLED_FIELD = "is_controller_enabled";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
Expand Down Expand Up @@ -100,8 +100,8 @@ public class MLModel implements ToXContentObject {
private String modelContentHash;
private MLModelConfig modelConfig;
private Boolean isEnabled;
private Boolean isModelControllerEnabled;
private MLRateLimiter modelRateLimiterConfig;
private Boolean isControllerEnabled;
private MLRateLimiter rateLimiter;
private Instant createdTime;
private Instant lastUpdateTime;
private Instant lastRegisteredTime;
Expand All @@ -120,7 +120,8 @@ public class MLModel implements ToXContentObject {
private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

//is domain manager creates any special hidden model in the cluster this status will be true. Otherwise,
// is domain manager creates any special hidden model in the cluster this status
// will be true. Otherwise,
// False by default
private Boolean isHidden;
@Setter
Expand All @@ -129,35 +130,35 @@ public class MLModel implements ToXContentObject {

@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
FunctionName algorithm,
String version,
String content,
User user,
String description,
MLModelFormat modelFormat,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
Boolean isEnabled,
Boolean isModelControllerEnabled,
MLRateLimiter modelRateLimiterConfig,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Instant lastRegisteredTime,
Instant lastDeployedTime,
Instant lastUndeployedTime,
Integer autoRedeployRetryTimes,
String modelId, Integer chunkNumber,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
String modelGroupId,
FunctionName algorithm,
String version,
String content,
User user,
String description,
MLModelFormat modelFormat,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
Boolean isEnabled,
Boolean isControllerEnabled,
MLRateLimiter rateLimiter,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Instant lastRegisteredTime,
Instant lastDeployedTime,
Instant lastUndeployedTime,
Integer autoRedeployRetryTimes,
String modelId, Integer chunkNumber,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand All @@ -170,8 +171,8 @@ public MLModel(String name,
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHash = modelContentHash;
this.isEnabled = isEnabled;
this.isModelControllerEnabled = isModelControllerEnabled;
this.modelRateLimiterConfig = modelRateLimiterConfig;
this.isControllerEnabled = isControllerEnabled;
this.rateLimiter = rateLimiter;
this.modelConfig = modelConfig;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
Expand All @@ -191,7 +192,7 @@ public MLModel(String name,
this.connectorId = connectorId;
}

public MLModel(StreamInput input) throws IOException{
public MLModel(StreamInput input) throws IOException {
name = input.readOptionalString();
algorithm = input.readEnum(FunctionName.class);
version = input.readString();
Expand Down Expand Up @@ -219,9 +220,9 @@ public MLModel(StreamInput input) throws IOException{
}
}
isEnabled = input.readOptionalBoolean();
isModelControllerEnabled = input.readOptionalBoolean();
isControllerEnabled = input.readOptionalBoolean();
if (input.readBoolean()) {
modelRateLimiterConfig = new MLRateLimiter(input);
rateLimiter = new MLRateLimiter(input);
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
Expand Down Expand Up @@ -278,10 +279,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isEnabled);
out.writeOptionalBoolean(isModelControllerEnabled);
if (modelRateLimiterConfig != null) {
out.writeOptionalBoolean(isControllerEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
modelRateLimiterConfig.writeTo(out);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
Expand Down Expand Up @@ -351,11 +352,11 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isEnabled != null) {
builder.field(IS_ENABLED_FIELD, isEnabled);
}
if (isModelControllerEnabled != null) {
builder.field(IS_MODEL_CONTROLLER_ENABLED_FIELD, isModelControllerEnabled);
if (isControllerEnabled != null) {
builder.field(IS_CONTROLLER_ENABLED_FIELD, isControllerEnabled);
}
if (modelRateLimiterConfig != null) {
builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig);
if (rateLimiter != null) {
builder.field(RATE_LIMITER_FIELD, rateLimiter);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
Expand Down Expand Up @@ -426,8 +427,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
String modelContentHash = null;
MLModelConfig modelConfig = null;
Boolean isEnabled = null;
Boolean isModelControllerEnabled = null;
MLRateLimiter modelRateLimiterConfig = null;
Boolean isControllerEnabled = null;
MLRateLimiter rateLimiter = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
Instant lastUploadedTime = null;
Expand Down Expand Up @@ -516,11 +517,11 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case IS_ENABLED_FIELD:
isEnabled = parser.booleanValue();
break;
case IS_MODEL_CONTROLLER_ENABLED_FIELD:
isModelControllerEnabled = parser.booleanValue();
case IS_CONTROLLER_ENABLED_FIELD:
isControllerEnabled = parser.booleanValue();
break;
case MODEL_RATE_LIMITER_CONFIG_FIELD:
modelRateLimiterConfig = MLRateLimiter.parse(parser);
case RATE_LIMITER_FIELD:
rateLimiter = MLRateLimiter.parse(parser);
break;
case PLANNING_WORKER_NODE_COUNT_FIELD:
planningWorkerNodeCount = parser.intValue();
Expand Down Expand Up @@ -589,13 +590,13 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.modelContentHash(modelContentHash)
.modelConfig(modelConfig)
.isEnabled(isEnabled)
.isModelControllerEnabled(isModelControllerEnabled)
.modelRateLimiterConfig(modelRateLimiterConfig)
.isControllerEnabled(isControllerEnabled)
.rateLimiter(rateLimiter)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime)
.lastDeployedTime(lastDeployedTime == null? lastLoadedTime : lastDeployedTime)
.lastUndeployedTime(lastUndeployedTime == null? lastUnloadedTime : lastUndeployedTime)
.lastRegisteredTime(lastRegisteredTime == null ? lastUploadedTime : lastRegisteredTime)
.lastDeployedTime(lastDeployedTime == null ? lastLoadedTime : lastDeployedTime)
.lastUndeployedTime(lastUndeployedTime == null ? lastUnloadedTime : lastUndeployedTime)
.modelId(modelId)
.autoRedeployRetryTimes(autoRedeployRetryTimes)
.chunkNumber(chunkNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,26 @@
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

@Data
public class MLModelController implements ToXContentObject, Writeable {
public class MLController implements ToXContentObject, Writeable {

public static final String MODEL_ID_FIELD = "model_id"; // mandatory
public static final String USER_RATE_LIMITER_CONFIG = "user_rate_limiter_config";
public static final String USER_RATE_LIMITER = "user_rate_limiter";

@Getter
private String modelId;
// The String is the username field where the MLRateLimiter is its corresponding rate limiter config.
private Map<String, MLRateLimiter> userRateLimiterConfig;
// The String is the username field where the MLRateLimiter is its corresponding
// rate limiter config.
private Map<String, MLRateLimiter> userRateLimiter;

@Builder(toBuilder = true)
public MLModelController(String modelId, Map<String, MLRateLimiter> userRateLimiterConfig) {
public MLController(String modelId, Map<String, MLRateLimiter> userRateLimiter) {
this.modelId = modelId;
this.userRateLimiterConfig = userRateLimiterConfig;
this.userRateLimiter = userRateLimiter;
}

public static MLModelController parse(XContentParser parser) throws IOException {
public static MLController parse(XContentParser parser) throws IOException {
String modelId = null;
Map<String, MLRateLimiter> userRateLimiterConfig = new HashMap<>();
Map<String, MLRateLimiter> userRateLimiter = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -58,15 +59,16 @@ public static MLModelController parse(XContentParser parser) throws IOException
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case USER_RATE_LIMITER_CONFIG:
Map<String, String> userRateLimiterConfigStringMap = getParameterMap(parser.map());
userRateLimiterConfigStringMap.forEach((user, rateLimiterString) -> {
case USER_RATE_LIMITER:
Map<String, String> userRateLimiterStringMap = getParameterMap(parser.map());
userRateLimiterStringMap.forEach((user, rateLimiterString) -> {
try {
XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString);
XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser(
NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString);
rateLimiterParser.nextToken();
MLRateLimiter rateLimiter = MLRateLimiter.parse(rateLimiterParser);
if (!rateLimiter.isEmpty()) {
userRateLimiterConfig.put(user, rateLimiter);
userRateLimiter.put(user, rateLimiter);
}
} catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -79,22 +81,23 @@ public static MLModelController parse(XContentParser parser) throws IOException
}
}
// Model ID can only be set through RestRequest.
return new MLModelController(modelId, userRateLimiterConfig);
return new MLController(modelId, userRateLimiter);
}

public MLModelController(StreamInput in) throws IOException{
public MLController(StreamInput in) throws IOException {
modelId = in.readString();
if (in.readBoolean()) {
userRateLimiterConfig = in.readMap(StreamInput::readString, MLRateLimiter::new);
userRateLimiter = in.readMap(StreamInput::readString, MLRateLimiter::new);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
if (userRateLimiterConfig != null) {
if (userRateLimiter != null) {
out.writeBoolean(true);
out.writeMap(userRateLimiterConfig, StreamOutput::writeString, (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput));
out.writeMap(userRateLimiter, StreamOutput::writeString,
(streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput));
} else {
out.writeBoolean(false);
}
Expand All @@ -104,28 +107,28 @@ public void writeTo(StreamOutput out) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (userRateLimiterConfig != null) {
builder.field(USER_RATE_LIMITER_CONFIG, userRateLimiterConfig);
if (userRateLimiter != null) {
builder.field(USER_RATE_LIMITER, userRateLimiter);
}
builder.endObject();
return builder;
}


/**
* Checks if a deployment is required after updating the MLModelController.
* Checks if a deployment is required after updating the MLController.
*
* @param updateContent The updated MLModelController object.
* @param updateContent The updated MLController object.
* @return True if a deployment is required, false otherwise.
*/
public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) {
if (updateContent != null && updateContent.getUserRateLimiterConfig() != null && !updateContent.getUserRateLimiterConfig().isEmpty()) {
Map<String, MLRateLimiter> updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig();
for (Map.Entry<String, MLRateLimiter> entry : updateUserRateLimiterConfig.entrySet()) {
public boolean isDeployRequiredAfterUpdate(MLController updateContent) {
if (updateContent != null && updateContent.getUserRateLimiter() != null
&& !updateContent.getUserRateLimiter().isEmpty()) {
Map<String, MLRateLimiter> updateUserRateLimiter = updateContent.getUserRateLimiter();
for (Map.Entry<String, MLRateLimiter> entry : updateUserRateLimiter.entrySet()) {
String newUser = entry.getKey();
MLRateLimiter newRateLimiter = entry.getValue();
if (this.userRateLimiterConfig.containsKey(newUser)) {
MLRateLimiter oldRateLimiter = this.userRateLimiterConfig.get(newUser);
if (this.userRateLimiter.containsKey(newUser)) {
MLRateLimiter oldRateLimiter = this.userRateLimiter.get(newUser);
if (MLRateLimiter.isDeployRequiredAfterUpdate(oldRateLimiter, newRateLimiter)) {
return true;
}
Expand All @@ -139,16 +142,16 @@ public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) {
return false;
}

public void update(MLModelController updateContent) {
Map<String, MLRateLimiter> updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig();
if (updateUserRateLimiterConfig != null && !updateUserRateLimiterConfig.isEmpty()) {
updateUserRateLimiterConfig.forEach((user, rateLimiter) -> {
public void update(MLController updateContent) {
Map<String, MLRateLimiter> updateUserRateLimiter = updateContent.getUserRateLimiter();
if (updateUserRateLimiter != null && !updateUserRateLimiter.isEmpty()) {
updateUserRateLimiter.forEach((user, rateLimiter) -> {
// rateLimiter can't be null due to parsing exception
if (this.userRateLimiterConfig.containsKey(user)) {
this.userRateLimiterConfig.get(user).update(rateLimiter);
} else {
this.userRateLimiterConfig.put(user, rateLimiter);
}
if (this.userRateLimiter.containsKey(user)) {
this.userRateLimiter.get(user).update(rateLimiter);
} else {
this.userRateLimiter.put(user, rateLimiter);
}
});
}
}
Expand Down
Loading

0 comments on commit d0895bb

Please sign in to comment.