Skip to content

Commit

Permalink
Enable in-place update model (#1673)
Browse files Browse the repository at this point in the history
* Enable in-place update model

Signed-off-by: Sicheng Song <[email protected]>

* Refactor inplace update api as well as adding more doc/comments/annotations for clarification

Signed-off-by: Sicheng Song <[email protected]>

* Address review concern

Signed-off-by: Sicheng Song <[email protected]>

* Refactor updatemodelcache pacakge

Signed-off-by: Sicheng Song <[email protected]>

* Fix UT

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored Dec 20, 2023
1 parent 06f8b2a commit 9aea9a0
Show file tree
Hide file tree
Showing 26 changed files with 2,171 additions and 508 deletions.
2 changes: 1 addition & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
String oldContent = null;
User user = null;

String description = null;;
String description = null;
MLModelFormat modelFormat = null;
MLModelState modelState = null;
Long modelContentSizeInBytes = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public MLDeployModelNodeResponse(DiscoveryNode node, Map<String, String> modelDe
public MLDeployModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelDeployStatus = in.readMap(s -> s.readString(), s-> s.readString());
this.modelDeployStatus = in.readMap(StreamInput::readString, StreamInput::readString);
}

}
Expand All @@ -58,7 +58,7 @@ public void writeTo(StreamOutput out) throws IOException {

if (modelDeployStatus != null && modelDeployStatus.size() > 0) {
out.writeBoolean(true);
out.writeMap(modelDeployStatus, (stream, v) -> stream.writeString(v), (stream, stats) -> stream.writeString(stats));
out.writeMap(modelDeployStatus, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import java.io.IOException;
import java.util.Map;
import java.time.Instant;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.Connector.createConnector;

@Data
public class MLUpdateModelInput implements ToXContentObject, Writeable {

public static final String MODEL_ID_FIELD = "model_id"; // mandatory
public static final String MODEL_ID_FIELD = "model_id"; // passively set when passing url to rest API
public static final String DESCRIPTION_FIELD = "description"; // optional
public static final String MODEL_VERSION_FIELD = "model_version"; // optional
public static final String MODEL_VERSION_FIELD = "model_version"; // passively set when register model to a new model group
public static final String MODEL_NAME_FIELD = "name"; // optional
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional
public static final String CONNECTOR_FIELD = "connector"; // optional
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional
// The field CONNECTOR_UPDATE_CONTENT_FIELD need to be declared because the update of Connector class relies on the MLCreateConnectorInput class
public static final String CONNECTOR_UPDATE_CONTENT_FIELD = "connector_update_content";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update request

@Getter
private String modelId;
Expand All @@ -42,29 +46,44 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
private String name;
private String modelGroupId;
private MLModelConfig modelConfig;
private Connector connector;
private String connectorId;
private MLCreateConnectorInput connectorUpdateContent;
private Instant lastUpdateTime;

@Builder(toBuilder = true)
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) {
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId,
MLModelConfig modelConfig, Connector connector, String connectorId,
MLCreateConnectorInput connectorUpdateContent, Instant lastUpdateTime) {
this.modelId = modelId;
this.description = description;
this.version = version;
this.name = name;
this.modelGroupId = modelGroupId;
this.modelConfig = modelConfig;
this.connector = connector;
this.connectorId = connectorId;
this.connectorUpdateContent = connectorUpdateContent;
this.lastUpdateTime = lastUpdateTime;
}

public MLUpdateModelInput(StreamInput in) throws IOException {
this.modelId = in.readString();
this.description = in.readOptionalString();
this.version = in.readOptionalString();
this.name = in.readOptionalString();
this.modelGroupId = in.readOptionalString();
modelId = in.readString();
description = in.readOptionalString();
version = in.readOptionalString();
name = in.readOptionalString();
modelGroupId = in.readOptionalString();
if (in.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(in);
}
this.connectorId = in.readOptionalString();
if (in.readBoolean()) {
connector = Connector.fromStream(in);
}
connectorId = in.readOptionalString();
if (in.readBoolean()) {
connectorUpdateContent = new MLCreateConnectorInput(in);
}
lastUpdateTime = in.readOptionalInstant();
}

@Override
Expand All @@ -86,9 +105,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (connector != null) {
builder.field(CONNECTOR_FIELD, connector);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (connectorUpdateContent != null) {
builder.field(CONNECTOR_UPDATE_CONTENT_FIELD, connectorUpdateContent);
}
if (lastUpdateTime != null) {
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
builder.endObject();
return builder;
}
Expand All @@ -106,7 +134,20 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (connector != null) {
out.writeBoolean(true);
connector.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
if (connectorUpdateContent != null) {
out.writeBoolean(true);
connectorUpdateContent.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(lastUpdateTime);
}

public static MLUpdateModelInput parse(XContentParser parser) throws IOException {
Expand All @@ -116,7 +157,10 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
String name = null;
String modelGroupId = null;
MLModelConfig modelConfig = null;
Connector connector = null;
String connectorId = null;
MLCreateConnectorInput connectorUpdateContent = null;
Instant lastUpdateTime = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -141,15 +185,24 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
break;
case CONNECTOR_FIELD:
connector = Connector.createConnector(parser);
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case CONNECTOR_UPDATE_CONTENT_FIELD:
connectorUpdateContent = MLCreateConnectorInput.parse(parser, true);
break;
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
default:
parser.skipChildren();
break;
}
}
// Model ID can only be set through RestRequest. Model version can only be set automatically.
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId);
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connector, connectorId, connectorUpdateContent, lastUpdateTime);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ public MLUndeployModelNodeResponse(DiscoveryNode node,
public MLUndeployModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUndeployStatus = in.readMap(s -> s.readString(), s-> s.readString());
this.modelUndeployStatus = in.readMap(StreamInput::readString, StreamInput::readString);
}
if (in.readBoolean()) {
this.modelWorkerNodeBeforeRemoval = in.readMap(s -> s.readString(), s-> s.readOptionalStringArray());
this.modelWorkerNodeBeforeRemoval = in.readMap(StreamInput::readString, StreamInput::readOptionalStringArray);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.update_cache;

import org.opensearch.action.ActionType;

public class MLUpdateModelCacheAction extends ActionType<MLUpdateModelCacheNodesResponse> {
public static final MLUpdateModelCacheAction INSTANCE = new MLUpdateModelCacheAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/update_model_cache";

private MLUpdateModelCacheAction() { super(NAME, MLUpdateModelCacheNodesResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.update_cache;

import org.opensearch.transport.TransportRequest;
import java.io.IOException;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLUpdateModelCacheNodeRequest extends TransportRequest {
@Getter
private MLUpdateModelCacheNodesRequest updateModelCacheNodesRequest;

public MLUpdateModelCacheNodeRequest(StreamInput in) throws IOException {
super(in);
this.updateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(in);
}

public MLUpdateModelCacheNodeRequest(MLUpdateModelCacheNodesRequest request) {
this.updateModelCacheNodesRequest = request;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
updateModelCacheNodesRequest.writeTo(out);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.update_cache;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

@Getter
@Log4j2
public class MLUpdateModelCacheNodeResponse extends BaseNodeResponse implements ToXContentFragment {
private Map<String, String> modelUpdateStatus;

public MLUpdateModelCacheNodeResponse(DiscoveryNode node, Map<String, String> modelUpdateStatus) {
super(node);
this.modelUpdateStatus = modelUpdateStatus;
}

public MLUpdateModelCacheNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUpdateStatus = in.readMap(StreamInput::readString, StreamInput::readString);
}
}

public static MLUpdateModelCacheNodeResponse readStats(StreamInput in) throws IOException {
return new MLUpdateModelCacheNodeResponse(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);

if (!isModelUpdateStatusEmpty()) {
out.writeBoolean(true);
out.writeMap(modelUpdateStatus, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
// Similar to deploy on node or undeploy on node response, this stat map is used to track update status on each node.
builder.startObject("stats");
if (!isModelUpdateStatusEmpty()) {
for (Map.Entry<String, String> stat : modelUpdateStatus.entrySet()) {
builder.field(stat.getKey(), stat.getValue());
}
}
builder.endObject();
return builder;
}

public boolean isModelUpdateStatusEmpty() {
return modelUpdateStatus == null || modelUpdateStatus.size() == 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.update_cache;

import lombok.Getter;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import java.io.IOException;

public class MLUpdateModelCacheNodesRequest extends BaseNodesRequest<MLUpdateModelCacheNodesRequest> {

@Getter
private String modelId;
@Getter
private boolean predictorUpdate;

public MLUpdateModelCacheNodesRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.predictorUpdate = in.readBoolean();
}

public MLUpdateModelCacheNodesRequest(String[] nodeIds, String modelId, boolean predictorUpdate) {
super(nodeIds);
this.modelId = modelId;
this.predictorUpdate = predictorUpdate;
}

public MLUpdateModelCacheNodesRequest(DiscoveryNode[] nodeIds, String modelId, boolean predictorUpdate) {
super(nodeIds);
this.modelId = modelId;
this.predictorUpdate = predictorUpdate;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeBoolean(predictorUpdate);
}
}
Loading

0 comments on commit 9aea9a0

Please sign in to comment.