diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 84c3a96712..3b23cbc06d 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -37,7 +37,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; @@ -196,6 +196,15 @@ public class CommonValue { + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" + " \"" + + MLModel.QUOTA_FLAG_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.RATE_LIMIT_NUMBER_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.RATE_LIMIT_UNIT_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + "\" : {\"type\": \"keyword\"},\n" + " \"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 78f6f4ac60..4ea0157a04 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -23,10 +23,12 @@ import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import java.io.IOException; +import java.sql.Time; import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.concurrent.TimeUnit; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; @@ -50,9 +52,14 @@ public class MLModel implements ToXContentObject { public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_STATE_FIELD = "model_state"; public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes"; - //SHA256 hash value of model content. + // SHA256 hash value of model content. public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; + // Model level quota and throttling control + public static final String QUOTA_FLAG_FIELD = "quota_flag"; + public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; + public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; + public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String CREATED_TIME_FIELD = "created_time"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @@ -92,6 +99,9 @@ public class MLModel implements ToXContentObject { private Long modelContentSizeInBytes; private String modelContentHash; private MLModelConfig modelConfig; + private Boolean quotaFlag; + private String rateLimitNumber; + private TimeUnit rateLimitUnit; private Instant createdTime; private Instant lastUpdateTime; private Instant lastRegisteredTime; @@ -126,6 +136,9 @@ public MLModel(String name, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHash, + Boolean quotaFlag, + String rateLimitNumber, + TimeUnit rateLimitUnit, MLModelConfig modelConfig, Instant createdTime, Instant lastUpdateTime, @@ -152,6 +165,9 @@ public MLModel(String name, this.modelState = modelState; this.modelContentSizeInBytes = modelContentSizeInBytes; this.modelContentHash = modelContentHash; + this.quotaFlag = quotaFlag; + this.rateLimitNumber = rateLimitNumber; + this.rateLimitUnit = rateLimitUnit; this.modelConfig = modelConfig; this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; @@ -197,6 +213,11 @@ public MLModel(StreamInput input) throws IOException{ modelConfig = new TextEmbeddingModelConfig(input); } } + quotaFlag = input.readOptionalBoolean(); + rateLimitNumber = input.readOptionalString(); + if (input.readBoolean()) { + rateLimitUnit = input.readEnum(TimeUnit.class); + } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); lastRegisteredTime = input.readOptionalInstant(); @@ -250,6 +271,14 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(quotaFlag); + out.writeOptionalString(rateLimitNumber); + if (rateLimitUnit != null) { + out.writeBoolean(true); + out.writeEnum(rateLimitUnit); + } else { + out.writeBoolean(false); + } out.writeOptionalInstant(createdTime); out.writeOptionalInstant(lastUpdateTime); out.writeOptionalInstant(lastRegisteredTime); @@ -312,6 +341,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (modelConfig != null) { builder.field(MODEL_CONFIG_FIELD, modelConfig); } + if (quotaFlag != null) { + builder.field(QUOTA_FLAG_FIELD, quotaFlag); + } + if (rateLimitNumber != null) { + builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + } + if (rateLimitUnit != null) { + builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } @@ -371,12 +409,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String oldContent = null; User user = null; - String description = null;; + String description = null; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; String modelContentHash = null; MLModelConfig modelConfig = null; + Boolean quotaFlag = null; + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; Instant createdTime = null; Instant lastUpdateTime = null; Instant lastUploadedTime = null; @@ -461,6 +502,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws modelConfig = TextEmbeddingModelConfig.parse(parser); } break; + case QUOTA_FLAG_FIELD: + quotaFlag = parser.booleanValue(); + break; + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; case PLANNING_WORKER_NODE_COUNT_FIELD: planningWorkerNodeCount = parser.intValue(); break; @@ -524,6 +574,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .modelContentSizeInBytes(modelContentSizeInBytes) .modelContentHash(modelContentHash) .modelConfig(modelConfig) + .quotaFlag(quotaFlag) + .rateLimitNumber(rateLimitNumber) + .rateLimitUnit(rateLimitUnit) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelAction.java new file mode 100644 index 0000000000..a8cecb43a0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.ActionType; + +public class MLInPlaceUpdateModelAction extends ActionType { + public static final MLInPlaceUpdateModelAction INSTANCE = new MLInPlaceUpdateModelAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/in_place_update"; + + private MLInPlaceUpdateModelAction() { super(NAME, MLInPlaceUpdateModelNodesResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeRequest.java new file mode 100644 index 0000000000..0bcfcf729a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.support.nodes.BaseNodeRequest; +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLInPlaceUpdateModelNodeRequest extends BaseNodeRequest { + @Getter + private MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest; + + public MLInPlaceUpdateModelNodeRequest(StreamInput in) throws IOException { + super(in); + this.mlInPlaceUpdateModelNodesRequest = new MLInPlaceUpdateModelNodesRequest(in); + } + + public MLInPlaceUpdateModelNodeRequest(MLInPlaceUpdateModelNodesRequest request) { + this.mlInPlaceUpdateModelNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + mlInPlaceUpdateModelNodesRequest.writeTo(out); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeResponse.java new file mode 100644 index 0000000000..3faa9ca7f4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodeResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +@Getter +@Log4j2 +public class MLInPlaceUpdateModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map modelUpdateStatus; + + public MLInPlaceUpdateModelNodeResponse(DiscoveryNode node, Map modelUpdateStatus) { + super(node); + this.modelUpdateStatus = modelUpdateStatus; + } + + public MLInPlaceUpdateModelNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.modelUpdateStatus = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + + public static MLInPlaceUpdateModelNodeResponse readStats(StreamInput in) throws IOException { + return new MLInPlaceUpdateModelNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + + if (!isEmpty()) { + out.writeBoolean(true); + out.writeMap(modelUpdateStatus, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject("stats"); + if (modelUpdateStatus != null && modelUpdateStatus.size() > 0) { + for (Map.Entry stat : modelUpdateStatus.entrySet()) { + builder.field(stat.getKey(), stat.getValue()); + } + } + builder.endObject(); + return builder; + } + + public boolean isEmpty() { + return modelUpdateStatus == null || modelUpdateStatus.size() == 0; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesRequest.java new file mode 100644 index 0000000000..223becc1dc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesRequest.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import java.io.IOException; + +public class MLInPlaceUpdateModelNodesRequest extends BaseNodesRequest { + + @Getter + private String modelId; + @Getter + private boolean updatePredictorFlag; + + public MLInPlaceUpdateModelNodesRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.updatePredictorFlag = in.readBoolean(); + } + + public MLInPlaceUpdateModelNodesRequest(String[] nodeIds, String modelId, boolean updatePredictorFlag) { + super(nodeIds); + this.modelId = modelId; + this.updatePredictorFlag = updatePredictorFlag; + } + + public MLInPlaceUpdateModelNodesRequest(DiscoveryNode[] nodeIds, String modelId, boolean updatePredictorFlag) { + super(nodeIds); + this.modelId = modelId; + this.updatePredictorFlag = updatePredictorFlag; + } + + public MLInPlaceUpdateModelNodesRequest(DiscoveryNode... nodes) { + super(nodes); + this.updatePredictorFlag = false; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeBoolean(updatePredictorFlag); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesResponse.java new file mode 100644 index 0000000000..4c669fa8a4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLInPlaceUpdateModelNodesResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class MLInPlaceUpdateModelNodesResponse extends BaseNodesResponse implements ToXContentObject { + + public MLInPlaceUpdateModelNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLInPlaceUpdateModelNodeResponse::readStats), in.readList(FailedNodeException::new)); + } + + public MLInPlaceUpdateModelNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLInPlaceUpdateModelNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + String nodeId; + DiscoveryNode node; + builder.startObject(); + for (MLInPlaceUpdateModelNodeResponse undeployStats : getNodes()) { + if (!undeployStats.isEmpty()) { + node = undeployStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + undeployStats.toXContent(builder, params); + builder.endObject(); + } + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java new file mode 100644 index 0000000000..2d584a0e73 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateModelAction extends ActionType { + public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/update"; + + private MLUpdateModelAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java new file mode 100644 index 0000000000..7c5a4e0b01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Data; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Data +public class MLUpdateModelInput implements ToXContentObject, Writeable { + + public static final String MODEL_ID_FIELD = "model_id"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String MODEL_VERSION_FIELD = "model_version"; // optional + public static final String MODEL_NAME_FIELD = "name"; // optional + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String QUOTA_FLAG_FIELD = "quota_flag"; // optional + public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; // optional + public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; // optional + public static final String MODEL_CONFIG_FIELD = "model_config"; // optional + public static final String CONNECTOR_FIELD = "connector"; // optional + public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional + public static final String CONNECTOR_UPDATE_CONTENT_FIELD = "connector_update_content"; // optional + + @Getter + private String modelId; + private String description; + private String version; + private String name; + private String modelGroupId; + private Boolean quotaFlag; + private String rateLimitNumber; + private TimeUnit rateLimitUnit; + private MLModelConfig modelConfig; + private Connector connector; + private String connectorId; + private MLCreateConnectorInput connectorUpdateContent; + + @Builder(toBuilder = true) + public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, + Boolean quotaFlag, String rateLimitNumber, TimeUnit rateLimitUnit, MLModelConfig modelConfig, + Connector connector, String connectorId, MLCreateConnectorInput connectorUpdateContent) { + this.modelId = modelId; + this.description = description; + this.version = version; + this.name = name; + this.modelGroupId = modelGroupId; + this.quotaFlag = quotaFlag; + this.rateLimitNumber = rateLimitNumber; + this.rateLimitUnit = rateLimitUnit; + this.modelConfig = modelConfig; + this.connector = connector; + this.connectorId = connectorId; + this.connectorUpdateContent = connectorUpdateContent; + } + + public MLUpdateModelInput(StreamInput in) throws IOException { + modelId = in.readString(); + description = in.readOptionalString(); + version = in.readOptionalString(); + name = in.readOptionalString(); + modelGroupId = in.readOptionalString(); + quotaFlag = in.readOptionalBoolean(); + rateLimitNumber = in.readOptionalString(); + if (in.readBoolean()) { + rateLimitUnit = in.readEnum(TimeUnit.class); + } + if (in.readBoolean()) { + modelConfig = new TextEmbeddingModelConfig(in); + } + if (in.readBoolean()) { + connector = Connector.fromStream(in); + } + connectorId = in.readOptionalString(); + if (in.readBoolean()) { + connectorUpdateContent = new MLCreateConnectorInput(in); + } + } + + public MLUpdateModelInput() {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + if (name != null) { + builder.field(MODEL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (version != null) { + builder.field(MODEL_VERSION_FIELD, version); + } + if (modelGroupId != null) { + builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); + } + if (quotaFlag != null) { + builder.field(QUOTA_FLAG_FIELD, quotaFlag); + } + if (rateLimitNumber != null) { + builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + } + if (rateLimitUnit != null) { + builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + } + if (modelConfig != null) { + builder.field(MODEL_CONFIG_FIELD, modelConfig); + } + if (connector != null) { + builder.field(CONNECTOR_FIELD, connector); + } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } + if (connectorUpdateContent != null) { + builder.field(CONNECTOR_UPDATE_CONTENT_FIELD, connectorUpdateContent); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(description); + out.writeOptionalString(version); + out.writeOptionalString(name); + out.writeOptionalString(modelGroupId); + out.writeOptionalBoolean(quotaFlag); + out.writeOptionalString(rateLimitNumber); + if (rateLimitUnit != null) { + out.writeBoolean(true); + out.writeEnum(rateLimitUnit); + } else { + out.writeBoolean(false); + } + if (modelConfig != null) { + out.writeBoolean(true); + modelConfig.writeTo(out); + } else { + out.writeBoolean(false); + } + if (connector != null) { + out.writeBoolean(true); + connector.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(connectorId); + if (connectorUpdateContent != null) { + out.writeBoolean(true); + connectorUpdateContent.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + public static MLUpdateModelInput parse(XContentParser parser) throws IOException { + String modelId = null; + String description = null; + String version = null; + String name = null; + String modelGroupId = null; + Boolean quotaFlag = null; + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; + MLModelConfig modelConfig = null; + Connector connector = null; + String connectorId = null; + MLCreateConnectorInput connectorUpdateContent = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case MODEL_NAME_FIELD: + name = parser.text(); + break; + case MODEL_VERSION_FIELD: + version = parser.text(); + break; + case MODEL_GROUP_ID_FIELD: + modelGroupId = parser.text(); + break; + case QUOTA_FLAG_FIELD: + quotaFlag = parser.booleanValue(); + break; + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; + case MODEL_CONFIG_FIELD: + modelConfig = TextEmbeddingModelConfig.parse(parser); + break; + case CONNECTOR_FIELD: + connector = Connector.createConnector(parser); + break; + case CONNECTOR_ID_FIELD: + connectorId = parser.text(); + break; + case CONNECTOR_UPDATE_CONTENT_FIELD: + connectorUpdateContent = MLCreateConnectorInput.parse(parser, true); + break; + default: + parser.skipChildren(); + break; + } + } + // Model ID can only be set through RestRequest. Model version can only be set automatically. + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, quotaFlag, rateLimitNumber, rateLimitUnit, modelConfig, connector, connectorId, connectorUpdateContent); + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java new file mode 100644 index 0000000000..b589f71ed4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLUpdateModelRequest extends ActionRequest { + + MLUpdateModelInput updateModelInput; + + @Builder + public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { + this.updateModelInput = updateModelInput; + } + + public MLUpdateModelRequest(StreamInput in) throws IOException { + super(in); + updateModelInput = new MLUpdateModelInput(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (updateModelInput == null) { + exception = addValidationError("Update Model Input can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.updateModelInput.writeTo(out); + } + + public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ + if (actionRequest instanceof MLUpdateModelRequest) { + return (MLUpdateModelRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateModelRequest(in); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); + } + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 6697461429..32ca3b622b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -22,9 +22,11 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; +import java.sql.Time; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.concurrent.TimeUnit; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.connector.Connector.createConnector; @@ -41,6 +43,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; + public static final String QUOTA_FLAG_FIELD = "quota_flag"; + public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; + public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; public static final String URL_FIELD = "url"; public static final String HASH_VALUE_FIELD = "model_content_hash_value"; public static final String MODEL_FORMAT_FIELD = "model_format"; @@ -59,6 +64,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private String modelGroupId; private String version; private String description; + private Boolean quotaFlag; + private String rateLimitNumber; + private TimeUnit rateLimitUnit; private String url; private String hashValue; private MLModelFormat modelFormat; @@ -81,6 +89,9 @@ public MLRegisterModelInput(FunctionName functionName, String modelGroupId, String version, String description, + Boolean quotaFlag, + String rateLimitNumber, + TimeUnit rateLimitUnit, String url, String hashValue, MLModelFormat modelFormat, @@ -114,6 +125,9 @@ public MLRegisterModelInput(FunctionName functionName, this.modelGroupId = modelGroupId; this.version = version; this.description = description; + this.quotaFlag = quotaFlag; + this.rateLimitNumber = rateLimitNumber; + this.rateLimitUnit = rateLimitUnit; this.url = url; this.hashValue = hashValue; this.modelFormat = modelFormat; @@ -135,6 +149,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelGroupId = in.readOptionalString(); this.version = in.readOptionalString(); this.description = in.readOptionalString(); + this.quotaFlag = in.readOptionalBoolean(); + this.rateLimitNumber = in.readOptionalString(); + if (in.readBoolean()) { + rateLimitUnit = in.readEnum(TimeUnit.class); + } this.url = in.readOptionalString(); this.hashValue = in.readOptionalString(); if (in.readBoolean()) { @@ -170,6 +189,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); out.writeOptionalString(version); out.writeOptionalString(description); + out.writeOptionalBoolean(quotaFlag); + out.writeOptionalString(rateLimitNumber); + if (rateLimitUnit != null) { + out.writeBoolean(true); + out.writeEnum(rateLimitUnit); + } else { + out.writeBoolean(false); + } out.writeOptionalString(url); out.writeOptionalString(hashValue); if (modelFormat != null) { @@ -223,6 +250,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } + if (quotaFlag != null) { + builder.field(QUOTA_FLAG_FIELD, quotaFlag); + } + if (rateLimitNumber != null) { + builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + } + if (rateLimitUnit != null) { + builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + } if (url != null) { builder.field(URL_FIELD, url); } @@ -264,6 +300,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException { FunctionName functionName = null; String modelGroupId = null; + Boolean quotaFlag = null; + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; String url = null; String hashValue = null; String description = null; @@ -288,6 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case MODEL_GROUP_ID_FIELD: modelGroupId = parser.text(); break; + case QUOTA_FLAG_FIELD: + quotaFlag = parser.booleanValue(); + break; + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; case URL_FIELD: url = parser.text(); break; @@ -335,7 +383,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, quotaFlag, rateLimitNumber, rateLimitUnit, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -343,6 +391,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo String name = null; String modelGroupId = null; String version = null; + Boolean quotaFlag = null; + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; String url = null; String hashValue = null; String description = null; @@ -377,6 +428,15 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case DESCRIPTION_FIELD: description = parser.text(); break; + case QUOTA_FLAG_FIELD: + quotaFlag = parser.booleanValue(); + break; + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; case URL_FIELD: url = parser.text(); break; @@ -421,6 +481,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, quotaFlag, rateLimitNumber, rateLimitUnit, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index ecb03d9bb6..b157f7cc45 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.concurrent.TimeUnit; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -34,7 +35,9 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ public static final String FUNCTION_NAME_FIELD = "function_name"; public static final String MODEL_NAME_FIELD = "name"; //mandatory public static final String DESCRIPTION_FIELD = "description"; //optional - + public static final String QUOTA_FLAG_FIELD = "quota_flag"; + public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; + public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; public static final String VERSION_FIELD = "version"; public static final String MODEL_FORMAT_FIELD = "model_format"; //mandatory public static final String MODEL_STATE_FIELD = "model_state"; @@ -55,7 +58,9 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private String modelGroupId; private String description; private String version; - + private Boolean quotaFlag; + private String rateLimitNumber; + private TimeUnit rateLimitUnit; private MLModelFormat modelFormat; private MLModelState modelState; @@ -70,7 +75,7 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) - public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, + public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, Boolean quotaFlag, String rateLimitNumber, TimeUnit rateLimitUnit, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, Boolean isAddAllBackendRoles, Boolean doesVersionCreateModelGroup) { @@ -98,6 +103,9 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.modelGroupId = modelGroupId; this.version = version; this.description = description; + this.quotaFlag = quotaFlag; + this.rateLimitNumber = rateLimitNumber; + this.rateLimitUnit = rateLimitUnit; this.modelFormat = modelFormat; this.modelState = modelState; this.modelContentSizeInBytes = modelContentSizeInBytes; @@ -116,6 +124,11 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ this.modelGroupId = in.readOptionalString(); this.version = in.readOptionalString(); this.description = in.readOptionalString(); + this.quotaFlag = in.readOptionalBoolean(); + this.rateLimitNumber = in.readOptionalString(); + if (in.readBoolean()) { + rateLimitUnit = in.readEnum(TimeUnit.class); + } if (in.readBoolean()) { modelFormat = in.readEnum(MLModelFormat.class); } @@ -143,6 +156,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); out.writeOptionalString(version); out.writeOptionalString(description); + out.writeOptionalBoolean(quotaFlag); + out.writeOptionalString(rateLimitNumber); + if (rateLimitUnit != null) { + out.writeBoolean(true); + out.writeEnum(rateLimitUnit); + } else { + out.writeBoolean(false); + } if (modelFormat != null) { out.writeBoolean(true); out.writeEnum(modelFormat); @@ -194,6 +215,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (description != null) { builder.field(DESCRIPTION_FIELD, description); } + if (quotaFlag != null) { + builder.field(QUOTA_FLAG_FIELD, quotaFlag); + } + if (rateLimitNumber != null) { + builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + } + if (rateLimitUnit != null) { + builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + } builder.field(MODEL_FORMAT_FIELD, modelFormat); if (modelState != null) { builder.field(MODEL_STATE_FIELD, modelState); @@ -226,6 +256,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc String modelGroupId = null; String version = null; String description = null; + Boolean quotaFlag = null; + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; @@ -257,6 +290,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case DESCRIPTION_FIELD: description = parser.text(); break; + case QUOTA_FLAG_FIELD: + quotaFlag = parser.booleanValue(); + break; + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; case MODEL_FORMAT_FIELD: modelFormat = MLModelFormat.from(parser.text()); break; @@ -296,7 +338,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, quotaFlag, rateLimitNumber, rateLimitUnit, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java new file mode 100644 index 0000000000..6bafe81692 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +public class MLUpdateModelInputTest { + + private MLUpdateModelInput updateModelInput; + private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .version("2") + .name("name") + .description("description") + .modelConfig(config) + .connectorId("test-connector_id") + .build(); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(updateModelInput, parsedInput -> { + assertEquals("test-model_id", parsedInput.getModelId()); + assertEquals(updateModelInput.getName(), parsedInput.getName()); + }); + } + + @Test + public void readInputStream_SuccessWithNullFields() throws IOException { + updateModelInput.setModelConfig(null); + readInputStream(updateModelInput, parsedInput -> { + assertNull(parsedInput.getModelConfig()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContent_Incomplete() throws Exception { + String expectedIncompleteInputStr = + "{\"model_id\":\"test-model_id\"}"; + updateModelInput.setDescription(null); + updateModelInput.setVersion(null); + updateModelInput.setName(null); + updateModelInput.setModelGroupId(null); + updateModelInput.setModelConfig(null); + updateModelInput.setConnectorId(null); + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void parse_Success() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("name", parsedInput.getName()); + }); + } + + @Test + public void parse_WithNullFieldWithoutModel() throws Exception { + exceptionRule.expect(IllegalStateException.class); + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parse_WithIllegalFieldWithoutModel() throws Exception { + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + parser.nextToken(); + MLUpdateModelInput parsedInput = MLUpdateModelInput.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLUpdateModelInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUpdateModelInput parsedInput = new MLUpdateModelInput(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLUpdateModelInput input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java new file mode 100644 index 0000000000..cadf865b1c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.junit.Before; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + + +public class MLUpdateModelRequestTest { + + private MLUpdateModelRequest updateModelRequest; + + @Before + public void setUp(){ + MockitoAnnotations.openMocks(this); + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .name("name") + .description("description") + .modelConfig(config) + .build(); + + updateModelRequest = MLUpdateModelRequest.builder() + .updateModelInput(updateModelInput) + .build(); + + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + updateModelRequest.writeTo(bytesStreamOutput); + MLUpdateModelRequest parsedUpdateRequest = new MLUpdateModelRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("test-model_id", parsedUpdateRequest.getUpdateModelInput().getModelId()); + assertEquals("name", parsedUpdateRequest.getUpdateModelInput().getName()); + } + + @Test + public void validate_Success() { + assertNull(updateModelRequest.validate()); + } + + @Test + public void validate_Exception_NullModelInput() { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.builder().build(); + Exception exception = updateModelRequest.validate(); + + assertEquals("Validation Failed: 1: Update Model Input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + assertSame(MLUpdateModelRequest.fromActionRequest(updateModelRequest), updateModelRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateModelRequest.writeTo(out); + } + }; + MLUpdateModelRequest request = MLUpdateModelRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateModelRequest); + assertEquals(updateModelRequest.getUpdateModelInput().getName(), request.getUpdateModelInput().getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUpdateModelRequest.fromActionRequest(actionRequest); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 61e57d4ac6..0abf977804 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,7 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); + "Model Description", null, null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index d7039780f0..faa3a0d300 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); + "Model Description", null, null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index b03e6028fa..7b953c341a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; @@ -19,11 +20,11 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; @@ -79,7 +80,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { log.error("Failed to get connector index", e); - actionListener.onFailure(new IllegalArgumentException("Fail to find connector")); + actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { log.error("Failed to get ML connector " + connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 86e4afe56f..066ca5f8a7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.search.SearchRequest; @@ -27,13 +28,13 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.engine.MLEngine; @@ -136,10 +137,11 @@ private void updateUndeployedConnector( } listener .onFailure( - new MLValidationException( + new OpenSearchStatusException( searchHits.length + " models are still using this connector, please undeploy the models first: " - + Arrays.toString(modelIds.toArray(new String[0])) + + Arrays.toString(modelIds.toArray(new String[0])), + RestStatus.BAD_REQUEST ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/InPlaceUpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/InPlaceUpdateModelTransportAction.java new file mode 100644 index 0000000000..67fd61bfdb --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/models/InPlaceUpdateModelTransportAction.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelNodeRequest; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelNodeResponse; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelNodesRequest; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class InPlaceUpdateModelTransportAction extends + TransportNodesAction { + private final MLModelManager mlModelManager; + private final ClusterService clusterService; + private final Client client; + private DiscoveryNodeHelper nodeFilter; + private final MLStats mlStats; + private NamedXContentRegistry xContentRegistry; + + private ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public InPlaceUpdateModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLModelManager mlModelManager, + ClusterService clusterService, + ThreadPool threadPool, + Client client, + DiscoveryNodeHelper nodeFilter, + MLStats mlStats, + NamedXContentRegistry xContentRegistry, + ModelAccessControlHelper modelAccessControlHelper + ) { + super( + MLInPlaceUpdateModelAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + MLInPlaceUpdateModelNodesRequest::new, + MLInPlaceUpdateModelNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + MLInPlaceUpdateModelNodeResponse.class + ); + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.client = client; + this.nodeFilter = nodeFilter; + this.mlStats = mlStats; + this.xContentRegistry = xContentRegistry; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected MLInPlaceUpdateModelNodesResponse newResponse( + MLInPlaceUpdateModelNodesRequest nodesRequest, + List responses, + List failures + ) { + return new MLInPlaceUpdateModelNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected MLInPlaceUpdateModelNodeRequest newNodeRequest(MLInPlaceUpdateModelNodesRequest request) { + return new MLInPlaceUpdateModelNodeRequest(request); + } + + @Override + protected MLInPlaceUpdateModelNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLInPlaceUpdateModelNodeResponse(in); + } + + @Override + protected MLInPlaceUpdateModelNodeResponse nodeOperation(MLInPlaceUpdateModelNodeRequest request) { + return createInPlaceUpdateModelNodeResponse(request.getMlInPlaceUpdateModelNodesRequest()); + } + + private MLInPlaceUpdateModelNodeResponse createInPlaceUpdateModelNodeResponse( + MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest + ) { + String modelId = mlInPlaceUpdateModelNodesRequest.getModelId(); + boolean updatePredictorFlag = mlInPlaceUpdateModelNodesRequest.isUpdatePredictorFlag(); + + Map modelUpdateStatus = new HashMap<>(); + modelUpdateStatus.put(modelId, "received"); + + String localNodeId = clusterService.localNode().getId(); + + mlModelManager.inplaceUpdateModel(modelId, updatePredictorFlag, ActionListener.wrap(r -> { + log.info("Successfully performing in-place update model {} on node {}", modelId, localNodeId); + }, e -> { log.error("Failed to perform in-place update model for model {} on node {}", modelId, localNodeId); })); + return new MLInPlaceUpdateModelNodeResponse(clusterService.localNode(), modelUpdateStatus); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java new file mode 100644 index 0000000000..566329240c --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -0,0 +1,531 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelNodesRequest; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(level = AccessLevel.PRIVATE) +public class UpdateModelTransportAction extends HandledTransportAction { + Client client; + ClusterService clusterService; + ModelAccessControlHelper modelAccessControlHelper; + ConnectorAccessControlHelper connectorAccessControlHelper; + MLModelManager mlModelManager; + MLModelGroupManager mlModelGroupManager; + MLEngine mlEngine; + volatile List trustedConnectorEndpointsRegex; + + @Inject + public UpdateModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper, + ModelAccessControlHelper modelAccessControlHelper, + MLModelManager mlModelManager, + MLModelGroupManager mlModelGroupManager, + Settings settings, + ClusterService clusterService, + MLEngine mlEngine + ) { + super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); + this.client = client; + this.modelAccessControlHelper = modelAccessControlHelper; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + this.mlModelGroupManager = mlModelGroupManager; + this.clusterService = clusterService; + this.mlEngine = mlEngine; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.fromActionRequest(request); + MLUpdateModelInput updateModelInput = updateModelRequest.getUpdateModelInput(); + String modelId = updateModelInput.getModelId(); + User user = RestActionUtils.getUserContext(client); + + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + boolean modelDeployFlag = isModelDeployed(mlModel.getModelState()); + FunctionName functionName = mlModel.getAlgorithm(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + updateRemoteOrTextEmbeddingModel( + modelId, + updateModelInput, + mlModel, + user, + actionListener, + context, + modelDeployFlag + ); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new MLValidationException( + "User doesn't have privilege to perform this operation on this function category: " + + functionName.toString() + ) + ); + } + }, + e -> actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model to update with the provided model id: " + modelId, + RestStatus.NOT_FOUND + ) + ) + )); + } catch (Exception e) { + log.error("Failed to update ML model for " + modelId, e); + actionListener.onFailure(e); + } + } + + private void updateRemoteOrTextEmbeddingModel( + String modelId, + MLUpdateModelInput updateModelInput, + MLModel mlModel, + User user, + ActionListener actionListener, + ThreadContext.StoredContext context, + boolean modelDeployFlag + ) { + String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) + && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; + String relinkConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; + + if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { + if (relinkConnectorId == null && updateModelInput.getConnectorUpdateContent() == null) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + actionListener, + context, + modelDeployFlag + ); + } else { + actionListener + .onFailure( + new OpenSearchStatusException("Trying to update the connector_id field on a local model.", RestStatus.BAD_REQUEST) + ); + } + } else { + // mlModel.getAlgorithm() == REMOTE + if (relinkConnectorId == null) { + if (updateModelInput.getConnectorUpdateContent() != null) { + Connector connector = mlModel.getConnector(); + connector.update(updateModelInput.getConnectorUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + updateModelInput.setConnector(connector); + updateModelInput.setConnectorUpdateContent(null); + } + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + actionListener, + context, + modelDeployFlag + ); + } else { + updateModelWithRelinkStandAloneConnector( + modelId, + newModelGroupId, + relinkConnectorId, + mlModel, + user, + updateModelInput, + actionListener, + context, + modelDeployFlag + ); + } + } + } + + private void updateModelWithRelinkStandAloneConnector( + String modelId, + String newModelGroupId, + String relinkConnectorId, + MLModel mlModel, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context, + boolean modelDeployFlag + ) { + if (Strings.hasLength(mlModel.getConnectorId())) { + connectorAccessControlHelper + .validateConnectorAccess(client, relinkConnectorId, ActionListener.wrap(hasRelinkConnectorPermission -> { + if (hasRelinkConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + actionListener, + context, + modelDeployFlag + ); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + relinkConnectorId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", relinkConnectorId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + RestStatus.BAD_REQUEST + ) + ); + } + } + + private void updateModelWithRegisteringToAnotherModelGroup( + String modelId, + String newModelGroupId, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context, + boolean modelDeployFlag + ) { + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); + // This flag is used to decide if we need to re-deploy the predictor(model) when performing the in-place update + boolean updatePredictorFlag = (updateModelInput.getConnector() != null || updateModelInput.getConnectorId() != null); + boolean inPlaceUpdateFieldFlag = (updateModelInput.getQuotaFlag() != null + || updateModelInput.getRateLimitNumber() != null + || updateModelInput.getRateLimitUnit() != null + || updatePredictorFlag); + // This flag is used to decide if we need to perform an in-place update + boolean inPlaceUpdateFlag = modelDeployFlag && inPlaceUpdateFieldFlag; + if (newModelGroupId != null) { + modelAccessControlHelper.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasRelinkPermission -> { + if (hasRelinkPermission) { + mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + updateRequestConstructor( + modelId, + newModelGroupId, + updateRequest, + updateModelInput, + newModelGroupResponse, + actionListener, + context, + inPlaceUpdateFlag, + updatePredictorFlag + ); + }, + exception -> actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + + newModelGroupId, + RestStatus.NOT_FOUND + ) + ) + )); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + + newModelGroupId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + updateRequestConstructor( + modelId, + updateRequest, + updateModelInput, + actionListener, + context, + inPlaceUpdateFlag, + updatePredictorFlag + ); + } + } + + private void updateRequestConstructor( + String modelId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context, + boolean inPlaceUpdateFlag, + boolean updatePredictorFlag + ) { + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + if (inPlaceUpdateFlag) { + String[] targetNodeIds = getAllNodes(); + MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest = new MLInPlaceUpdateModelNodesRequest( + targetNodeIds, + modelId, + updatePredictorFlag + ); + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context, mlInPlaceUpdateModelNodesRequest)); + } else { + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + } + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private void updateRequestConstructor( + String modelId, + String newModelGroupId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + GetResponse newModelGroupResponse, + ActionListener actionListener, + ThreadContext.StoredContext context, + boolean inPlaceUpdateFlag, + boolean updatePredictorFlag + ) { + Map newModelGroupSourceMap = newModelGroupResponse.getSourceAsMap(); + String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); + updateModelInput.setVersion(updatedVersion); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + newModelGroupSourceMap, + newModelGroupId, + newModelGroupResponse.getSeqNo(), + newModelGroupResponse.getPrimaryTerm(), + Integer.parseInt(updatedVersion) + ); + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + if (inPlaceUpdateFlag) { + String[] targetNodeIds = getAllNodes(); + MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest = new MLInPlaceUpdateModelNodesRequest( + targetNodeIds, + modelId, + updatePredictorFlag + ); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client + .update( + updateRequest, + getUpdateResponseListener(modelId, actionListener, context, mlInPlaceUpdateModelNodesRequest) + ); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId + ); + actionListener.onFailure(e); + })); + } else { + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId + ); + actionListener.onFailure(e); + })); + } + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String modelId, + ActionListener actionListener, + ThreadContext.StoredContext context, + MLInPlaceUpdateModelNodesRequest mlInPlaceUpdateModelNodesRequest + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + client.execute(MLInPlaceUpdateModelAction.INSTANCE, mlInPlaceUpdateModelNodesRequest, ActionListener.wrap(r -> { + log.info("Successfully perform in-place update ML model with model ID {}", modelId); + actionListener.onResponse(updateResponse); + }, e -> { + log.error("Failed to perform in-place update for model: {}" + modelId, e); + actionListener.onFailure(e); + })); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }), context::restore); + } + + private ActionListener getUpdateResponseListener( + String modelId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Successfully update ML model with model ID {}", modelId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }), context::restore); + } + + private String incrementLatestVersion(Map modelGroupSourceMap) { + return Integer.toString((int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1); + } + + private UpdateRequest createUpdateModelGroupRequest( + Map modelGroupSourceMap, + String modelGroupId, + long seqNo, + long primaryTerm, + int updatedVersion + ) { + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateModelGroupRequest = new UpdateRequest(); + + updateModelGroupRequest + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .doc(modelGroupSourceMap); + + return updateModelGroupRequest; + } + + private Boolean isModelDeployed(MLModelState mlModelState) { + return mlModelState.equals(MLModelState.LOADED) + || mlModelState.equals(MLModelState.PARTIALLY_LOADED) + || mlModelState.equals(MLModelState.DEPLOYED) + || mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); + } + + private String[] getAllNodes() { + Iterator iterator = clusterService.state().nodes().iterator(); + List nodeIds = new ArrayList<>(); + while (iterator.hasNext()) { + nodeIds.add(iterator.next().getId()); + } + return nodeIds.toArray(new String[0]); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 63be5e2423..d60567653b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.prediction; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -14,6 +15,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -104,7 +106,20 @@ public void onResponse(MLModel mlModel) { new MLValidationException("User Doesn't have privilege to perform this operation on this model") ); } else { - executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + if (modelCacheHelper.getQuotaFlag(modelId) != null && !modelCacheHelper.getQuotaFlag(modelId)) { + wrappedListener + .onFailure(new OpenSearchStatusException("Quota is depleted.", RestStatus.TOO_MANY_REQUESTS)); + } else { + if (modelCacheHelper.getRateLimiter(modelId) != null + && !modelCacheHelper.getRateLimiter(modelId).request()) { + wrappedListener + .onFailure( + new OpenSearchStatusException("Request is throttled.", RestStatus.TOO_MANY_REQUESTS) + ); + } else { + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + } + } } }, e -> { log.error("Failed to Validate Access for ModelId " + modelId, e); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index b2c912b9d5..b1096e7e38 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -19,6 +20,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -30,7 +32,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.builder.SearchSourceBuilder; @@ -100,11 +101,11 @@ public void getConnector(Client client, String connectorId, ActionListener { - log.error("Fail to get connector", e); - listener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); + log.error("Failed to get connector", e); + listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 5fd7d71ce0..f3646a0a44 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -13,6 +13,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.stream.DoubleStream; +import org.opensearch.common.util.TokenBucket; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -33,6 +34,8 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) TokenBucket rateLimiter; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean quotaFlag; private final Set targetWorkerNodes; private final Set workerNodes; private MLModel modelInfo; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 553ffeb664..6f21014aae 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -17,6 +17,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.TokenBucket; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; @@ -75,6 +76,50 @@ public synchronized void setModelState(String modelId, MLModelState state) { getExistingModelCache(modelId).setModelState(state); } + /** + * Set a rate limiter to enable throttling + * @param modelId model id + * @param rateLimiter rate limiter + */ + public synchronized void setRateLimiter(String modelId, TokenBucket rateLimiter) { + log.debug("Setting the rate limiter for Model {}", modelId); + getExistingModelCache(modelId).setRateLimiter(rateLimiter); + } + + /** + * Get the current rate limiter for the model + * @param modelId model id + */ + public TokenBucket getRateLimiter(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getRateLimiter(); + } + + /** + * Set a quota flag to control if the model can still receive request + * @param modelId model id + * @param quotaFlag rate limiter + */ + public synchronized void setQuotaFlag(String modelId, Boolean quotaFlag) { + log.debug("Setting the quota flag for Model {}", modelId); + getExistingModelCache(modelId).setQuotaFlag(quotaFlag); + } + + /** + * Get the current quota flag condition for the model + * @param modelId model id + */ + public Boolean getQuotaFlag(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getQuotaFlag(); + } + /** * Set memory size estimation CPU/GPU * @param modelId model id diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 94cbcf5364..83523729e4 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -11,6 +11,8 @@ import java.util.HashSet; import java.util.Iterator; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -30,6 +32,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; @@ -207,6 +210,24 @@ public void validateUniqueModelGroupName(String name, ActionListener listener) { + GetRequest getRequest = new GetRequest(); + getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + listener.onResponse(r); + } else { + listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); + } + }, e -> { listener.onFailure(e); })); + } + private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { if (input.getModelAccessMode() != null || input.getIsAddAllBackendRoles() != null diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index dd1deac4ab..009b142ec9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -8,7 +8,6 @@ import static org.opensearch.common.xcontent.XContentType.JSON; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -77,8 +76,8 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.TokenBucket; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -90,6 +89,7 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -120,6 +120,7 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; +import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.script.ScriptService; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; @@ -279,6 +280,9 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .version(version) .modelGroupId(mlRegisterModelMetaInput.getModelGroupId()) .description(mlRegisterModelMetaInput.getDescription()) + .quotaFlag(mlRegisterModelMetaInput.getQuotaFlag()) + .rateLimitNumber(mlRegisterModelMetaInput.getRateLimitNumber()) + .rateLimitUnit(mlRegisterModelMetaInput.getRateLimitUnit()) .modelFormat(mlRegisterModelMetaInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(mlRegisterModelMetaInput.getModelConfig()) @@ -289,7 +293,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .lastUpdateTime(now) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), EMPTY_PARAMS)); + indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(indexRequest, ActionListener.wrap(response -> { @@ -503,6 +507,9 @@ private void indexRemoteModel( .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) + .quotaFlag(registerModelInput.getQuotaFlag()) + .rateLimitNumber(registerModelInput.getRateLimitNumber()) + .rateLimitUnit(registerModelInput.getRateLimitUnit()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -561,6 +568,9 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) + .quotaFlag(registerModelInput.getQuotaFlag()) + .rateLimitNumber(registerModelInput.getRateLimitNumber()) + .rateLimitUnit(registerModelInput.getRateLimitUnit()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -622,6 +632,9 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .algorithm(functionName) .version(version) .description(registerModelInput.getDescription()) + .quotaFlag(registerModelInput.getQuotaFlag()) + .rateLimitNumber(registerModelInput.getRateLimitNumber()) + .rateLimitUnit(registerModelInput.getRateLimitUnit()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(registerModelInput.getModelConfig()) @@ -645,7 +658,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas log.error("Failed to index model meta doc", e); handleException(functionName, taskId, e); }); - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, listener)); }, e -> { log.error("Failed to init model index", e); @@ -701,6 +713,9 @@ private void registerModel( .algorithm(functionName) .version(version) .modelFormat(registerModelInput.getModelFormat()) + .quotaFlag(registerModelInput.getQuotaFlag()) + .rateLimitNumber(registerModelInput.getRateLimitNumber()) + .rateLimitUnit(registerModelInput.getRateLimitUnit()) .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) @@ -885,6 +900,55 @@ private void handleException(FunctionName functionName, String taskId, Exception mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } + public synchronized Map inplaceUpdateModel( + String modelId, + boolean updatePredictorFlag, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + getModel(modelId, ActionListener.wrap(mlModel -> { + int NodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; + modelCacheHelper.setRateLimiter(modelId, rateLimiterConstructor(NodeCount, mlModel)); + modelCacheHelper.setQuotaFlag(modelId, mlModel.getQuotaFlag()); + if (updatePredictorFlag) { + assert FunctionName.REMOTE == mlModel.getAlgorithm() + : "In-place update is only supported on REMOTE models at this time."; + Map params = ImmutableMap + .of( + ML_ENGINE, + mlEngine, + SCRIPT_SERVICE, + scriptService, + CLIENT, + client, + XCONTENT_REGISTRY, + xContentRegistry, + CLUSTER_SERVICE, + clusterService + ); + if (mlModel.getConnector() != null) { + Predictable predictable = mlEngine.deploy(mlModel, params); + modelCacheHelper.setPredictor(modelId, predictable); + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update internal connector for the model {}", modelId); + } else { + getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + Predictable predictable = mlEngine.deploy(mlModel, params); + modelCacheHelper.setPredictor(modelId, predictable); + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update stand-alone connector for the model {}", modelId); + }, wrappedListener::onFailure)); + } + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update for the model {}", modelId); + } + }, wrappedListener::onFailure)); + } + return null; + } + /** * Read model chunks from model index. Concat chunks into a whole model file, then load * into memory. @@ -921,6 +985,7 @@ public void deployModel( listener.onFailure(new IllegalArgumentException("Exceed max local model per node limit")); return; } + int eligibleNodeCount = workerNodes.size(); modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); @@ -944,33 +1009,17 @@ public void deployModel( ); // deploy remote model with internal connector or model trained by built-in algorithm like kmeans if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { - setupPredictable(modelId, mlModel, params); + setupPredictable(modelId, mlModel, params, eligibleNodeCount); wrappedListener.onResponse("successful"); return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - GetRequest getConnectorRequest = new GetRequest(); - FetchSourceContext fetchContext = new FetchSourceContext(true, null, null); - getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext); - // get connector and deploy remote model with standalone connector - client.get(getConnectorRequest, ActionListener.wrap(getResponse -> { - if (getResponse != null && getResponse.isExists()) { - try ( - XContentParser parser = createXContentParserFromRegistry( - xContentRegistry, - getResponse.getSourceAsBytesRef() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - mlModel.setConnector(connector); - setupPredictable(modelId, mlModel, params); - wrappedListener.onResponse("successful"); - log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); - } - } - }, e -> { wrappedListener.onFailure(e); })); - + getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupPredictable(modelId, mlModel, params, eligibleNodeCount); + wrappedListener.onResponse("successful"); + log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); + }, wrappedListener::onFailure)); return; } // check circuit breaker before deploying custom model chunks @@ -990,6 +1039,8 @@ public void deployModel( MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { modelCacheHelper.setMLExecutor(modelId, mlExecutable); + modelCacheHelper.setRateLimiter(modelId, rateLimiterConstructor(eligibleNodeCount, mlModel)); + modelCacheHelper.setQuotaFlag(modelId, mlModel.getQuotaFlag()); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); wrappedListener.onResponse("successful"); @@ -1003,6 +1054,8 @@ public void deployModel( Predictable predictable = mlEngine.deploy(mlModel, params); try { modelCacheHelper.setPredictor(modelId, predictable); + modelCacheHelper.setQuotaFlag(modelId, mlModel.getQuotaFlag()); + modelCacheHelper.setRateLimiter(modelId, rateLimiterConstructor(eligibleNodeCount, mlModel)); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); @@ -1044,13 +1097,32 @@ private void handleDeployModelException(String modelId, FunctionName functionNam listener.onFailure(e); } - private void setupPredictable(String modelId, MLModel mlModel, Map params) { + private void setupPredictable(String modelId, MLModel mlModel, Map params, Integer eligibleNodeCount) { Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); + modelCacheHelper.setRateLimiter(modelId, rateLimiterConstructor(eligibleNodeCount, mlModel)); + modelCacheHelper.setQuotaFlag(modelId, mlModel.getQuotaFlag()); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); } + private TokenBucket rateLimiterConstructor(Integer eligibleNodeCount, MLModel mlModel) { + if (mlModel != null && mlModel.getRateLimitNumber() != null && mlModel.getRateLimitUnit() != null) { + double rateLimitNumber = Double.parseDouble(mlModel.getRateLimitNumber()); + TimeUnit rateLimitUnit = mlModel.getRateLimitUnit(); + log + .debug( + "Initializing the rate limiter for Model {}, with TPS limit {} and burst capacity {}, evenly distributed on {} nodes", + mlModel.getModelId(), + rateLimitNumber / rateLimitUnit.toSeconds(1), + rateLimitNumber, + eligibleNodeCount + ); + return new TokenBucket(System::nanoTime, rateLimitNumber / rateLimitUnit.toNanos(1) / eligibleNodeCount, rateLimitNumber); + } + return null; + } + /** * Get model from model index. * @@ -1093,6 +1165,30 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio }, e -> { listener.onFailure(e); })); } + private void getConnector(Client client, String connectorId, ActionListener listener) { + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + }, e -> { + log.error("Failed to get connector", e); + listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); + })); + } + private void retrieveModelChunks(MLModel mlModelMeta, ActionListener listener) throws InterruptedException { String modelId = mlModelMeta.getModelId(); String modelName = mlModelMeta.getName(); @@ -1301,6 +1397,10 @@ public String[] getWorkerNodes(String modelId, FunctionName functionName, boolea return eligibleNodeIds; } + public int getWorkerNodesSize(String modelId, FunctionName functionName, boolean onlyEligibleNode) { + return getWorkerNodes(modelId, functionName, onlyEligibleNode).length; + } + /** * Get worker node of specific model without filtering eligible node. * @@ -1312,6 +1412,10 @@ public String[] getWorkerNodes(String modelId, FunctionName functionName) { return getWorkerNodes(modelId, functionName, false); } + public int getWorkerNodesSize(String modelId, FunctionName functionName) { + return getWorkerNodes(modelId, functionName, false).length; + } + /** * Get predictable instance with model id. * diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 52f83c572e..adb638a8f5 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -52,7 +52,9 @@ import org.opensearch.ml.action.model_group.TransportUpdateModelGroupAction; import org.opensearch.ml.action.models.DeleteModelTransportAction; import org.opensearch.ml.action.models.GetModelTransportAction; +import org.opensearch.ml.action.models.InPlaceUpdateModelTransportAction; import org.opensearch.ml.action.models.SearchModelTransportAction; +import org.opensearch.ml.action.models.UpdateModelTransportAction; import org.opensearch.ml.action.prediction.TransportPredictionTaskAction; import org.opensearch.ml.action.profile.MLProfileAction; import org.opensearch.ml.action.profile.MLProfileTransportAction; @@ -97,9 +99,11 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.forward.MLForwardAction; +import org.opensearch.ml.common.transport.model.MLInPlaceUpdateModelAction; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; @@ -166,6 +170,7 @@ import org.opensearch.ml.rest.RestMLTrainingAction; import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateConnectorAction; +import org.opensearch.ml.rest.RestMLUpdateModelAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -282,6 +287,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLUndeployModelsAction.INSTANCE, TransportUndeployModelsAction.class), new ActionHandler<>(MLRegisterModelMetaAction.INSTANCE, TransportRegisterModelMetaAction.class), new ActionHandler<>(MLUploadModelChunkAction.INSTANCE, TransportUploadModelChunkAction.class), + new ActionHandler<>(MLUpdateModelAction.INSTANCE, UpdateModelTransportAction.class), + new ActionHandler<>(MLInPlaceUpdateModelAction.INSTANCE, InPlaceUpdateModelTransportAction.class), new ActionHandler<>(MLForwardAction.INSTANCE, TransportForwardAction.class), new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class), new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), @@ -537,6 +544,7 @@ public List getRestHandlers( RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(); RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); + RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); @@ -558,6 +566,7 @@ public List getRestHandlers( restMLGetModelAction, restMLDeleteModelAction, restMLSearchModelAction, + restMLUpdateModelAction, restMLGetTaskAction, restMLDeleteTaskAction, restMLSearchTaskAction, @@ -687,7 +696,6 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, - MLCommonsSettings.ML_COMMONS_UPDATE_CONNECTOR_ENABLED, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED ); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index 0c31695d1d..b6e3822318 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; -import static org.opensearch.ml.utils.MLExceptionUtils.UPDATE_CONNECTOR_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -60,9 +59,7 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - if (!mlFeatureEnabledSetting.isUpdateConnectorEnabled()) { - throw new IllegalStateException(UPDATE_CONNECTOR_DISABLED_ERR_MSG); - } + if (!request.hasContent()) { throw new OpenSearchParseException("Failed to update connector: Request body is empty"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java new file mode 100644 index 0000000000..44e9fc2cb6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMLUpdateModelAction extends BaseRestHandler { + + private static final String ML_UPDATE_MODEL_ACTION = "ml_update_model_action"; + + @Override + public String getName() { + return ML_UPDATE_MODEL_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/models/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLUpdateModelRequest updateModelRequest = getRequest(request); + return channel -> client.execute(MLUpdateModelAction.INSTANCE, updateModelRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLUpdateModelRequest from a RestRequest + * + * @param request RestRequest + * @return MLUpdateModelRequest + */ + private MLUpdateModelRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Model update request has empty body"); + } + + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + try { + MLUpdateModelInput input = MLUpdateModelInput.parse(parser); + if (input.getConnectorId() != null && input.getConnectorUpdateContent() != null) { + throw new OpenSearchStatusException( + "Model cannot have both stand-alone connector and internal connector. Please check your update input body", + RestStatus.BAD_REQUEST + ); + } + // Model ID can only be set here. Model version as well as connector field can only be set automatically. + input.setModelId(modelId); + input.setVersion(null); + input.setConnector(null); + return new MLUpdateModelRequest(input); + } catch (IllegalStateException e) { + throw new OpenSearchParseException(e.getMessage()); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 03d4cf8647..bf200a3b02 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -114,9 +114,6 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting .boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_UPDATE_CONNECTOR_ENABLED = Setting - .boolSetting("plugins.ml_commons.update_connector.enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index 8f231a061e..0a1a00ac4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -8,7 +8,6 @@ package org.opensearch.ml.settings; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_UPDATE_CONNECTOR_ENABLED; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -16,18 +15,13 @@ public class MLFeatureEnabledSetting { private volatile Boolean isRemoteInferenceEnabled; - private volatile Boolean isUpdateConnectorEnabled; public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); - isUpdateConnectorEnabled = ML_COMMONS_UPDATE_CONNECTOR_ENABLED.get(settings); clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_UPDATE_CONNECTOR_ENABLED, it -> isUpdateConnectorEnabled = it); } /** @@ -38,8 +32,4 @@ public boolean isRemoteInferenceEnabled() { return isRemoteInferenceEnabled; } - public boolean isUpdateConnectorEnabled() { - return isUpdateConnectorEnabled; - } - } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 6f051615e3..da42d95382 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -20,8 +20,6 @@ public class MLExceptionUtils { public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: "; public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG = "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; - public static final String UPDATE_CONNECTOR_DISABLED_ERR_MSG = - "Update connector API is currently disabled. To enable it, update the setting \"plugins.ml_commons.update_connector.enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index a7fb34a4b5..c2cbf81cf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -150,7 +150,7 @@ public void testGetConnector_IndexNotFoundException() { getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to find connector", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); } public void testGetConnector_RuntimeException() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java similarity index 99% rename from plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index bb3d5ecebd..e1bbcfa881 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -20,6 +20,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -62,7 +63,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { +public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private UpdateConnectorTransportAction transportUpdateConnectorAction; @@ -201,6 +202,7 @@ public void setup() throws IOException { }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); } + @Test public void test_execute_connectorAccessControl_success() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -220,6 +222,7 @@ public void test_execute_connectorAccessControl_success() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_connectorAccessControl_NoPermission() { doReturn(false).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -232,6 +235,7 @@ public void test_execute_connectorAccessControl_NoPermission() { ); } + @Test public void test_execute_connectorAccessControl_AccessError() { doThrow(new RuntimeException("Connector Access Control Error")) .when(connectorAccessControlHelper) @@ -243,6 +247,7 @@ public void test_execute_connectorAccessControl_AccessError() { assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_connectorAccessControl_Exception() { doThrow(new RuntimeException("exception in access control")) .when(connectorAccessControlHelper) @@ -254,6 +259,7 @@ public void test_execute_connectorAccessControl_Exception() { assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_UpdateWrongStatus() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -274,6 +280,7 @@ public void test_execute_UpdateWrongStatus() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_UpdateException() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -295,6 +302,7 @@ public void test_execute_UpdateException() { assertEquals("update document failure", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchResponseNotEmpty() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -312,6 +320,7 @@ public void test_execute_SearchResponseNotEmpty() { ); } + @Test public void test_execute_SearchResponseError() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); @@ -327,6 +336,7 @@ public void test_execute_SearchResponseError() { assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchIndexNotFoundError() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java new file mode 100644 index 0000000000..9e5d34647a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -0,0 +1,861 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateModelTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + Task task; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + MLUpdateModelInput mockUpdateModelInput; + + @Mock + MLUpdateModelRequest mockUpdateModelRequest; + + @Mock + MLModel mockModel; + + @Mock + MLModelManager mlModelManager; + + @Mock + MLModelGroupManager mlModelGroupManager; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + Settings settings; + + @Mock + private ClusterService clusterService; + + @Mock + private MLEngine mlEngine; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private ShardId shardId; + + UpdateResponse updateResponse; + + UpdateModelTransportAction transportUpdateModelAction; + + MLUpdateModelRequest updateLocalModelRequest; + + MLUpdateModelInput updateLocalModelInput; + + MLUpdateModelRequest updateRemoteModelRequest; + + MLUpdateModelInput updateRemoteModelInput; + + MLModel mlModelWithNullFunctionName; + + MLModel localModel; + + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + updateLocalModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .build(); + updateLocalModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateLocalModelInput).build(); + updateRemoteModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .connectorId("updated_test_connector_id") + .build(); + updateRemoteModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); + + mlModelWithNullFunctionName = MLModel + .builder() + .modelId("test_model_id") + .name("test_name") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .build(); + + Settings settings = Settings.builder().build(); + + transportUpdateModelAction = spy( + new UpdateModelTransportAction( + transportService, + actionFilters, + client, + connectorAccessControlHelper, + modelAccessControlHelper, + mlModelManager, + mlModelGroupManager, + settings, + clusterService, + mlEngine + ) + ); + + localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), eq("updated_test_connector_id"), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + } + + @Test + public void testUpdateLocalModelSuccess() { + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelStateLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateLoadingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithLocalInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { + MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModelWithInternalConnector); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: updated_test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") + ); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other connector access control Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model, model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during update the model. Please check log for more details." + ) + ); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during update the model. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details." + ) + ); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model to update with the provided model id: test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelWithFunctionNameFieldNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModelWithNullFunctionName); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + @Test + public void testUpdateLocalModelWithRemoteInformation() { + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Trying to update the connector or connector_id field on a local model", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithUnsupportedFunction() { + MLModel localModelWithUnsupportedFunction = prepareUnsupportedMLModel(FunctionName.KMEANS); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModelWithUnsupportedFunction); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this function category: KMEANS", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRequestDocIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doReturn("mockUpdateModelGroupId").when(mockUpdateModelInput).getModelGroupId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetUpdateResponseListenerWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + // TODO: Add UT to make sure that version incremented successfully. + + private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { + MLModel mlModel; + switch (functionName) { + case TEXT_EMBEDDING: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.TEXT_EMBEDDING) + .build(); + return mlModel; + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connectorId("test_connector_id") + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.TEXT_EMBEDDING and FunctionName.REMOTE"); + } + } + + private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { + MLModel mlModel; + switch (unsupportedCase) { + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); + return mlModel; + case KMEANS: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); + } + } + + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 7f48d9f32c..30c9f6191c 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -42,7 +43,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -207,7 +207,7 @@ public void test_validateConnectorAccess_connectorNotFound_return_false() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener, times(1)).onFailure(any(MLResourceNotFoundException.class)); + verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } public void test_validateConnectorAccess_searchConnectorException_return_false() { @@ -222,7 +222,7 @@ public void test_validateConnectorAccess_searchConnectorException_return_false() threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener).onFailure(any(IllegalStateException.class)); + verify(actionListener).onFailure(any(OpenSearchStatusException.class)); } public void test_skipConnectorAccessControl_userIsNull_return_true() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f7eb759026..ccedef9bc1 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -23,35 +23,37 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Mock - private TransportService transportService; - @Mock private MLIndicesHandler mlIndicesHandler; @@ -61,26 +63,23 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock private ThreadPool threadPool; - @Mock - private Task task; - @Mock private Client client; - @Mock - private ActionFilters actionFilters; @Mock private ActionListener actionListener; + @Mock + private ActionListener modelGroupListener; + @Mock private IndexResponse indexResponse; ThreadContext threadContext; - private TransportRegisterModelGroupAction transportRegisterModelGroupAction; - @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock private MLModelGroupManager mlModelGroupManager; @@ -335,6 +334,61 @@ public void test_ExceptionInitModelGroupIndexIfAbsent() { assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } + public void test_SuccessGetModelGroup() throws IOException { + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("testModelGroupID") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + verify(modelGroupListener).onResponse(getResponse); + } + + public void test_OtherExceptionGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException("Any other Exception occurred during getting the model group. Please check log for more details.") + ); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during getting the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_NotFoundGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelGroupInput prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { return MLRegisterModelGroupInput .builder() @@ -363,4 +417,10 @@ private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOE return searchResponse; } + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index d880baaa2e..1c6a3d2ae7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -67,7 +67,6 @@ public void setup() { client = spy(new NodeClient(Settings.EMPTY, threadPool)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); - when(mlFeatureEnabledSetting.isUpdateConnectorEnabled()).thenReturn(true); restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java new file mode 100644 index 0000000000..28687d1c9c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMLUpdateModelActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateModelAction restMLUpdateModelAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + restMLUpdateModelAction = new RestMLUpdateModelAction(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateModelAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testConstructor() { + RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(); + assertNotNull(UpdateModelAction); + } + + @Test + public void testGetName() { + String actionName = restMLUpdateModelAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_model_action", actionName); + } + + @Test + public void testRoutes() { + List routes = restMLUpdateModelAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); + } + + @Test + public void testUpdateModelRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateModelInput updateModelInput = argumentCaptor.getValue().getUpdateModelInput(); + assertEquals("testModelName", updateModelInput.getName()); + assertEquals("This is test description", updateModelInput.getDescription()); + } + + @Test + public void testUpdateModelRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Model update request has empty body"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithNullModelId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain model_id"); + RestRequest request = getRestRequestWithNullModelId(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithNullField() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); + RestRequest request = getRestRequestWithNullField(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.PUT; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullModelId() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullField() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"name\":\"testModelName\",\"description\":null}"; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}