diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index fef1af1196..80b488d418 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -383,7 +383,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String oldContent = null; User user = null; - String description = null;; + String description = null; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java index e2a0cbf084..685fc43cf7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java @@ -36,7 +36,7 @@ public MLDeployModelNodeResponse(DiscoveryNode node, Map modelDe public MLDeployModelNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelDeployStatus = in.readMap(s -> s.readString(), s-> s.readString()); + this.modelDeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); } } @@ -58,7 +58,7 @@ public void writeTo(StreamOutput out) throws IOException { if (modelDeployStatus != null && modelDeployStatus.size() > 0) { out.writeBoolean(true); - out.writeMap(modelDeployStatus, (stream, v) -> stream.writeString(v), (stream, stats) -> stream.writeString(stats)); + out.writeMap(modelDeployStatus, StreamOutput::writeString, StreamOutput::writeString); } else { out.writeBoolean(false); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index ca0a2f70d4..adfdd4f307 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -17,23 +17,27 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import java.io.IOException; -import java.util.Map; +import java.time.Instant; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.connector.Connector.createConnector; @Data public class MLUpdateModelInput implements ToXContentObject, Writeable { - public static final String MODEL_ID_FIELD = "model_id"; // mandatory + public static final String MODEL_ID_FIELD = "model_id"; // passively set when passing url to rest API public static final String DESCRIPTION_FIELD = "description"; // optional - public static final String MODEL_VERSION_FIELD = "model_version"; // optional + public static final String MODEL_VERSION_FIELD = "model_version"; // passively set when register model to a new model group public static final String MODEL_NAME_FIELD = "name"; // optional public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional public static final String MODEL_CONFIG_FIELD = "model_config"; // optional + public static final String CONNECTOR_FIELD = "connector"; // optional public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional + // The field CONNECTOR_UPDATE_CONTENT_FIELD need to be declared because the update of Connector class relies on the MLCreateConnectorInput class + public static final String CONNECTOR_UPDATE_CONTENT_FIELD = "connector_update_content"; + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update request @Getter private String modelId; @@ -42,29 +46,44 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private String name; private String modelGroupId; private MLModelConfig modelConfig; + private Connector connector; private String connectorId; + private MLCreateConnectorInput connectorUpdateContent; + private Instant lastUpdateTime; @Builder(toBuilder = true) - public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { + public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, + MLModelConfig modelConfig, Connector connector, String connectorId, + MLCreateConnectorInput connectorUpdateContent, Instant lastUpdateTime) { this.modelId = modelId; this.description = description; this.version = version; this.name = name; this.modelGroupId = modelGroupId; this.modelConfig = modelConfig; + this.connector = connector; this.connectorId = connectorId; + this.connectorUpdateContent = connectorUpdateContent; + this.lastUpdateTime = lastUpdateTime; } public MLUpdateModelInput(StreamInput in) throws IOException { - this.modelId = in.readString(); - this.description = in.readOptionalString(); - this.version = in.readOptionalString(); - this.name = in.readOptionalString(); - this.modelGroupId = in.readOptionalString(); + modelId = in.readString(); + description = in.readOptionalString(); + version = in.readOptionalString(); + name = in.readOptionalString(); + modelGroupId = in.readOptionalString(); if (in.readBoolean()) { modelConfig = new TextEmbeddingModelConfig(in); } - this.connectorId = in.readOptionalString(); + if (in.readBoolean()) { + connector = Connector.fromStream(in); + } + connectorId = in.readOptionalString(); + if (in.readBoolean()) { + connectorUpdateContent = new MLCreateConnectorInput(in); + } + lastUpdateTime = in.readOptionalInstant(); } @Override @@ -86,9 +105,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelConfig != null) { builder.field(MODEL_CONFIG_FIELD, modelConfig); } + if (connector != null) { + builder.field(CONNECTOR_FIELD, connector); + } if (connectorId != null) { builder.field(CONNECTOR_ID_FIELD, connectorId); } + if (connectorUpdateContent != null) { + builder.field(CONNECTOR_UPDATE_CONTENT_FIELD, connectorUpdateContent); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } builder.endObject(); return builder; } @@ -106,7 +134,20 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (connector != null) { + out.writeBoolean(true); + connector.writeTo(out); + } else { + out.writeBoolean(false); + } out.writeOptionalString(connectorId); + if (connectorUpdateContent != null) { + out.writeBoolean(true); + connectorUpdateContent.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalInstant(lastUpdateTime); } public static MLUpdateModelInput parse(XContentParser parser) throws IOException { @@ -116,7 +157,10 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException String name = null; String modelGroupId = null; MLModelConfig modelConfig = null; + Connector connector = null; String connectorId = null; + MLCreateConnectorInput connectorUpdateContent = null; + Instant lastUpdateTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -141,15 +185,24 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case MODEL_CONFIG_FIELD: modelConfig = TextEmbeddingModelConfig.parse(parser); break; + case CONNECTOR_FIELD: + connector = Connector.createConnector(parser); + break; case CONNECTOR_ID_FIELD: connectorId = parser.text(); break; + case CONNECTOR_UPDATE_CONTENT_FIELD: + connectorUpdateContent = MLCreateConnectorInput.parse(parser, true); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + break; default: parser.skipChildren(); break; } } // Model ID can only be set through RestRequest. Model version can only be set automatically. - return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connector, connectorId, connectorUpdateContent, lastUpdateTime); } } \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java index 2af72a6d6a..99a7f39882 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java @@ -39,10 +39,10 @@ public MLUndeployModelNodeResponse(DiscoveryNode node, public MLUndeployModelNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelUndeployStatus = in.readMap(s -> s.readString(), s-> s.readString()); + this.modelUndeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); } if (in.readBoolean()) { - this.modelWorkerNodeBeforeRemoval = in.readMap(s -> s.readString(), s-> s.readOptionalStringArray()); + this.modelWorkerNodeBeforeRemoval = in.readMap(StreamInput::readString, StreamInput::readOptionalStringArray); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java new file mode 100644 index 0000000000..ef9d6ec063 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.opensearch.action.ActionType; + +public class MLUpdateModelCacheAction extends ActionType { + public static final MLUpdateModelCacheAction INSTANCE = new MLUpdateModelCacheAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/update_model_cache"; + + private MLUpdateModelCacheAction() { super(NAME, MLUpdateModelCacheNodesResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeRequest.java new file mode 100644 index 0000000000..7d79a3a9d4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.opensearch.transport.TransportRequest; +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLUpdateModelCacheNodeRequest extends TransportRequest { + @Getter + private MLUpdateModelCacheNodesRequest updateModelCacheNodesRequest; + + public MLUpdateModelCacheNodeRequest(StreamInput in) throws IOException { + super(in); + this.updateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(in); + } + + public MLUpdateModelCacheNodeRequest(MLUpdateModelCacheNodesRequest request) { + this.updateModelCacheNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + updateModelCacheNodesRequest.writeTo(out); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java new file mode 100644 index 0000000000..35a642c33c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +@Getter +@Log4j2 +public class MLUpdateModelCacheNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map modelUpdateStatus; + + public MLUpdateModelCacheNodeResponse(DiscoveryNode node, Map modelUpdateStatus) { + super(node); + this.modelUpdateStatus = modelUpdateStatus; + } + + public MLUpdateModelCacheNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.modelUpdateStatus = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + + public static MLUpdateModelCacheNodeResponse readStats(StreamInput in) throws IOException { + return new MLUpdateModelCacheNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + + if (!isModelUpdateStatusEmpty()) { + out.writeBoolean(true); + out.writeMap(modelUpdateStatus, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Similar to deploy on node or undeploy on node response, this stat map is used to track update status on each node. + builder.startObject("stats"); + if (!isModelUpdateStatusEmpty()) { + for (Map.Entry stat : modelUpdateStatus.entrySet()) { + builder.field(stat.getKey(), stat.getValue()); + } + } + builder.endObject(); + return builder; + } + + public boolean isModelUpdateStatusEmpty() { + return modelUpdateStatus == null || modelUpdateStatus.size() == 0; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java new file mode 100644 index 0000000000..566b838632 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import java.io.IOException; + +public class MLUpdateModelCacheNodesRequest extends BaseNodesRequest { + + @Getter + private String modelId; + @Getter + private boolean predictorUpdate; + + public MLUpdateModelCacheNodesRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.predictorUpdate = in.readBoolean(); + } + + public MLUpdateModelCacheNodesRequest(String[] nodeIds, String modelId, boolean predictorUpdate) { + super(nodeIds); + this.modelId = modelId; + this.predictorUpdate = predictorUpdate; + } + + public MLUpdateModelCacheNodesRequest(DiscoveryNode[] nodeIds, String modelId, boolean predictorUpdate) { + super(nodeIds); + this.modelId = modelId; + this.predictorUpdate = predictorUpdate; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeBoolean(predictorUpdate); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java new file mode 100644 index 0000000000..6e26174eb6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class MLUpdateModelCacheNodesResponse extends BaseNodesResponse implements ToXContentObject { + + public MLUpdateModelCacheNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLUpdateModelCacheNodeResponse::readStats), in.readList(FailedNodeException::new)); + } + + public MLUpdateModelCacheNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLUpdateModelCacheNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + String nodeId; + DiscoveryNode node; + builder.startObject(); + for (MLUpdateModelCacheNodeResponse updateStats : getNodes()) { + if (!updateStats.isModelUpdateStatusEmpty()) { + node = updateStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + updateStats.toXContent(builder, params); + builder.endObject(); + } + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 6bafe81692..ed1f5568eb 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import java.util.Map; import java.util.function.Consumer; import org.junit.Before; @@ -29,6 +30,11 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.search.SearchModule; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -38,22 +44,36 @@ public class MLUpdateModelInputTest { private MLUpdateModelInput updateModelInput; private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\"}}"; private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; - private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}}"; + private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\"},\"illegal_field\":\"This field need to be skipped.\"}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { - MLModelConfig config = TextEmbeddingModelConfig.builder() .modelType("testModelType") .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") @@ -61,6 +81,35 @@ public void setUp() throws Exception { .embeddingDimension(100) .build(); + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("1") + .description("updated description") + .build(); + updateModelInput = MLUpdateModelInput.builder() .modelId("test-model_id") .modelGroupId("modelGroupId") @@ -68,12 +117,14 @@ public void setUp() throws Exception { .name("name") .description("description") .modelConfig(config) + .connector(connector) .connectorId("test-connector_id") + .connectorUpdateContent(updateContent) .build(); } @Test - public void readInputStream_Success() throws IOException { + public void readInputStreamSuccess() throws IOException { readInputStream(updateModelInput, parsedInput -> { assertEquals("test-model_id", parsedInput.getModelId()); assertEquals(updateModelInput.getName(), parsedInput.getName()); @@ -81,7 +132,7 @@ public void readInputStream_Success() throws IOException { } @Test - public void readInputStream_SuccessWithNullFields() throws IOException { + public void readInputStreamSuccessWithNullFields() throws IOException { updateModelInput.setModelConfig(null); readInputStream(updateModelInput, parsedInput -> { assertNull(parsedInput.getModelConfig()); @@ -95,7 +146,7 @@ public void testToXContent() throws Exception { } @Test - public void testToXContent_Incomplete() throws Exception { + public void testToXContentIncomplete() throws Exception { String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; updateModelInput.setDescription(null); @@ -103,20 +154,22 @@ public void testToXContent_Incomplete() throws Exception { updateModelInput.setName(null); updateModelInput.setModelGroupId(null); updateModelInput.setModelConfig(null); + updateModelInput.setConnector(null); updateModelInput.setConnectorId(null); + updateModelInput.setConnectorUpdateContent(null); String jsonStr = serializationWithToXContent(updateModelInput); assertEquals(expectedIncompleteInputStr, jsonStr); } @Test - public void parse_Success() throws Exception { + public void parseSuccess() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("name", parsedInput.getName()); }); } @Test - public void parse_WithNullFieldWithoutModel() throws Exception { + public void parseWithNullFieldWithoutModel() throws Exception { exceptionRule.expect(IllegalStateException.class); testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { @@ -128,7 +181,7 @@ public void parse_WithNullFieldWithoutModel() throws Exception { } @Test - public void parse_WithIllegalFieldWithoutModel() throws Exception { + public void parseWithIllegalFieldWithoutModel() throws Exception { testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java index cadf865b1c..ef0298df27 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -10,14 +10,13 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.junit.Test; -import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestRequest; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java new file mode 100644 index 0000000000..e92faef338 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.transport.TransportAddress; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +@RunWith(MockitoJUnitRunner.class) +public class MLUpdateModelCacheNodeResponseTest { + + @Mock + private DiscoveryNode localNode; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + } + + @Test + public void testSerializationDeserialization() throws IOException { + Map updateModelCacheStatus = new HashMap<>(); + updateModelCacheStatus.put("modelName:version", "response"); + MLUpdateModelCacheNodeResponse response = new MLUpdateModelCacheNodeResponse(localNode, updateModelCacheStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUpdateModelCacheNodeResponse newResponse = new MLUpdateModelCacheNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testSerializationDeserializationNullModelUpdateModelCacheStatus() throws IOException { + MLUpdateModelCacheNodeResponse response = new MLUpdateModelCacheNodeResponse(localNode, null); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUpdateModelCacheNodeResponse newResponse = new MLUpdateModelCacheNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testReadProfile() throws IOException { + MLUpdateModelCacheNodeResponse response = new MLUpdateModelCacheNodeResponse(localNode, new HashMap<>()); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUpdateModelCacheNodeResponse newResponse = MLUpdateModelCacheNodeResponse.readStats(output.bytes().streamInput()); + assertNotEquals(newResponse, response); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java new file mode 100644 index 0000000000..b78cd6f263 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +@RunWith(MockitoJUnitRunner.class) +public class MLUpdateModelCacheNodesRequestTest { + + @Test + public void testConstructorSerialization1() throws IOException { + String modelId = "testModelId"; + String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + + MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( + new MLUpdateModelCacheNodesRequest(nodeIds, modelId, true) + ); + BytesStreamOutput output = new BytesStreamOutput(); + + updateModelCacheNodeRequest.writeTo(output); + assertEquals("testModelId", updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().getModelId()); + } + + @Test + public void testConstructorSerialization2() { + String modelId = "testModelId"; + DiscoveryNode localNode1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + DiscoveryNode localNode2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + DiscoveryNode[] nodes = {localNode1, localNode2}; + MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( + new MLUpdateModelCacheNodesRequest(nodes, modelId, true) + ); + assertEquals(2, updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().concreteNodes().length); + } + + @Test + public void testConstructorFromInputStream() throws IOException { + String modelId = "testModelId"; + String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + + MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( + new MLUpdateModelCacheNodesRequest(nodeIds, modelId, true) + ); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + updateModelCacheNodeRequest.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUpdateModelCacheNodeRequest parsedNodeRequest = new MLUpdateModelCacheNodeRequest(streamInput); + + assertEquals(updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().getModelId(), parsedNodeRequest.getUpdateModelCacheNodesRequest().getModelId()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java new file mode 100644 index 0000000000..f3d1ac668e --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.update_cache; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +@RunWith(MockitoJUnitRunner.class) +public class MLUpdateModelCacheNodesResponseTest { + + @Mock + private ClusterName clusterName; + private DiscoveryNode node1; + private DiscoveryNode node2; + private Map modelWorkerNodeCounts; + + @Before + public void setUp() throws Exception { + clusterName = new ClusterName("clusterName"); + node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", 1); + } + + @Test + public void testSerializationDeserialization1() throws IOException { + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUpdateModelCacheNodesResponse response = new MLUpdateModelCacheNodesResponse(clusterName, responseList, failuresList); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUpdateModelCacheNodesResponse newResponse = new MLUpdateModelCacheNodesResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNodes().size(), response.getNodes().size()); + } + + @Test + public void testToXContent() throws IOException { + List nodes = new ArrayList<>(); + + Map updateModelCacheStatus1 = new HashMap<>(); + updateModelCacheStatus1.put("modelId1", "response"); + Map modelWorkerNodeCounts1 = new HashMap<>(); + modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); + nodes.add(new MLUpdateModelCacheNodeResponse(node1, updateModelCacheStatus1)); + + Map updateModelCacheStatus2 = new HashMap<>(); + updateModelCacheStatus2.put("modelId2", "response"); + Map modelWorkerNodeCounts2 = new HashMap<>(); + modelWorkerNodeCounts2.put("modelId2", new String[]{"mockNode2"}); + nodes.add(new MLUpdateModelCacheNodeResponse(node2, updateModelCacheStatus2)); + + List failures = new ArrayList<>(); + MLUpdateModelCacheNodesResponse response = new MLUpdateModelCacheNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals( + "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", + jsonStr + ); + } + + @Test + public void testNullUpdateModelCacheStatusToXContent() throws IOException { + List nodes = new ArrayList<>(); + Map modelWorkerNodeCounts1 = new HashMap<>(); + modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); + nodes.add(new MLUpdateModelCacheNodeResponse(node1, null)); + List failures = new ArrayList<>(); + MLUpdateModelCacheNodesResponse response = new MLUpdateModelCacheNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals("{}",jsonStr); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 066ca5f8a7..970d94aa48 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -163,7 +163,7 @@ private ActionListener getUpdateResponseListener( ) { return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.info("Failed to update the connector with ID: {}", connectorId); + log.error("Failed to update the connector with ID: {}", connectorId); actionListener.onResponse(updateResponse); return; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 124145b581..99e450567a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -65,7 +65,6 @@ import org.opensearch.transport.TransportService; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -288,7 +287,7 @@ private void deployModel( mlTaskManager .updateMLTask( taskId, - ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), TASK_SEMAPHORE_TIMEOUT, true ); @@ -333,7 +332,7 @@ void deployRemoteModel( mlModelManager .updateModel( mlModel.getModelId(), - ImmutableMap + Map .of( MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOYING, @@ -359,7 +358,7 @@ private ActionListener deployModelNodesResponseListe ) { return ActionListener.wrap(r -> { if (mlTaskManager.contains(taskId)) { - mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); + mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); } listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name())); }, e -> { @@ -367,11 +366,11 @@ private ActionListener deployModelNodesResponseListe mlTaskManager .updateMLTask( taskId, - ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED), + Map.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED), TASK_SEMAPHORE_TIMEOUT, true ); - mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)); + mlModelManager.updateModel(modelId, Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)); listener.onFailure(e); }); } @@ -401,25 +400,25 @@ void updateModelDeployStatusAndTriggerOnNodesAction( ); ActionListener actionListener = ActionListener.wrap(r -> { if (mlTaskManager.contains(taskId)) { - mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); + mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); } }, e -> { log.error("Failed to deploy model " + modelId, e); mlTaskManager .updateMLTask( taskId, - ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED), + Map.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED), TASK_SEMAPHORE_TIMEOUT, true ); - mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)); + mlModelManager.updateModel(modelId, Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)); }); List workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList()); mlModelManager .updateModel( modelId, - ImmutableMap + Map .of( MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOYING, diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 5e88c50ef0..aa71c59570 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -9,9 +9,14 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -25,6 +30,7 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; @@ -38,11 +44,16 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; @@ -58,7 +69,7 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateModelTransportAction extends HandledTransportAction { Client client; @@ -68,6 +79,8 @@ public class UpdateModelTransportAction extends HandledTransportAction trustedConnectorEndpointsRegex; @Inject public UpdateModelTransportAction( @@ -79,7 +92,8 @@ public UpdateModelTransportAction( MLModelManager mlModelManager, MLModelGroupManager mlModelGroupManager, Settings settings, - ClusterService clusterService + ClusterService clusterService, + MLEngine mlEngine ) { super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); this.client = client; @@ -88,7 +102,12 @@ public UpdateModelTransportAction( this.mlModelManager = mlModelManager; this.mlModelGroupManager = mlModelGroupManager; this.clusterService = clusterService; + this.mlEngine = mlEngine; this.settings = settings; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); } @Override @@ -102,66 +121,80 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { - FunctionName functionName = mlModel.getAlgorithm(); - MLModelState mlModelState = mlModel.getModelState(); - - if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { - if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { - if (isSuperAdmin) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model, model ID " + modelId, - RestStatus.FORBIDDEN - ) + if (!isModelDeploying(mlModel.getModelState())) { + boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); + FunctionName functionName = mlModel.getAlgorithm(); + // TODO: Support update as well as model/user level throttling in all other DLModel categories + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { + if (isSuperAdmin) { + updateRemoteOrTextEmbeddingModel( + modelId, + updateModelInput, + mlModel, + user, + wrappedListener, + isModelDeployed ); - } - } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - if (isModelDeployed(mlModelState)) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, + RestStatus.FORBIDDEN + ) + ); + } + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + updateRemoteOrTextEmbeddingModel( + modelId, + updateModelInput, + mlModel, + user, + wrappedListener, + isModelDeployed + ); } else { - actionListener + wrappedListener .onFailure( new OpenSearchStatusException( - "ML Model " - + modelId - + " is in deploying or deployed state, please undeploy the models first!", + "User doesn't have privilege to perform this operation on this model, model ID " + + modelId, RestStatus.FORBIDDEN ) ); } - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model, model ID " + modelId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - actionListener.onFailure(exception); - })); - } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + wrappedListener.onFailure(exception); + })); + } + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "The function category " + functionName.toString() + " is not supported at this time.", + RestStatus.FORBIDDEN + ) + ); + } } else { - actionListener + wrappedListener .onFailure( - new MLValidationException( - "User doesn't have privilege to perform this operation on this function category: " - + functionName.toString() + new OpenSearchStatusException( + "Model is deploying, please wait for it complete. model ID " + modelId, + RestStatus.CONFLICT ) ); } }, - e -> actionListener + e -> wrappedListener .onFailure( new OpenSearchStatusException( "Failed to find model to update with the provided model id: " + modelId, @@ -180,81 +213,101 @@ private void updateRemoteOrTextEmbeddingModel( MLUpdateModelInput updateModelInput, MLModel mlModel, User user, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + boolean isModelDeployed ) { String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; - String relinkConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; + String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { - if (relinkConnectorId == null) { - updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); + if (newConnectorId == null && updateModelInput.getConnectorUpdateContent() == null) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + wrappedListener, + isModelDeployed + ); } else { - actionListener + wrappedListener .onFailure( new OpenSearchStatusException( - "Trying to update the connector or connector_id field on a local model", + "Trying to update the connector or connector_id field on a local model.", RestStatus.BAD_REQUEST ) ); } } else { // mlModel.getAlgorithm() == REMOTE - if (relinkConnectorId == null) { - updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); + if (newConnectorId == null) { + if (updateModelInput.getConnectorUpdateContent() != null) { + Connector connector = mlModel.getConnector(); + connector.update(updateModelInput.getConnectorUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + updateModelInput.setConnector(connector); + updateModelInput.setConnectorUpdateContent(null); + } + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + wrappedListener, + isModelDeployed + ); } else { - updateModelWithRelinkStandAloneConnector( + updateModelWithNewStandAloneConnector( modelId, newModelGroupId, - relinkConnectorId, + newConnectorId, mlModel, user, updateModelInput, - actionListener, - context + wrappedListener, + isModelDeployed ); } } } - private void updateModelWithRelinkStandAloneConnector( + private void updateModelWithNewStandAloneConnector( String modelId, String newModelGroupId, - String relinkConnectorId, + String newConnectorId, MLModel mlModel, User user, MLUpdateModelInput updateModelInput, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + boolean isModelDeployed ) { if (Strings.hasLength(mlModel.getConnectorId())) { - connectorAccessControlHelper - .validateConnectorAccess(client, relinkConnectorId, ActionListener.wrap(hasRelinkConnectorPermission -> { - if (hasRelinkConnectorPermission) { - updateModelWithRegisteringToAnotherModelGroup( - modelId, - newModelGroupId, - user, - updateModelInput, - actionListener, - context + connectorAccessControlHelper.validateConnectorAccess(client, newConnectorId, ActionListener.wrap(hasNewConnectorPermission -> { + if (hasNewConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + wrappedListener, + isModelDeployed + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + newConnectorId, + RestStatus.FORBIDDEN + ) ); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to update the connector, connector id: " + relinkConnectorId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", relinkConnectorId, exception); - actionListener.onFailure(exception); - })); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", newConnectorId, exception); + wrappedListener.onFailure(exception); + })); } else { - actionListener + wrappedListener .onFailure( new OpenSearchStatusException( "This remote does not have a connector_id field, maybe it uses an internal connector.", @@ -269,49 +322,55 @@ private void updateModelWithRegisteringToAnotherModelGroup( String newModelGroupId, User user, MLUpdateModelInput updateModelInput, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + boolean isModelDeployed ) { UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); + // This flag is used to decide if we need to re-deploy the predictor(model) when performing the in-place update + boolean isPredictorUpdate = (updateModelInput.getConnector() != null || updateModelInput.getConnectorId() != null); + // This flag is used to decide if we need to perform an in-place update + boolean isUpdateModelCache = isModelDeployed && isPredictorUpdate; if (newModelGroupId != null) { - modelAccessControlHelper.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasRelinkPermission -> { - if (hasRelinkPermission) { - mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { - updateRequestConstructor( - modelId, - newModelGroupId, - updateRequest, - updateModelInput, - newModelGroupResponse, - actionListener, - context - ); - }, - exception -> actionListener + modelAccessControlHelper + .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { + if (hasNewModelGroupPermission) { + mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + updateRequestConstructor( + modelId, + newModelGroupId, + updateRequest, + updateModelInput, + newModelGroupResponse, + wrappedListener, + isUpdateModelCache, + isPredictorUpdate + ); + }, + exception -> wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + + newModelGroupId, + RestStatus.NOT_FOUND + ) + ) + )); + } else { + wrappedListener .onFailure( new OpenSearchStatusException( - "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + newModelGroupId, - RestStatus.NOT_FOUND + RestStatus.FORBIDDEN ) - ) - )); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " - + newModelGroupId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - actionListener.onFailure(exception); - })); + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + wrappedListener.onFailure(exception); + })); } else { - updateRequestConstructor(modelId, updateRequest, updateModelInput, actionListener, context); + updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache, isPredictorUpdate); } } @@ -319,16 +378,33 @@ private void updateRequestConstructor( String modelId, UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + boolean isUpdateModelCache, + boolean isPredictorUpdate ) { try { + updateModelInput.setLastUpdateTime(Instant.now()); updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); - client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + if (isUpdateModelCache) { + String[] targetNodeIds = getAllNodes(); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( + targetNodeIds, + modelId, + isPredictorUpdate + ); + client + .update( + updateRequest, + getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest) + ); + } else { + client.update(updateRequest, getUpdateResponseListener(modelId, wrappedListener)); + } } catch (IOException e) { - log.error("Failed to build update request."); - actionListener.onFailure(e); + log.error("Failed to build update request.", e); + wrappedListener.onFailure(e); } } @@ -338,12 +414,14 @@ private void updateRequestConstructor( UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, GetResponse newModelGroupResponse, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + boolean isUpdateModelCache, + boolean isPredictorUpdate ) { Map newModelGroupSourceMap = newModelGroupResponse.getSourceAsMap(); String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); updateModelInput.setVersion(updatedVersion); + updateModelInput.setLastUpdateTime(Instant.now()); UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( newModelGroupSourceMap, newModelGroupId, @@ -351,43 +429,118 @@ private void updateRequestConstructor( newModelGroupResponse.getPrimaryTerm(), Integer.parseInt(updatedVersion) ); + updateModelGroupRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); try { updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); - }, e -> { - log - .error( - "Failed to register ML model with model ID {} to the new model group with model group ID {}", - modelId, - newModelGroupId - ); - actionListener.onFailure(e); - })); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + if (isUpdateModelCache) { + String[] targetNodeIds = getAllNodes(); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( + targetNodeIds, + modelId, + isPredictorUpdate + ); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client + .update( + updateRequest, + getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest) + ); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId, + e + ); + wrappedListener.onFailure(e); + })); + } else { + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client.update(updateRequest, getUpdateResponseListener(modelId, wrappedListener)); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId, + e + ); + wrappedListener.onFailure(e); + })); + } } catch (IOException e) { log.error("Failed to build update request."); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } } - private ActionListener getUpdateResponseListener( + private ActionListener getUpdateResponseListenerWithUpdateModelCache( String modelId, - ActionListener actionListener, - ThreadContext.StoredContext context + ActionListener wrappedListener, + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest ) { - return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { - if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.info("Model id:{} failed update", modelId); - actionListener.onResponse(updateResponse); - return; + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + client.execute(MLUpdateModelCacheAction.INSTANCE, mlUpdateModelCacheNodesRequest, ActionListener.wrap(r -> { + if (isUpdateModelCacheSuccessOnAllNodes(modelId, r)) { + log.info("Successfully updated ML model cache with model ID {}", modelId); + wrappedListener.onResponse(updateResponse); + } else { + String[] nodeIds = getUpdateModelCacheFailedNodesList(modelId, r); + log + .error( + "Successfully update ML model index with model ID {} but update model cache was failed on following nodes {}, maybe retry?", + modelId, + Arrays.toString(nodeIds) + ); + wrappedListener + .onFailure( + new RuntimeException( + "Successfully update ML model index with model ID" + + modelId + + "but update model cache was failed on following nodes " + + Arrays.toString(nodeIds) + + ", maybe retry?" + ) + ); + } + }, e -> { + log.error("Failed to update ML model cache for model: " + modelId, e); + wrappedListener.onFailure(e); + })); + } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + // The update response returned an unexpected status may indicate a failed update + log.warn("Model id:{} failed update with result {}", modelId, updateResponse.getResult()); + wrappedListener.onResponse(updateResponse); + } else { + log.error("Failed to update ML model: " + modelId); + wrappedListener.onFailure(new RuntimeException("Failed to update ML model: " + modelId)); + } + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + wrappedListener.onFailure(exception); + }); + } + + private ActionListener getUpdateResponseListener(String modelId, ActionListener wrappedListener) { + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log.info("Successfully update ML model with model ID {}", modelId); + wrappedListener.onResponse(updateResponse); + } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.warn("Model id:{} failed update with result {}", modelId, updateResponse.getResult()); + wrappedListener.onResponse(updateResponse); + } else { + log.error("Failed to update ML model: " + modelId); + wrappedListener.onFailure(new RuntimeException("Failed to update ML model: " + modelId)); } - log.info("Successfully update ML model with model ID {}", modelId); - actionListener.onResponse(updateResponse); }, exception -> { log.error("Failed to update ML model: " + modelId, exception); - actionListener.onFailure(exception); - }), context::restore); + wrappedListener.onFailure(exception); + }); } private String incrementLatestVersion(Map modelGroupSourceMap) { @@ -417,12 +570,52 @@ private UpdateRequest createUpdateModelGroupRequest( } private Boolean isModelDeployed(MLModelState mlModelState) { - return !mlModelState.equals(MLModelState.LOADED) - && !mlModelState.equals(MLModelState.LOADING) - && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) - && !mlModelState.equals(MLModelState.DEPLOYED) - && !mlModelState.equals(MLModelState.DEPLOYING) - && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); + return mlModelState.equals(MLModelState.LOADED) + || mlModelState.equals(MLModelState.PARTIALLY_LOADED) + || mlModelState.equals(MLModelState.DEPLOYED) + || mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); + } + + private Boolean isModelDeploying(MLModelState mlModelState) { + return mlModelState.equals(MLModelState.LOADING) || mlModelState.equals(MLModelState.DEPLOYING); + } + + private String[] getAllNodes() { + Iterator iterator = clusterService.state().nodes().iterator(); + List nodeIds = new ArrayList<>(); + while (iterator.hasNext()) { + nodeIds.add(iterator.next().getId()); + } + return nodeIds.toArray(new String[0]); + } + + private boolean isUpdateModelCacheSuccessOnAllNodes(String modelId, MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { + if (updateModelCacheNodesResponse == null) { + return false; + } else { + for (MLUpdateModelCacheNodeResponse mlUpdateModelCacheNodeResponse : updateModelCacheNodesResponse.getNodes()) { + if (mlUpdateModelCacheNodeResponse.isModelUpdateStatusEmpty() + || !Objects.equals(mlUpdateModelCacheNodeResponse.getModelUpdateStatus().get(modelId), "success")) { + return false; + } + } + return true; + } + } + + private String[] getUpdateModelCacheFailedNodesList(String modelId, MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { + if (updateModelCacheNodesResponse == null) { + return getAllNodes(); + } else { + List nodeIds = new ArrayList<>(); + for (MLUpdateModelCacheNodeResponse mlUpdateModelCacheNodeResponse : updateModelCacheNodesResponse.getNodes()) { + if (mlUpdateModelCacheNodeResponse.isModelUpdateStatusEmpty() + || !Objects.equals(mlUpdateModelCacheNodeResponse.getModelUpdateStatus().get(modelId), "success")) { + nodeIds.add(mlUpdateModelCacheNodeResponse.getNode().getId()); + } + } + return nodeIds.toArray(new String[0]); + } } @VisibleForTesting diff --git a/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java new file mode 100644 index 0000000000..fb0b6e5717 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.update_cache; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UpdateModelCacheTransportAction extends + TransportNodesAction { + private final MLModelManager mlModelManager; + private final ClusterService clusterService; + private final Client client; + private DiscoveryNodeHelper nodeFilter; + private final MLStats mlStats; + private NamedXContentRegistry xContentRegistry; + + private ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public UpdateModelCacheTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLModelManager mlModelManager, + ClusterService clusterService, + ThreadPool threadPool, + Client client, + DiscoveryNodeHelper nodeFilter, + MLStats mlStats, + NamedXContentRegistry xContentRegistry, + ModelAccessControlHelper modelAccessControlHelper + ) { + super( + MLUpdateModelCacheAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + MLUpdateModelCacheNodesRequest::new, + MLUpdateModelCacheNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + MLUpdateModelCacheNodeResponse.class + ); + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.client = client; + this.nodeFilter = nodeFilter; + this.mlStats = mlStats; + this.xContentRegistry = xContentRegistry; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected MLUpdateModelCacheNodesResponse newResponse( + MLUpdateModelCacheNodesRequest nodesRequest, + List responses, + List failures + ) { + return new MLUpdateModelCacheNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected MLUpdateModelCacheNodeRequest newNodeRequest(MLUpdateModelCacheNodesRequest request) { + return new MLUpdateModelCacheNodeRequest(request); + } + + @Override + protected MLUpdateModelCacheNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLUpdateModelCacheNodeResponse(in); + } + + @Override + protected MLUpdateModelCacheNodeResponse nodeOperation(MLUpdateModelCacheNodeRequest request) { + return createUpdateModelCacheNodeResponse(request.getUpdateModelCacheNodesRequest()); + } + + private MLUpdateModelCacheNodeResponse createUpdateModelCacheNodeResponse( + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest + ) { + String modelId = mlUpdateModelCacheNodesRequest.getModelId(); + boolean isPredictorUpdate = mlUpdateModelCacheNodesRequest.isPredictorUpdate(); + + Map modelUpdateStatus = new HashMap<>(); + modelUpdateStatus.put(modelId, "received"); + + String localNodeId = clusterService.localNode().getId(); + + mlModelManager.updateModelCache(modelId, isPredictorUpdate, ActionListener.wrap(r -> { + modelUpdateStatus.replace(modelId, "success"); + log.info("Successfully performing in-place update model {} on node {}", modelId, localNodeId); + }, e -> { + modelUpdateStatus.replace(modelId, "failed"); + log.error("Failed to perform in-place update model for model {} on node {}", modelId, localNodeId); + })); + return new MLUpdateModelCacheNodeResponse(clusterService.localNode(), modelUpdateStatus); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 20286fd3c5..04fa8cb75d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -8,7 +8,6 @@ import static org.opensearch.common.xcontent.XContentType.JSON; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -66,7 +65,6 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; @@ -78,7 +76,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -90,6 +87,7 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -120,6 +118,7 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; +import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.script.ScriptService; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; @@ -291,7 +290,6 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .lastUpdateTime(now) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), EMPTY_PARAMS)); if (mlRegisterModelMetaInput.getIsHidden()) { indexRequest.id(modelName); @@ -665,7 +663,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas log.error("Failed to index model meta doc", e); handleException(functionName, taskId, e); }); - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, listener)); }, e -> { log.error("Failed to init model index", e); @@ -910,6 +907,47 @@ private void handleException(FunctionName functionName, String taskId, Exception mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } + public synchronized void updateModelCache(String modelId, boolean isPredictorUpdate, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + getModel(modelId, ActionListener.wrap(mlModel -> { + if (isPredictorUpdate) { + assert FunctionName.REMOTE == mlModel.getAlgorithm() + : "In-place update is only supported on REMOTE models at this time."; + Map params = ImmutableMap + .of( + ML_ENGINE, + mlEngine, + SCRIPT_SERVICE, + scriptService, + CLIENT, + client, + XCONTENT_REGISTRY, + xContentRegistry, + CLUSTER_SERVICE, + clusterService + ); + if (mlModel.getConnector() != null) { + Predictable predictable = mlEngine.deploy(mlModel, params); + modelCacheHelper.setPredictor(modelId, predictable); + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update internal connector for the model {}", modelId); + } else { + getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + Predictable predictable = mlEngine.deploy(mlModel, params); + modelCacheHelper.setPredictor(modelId, predictable); + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update stand-alone connector for the model {}", modelId); + }, wrappedListener::onFailure)); + } + wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); + log.info("Completed in-place update for the model {}", modelId); + } + }, wrappedListener::onFailure)); + } + } + /** * Read model chunks from model index. Concat chunks into a whole model file, then load * into memory. @@ -948,7 +986,7 @@ public void deployModel( } modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -974,28 +1012,12 @@ public void deployModel( return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - GetRequest getConnectorRequest = new GetRequest(); - FetchSourceContext fetchContext = new FetchSourceContext(true, null, null); - getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext); - // get connector and deploy remote model with standalone connector - client.get(getConnectorRequest, ActionListener.wrap(getResponse -> { - if (getResponse != null && getResponse.isExists()) { - try ( - XContentParser parser = createXContentParserFromRegistry( - xContentRegistry, - getResponse.getSourceAsBytesRef() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - mlModel.setConnector(connector); - setupPredictable(modelId, mlModel, params); - wrappedListener.onResponse("successful"); - log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); - } - } - }, e -> { wrappedListener.onFailure(e); })); - + getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupPredictable(modelId, mlModel, params); + wrappedListener.onResponse("successful"); + log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); + }, wrappedListener::onFailure)); return; } // check circuit breaker before deploying custom model chunks @@ -1102,8 +1124,7 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - GetResponse getResponse = r; - String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); + String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); MLModel mlModel = MLModel.parse(parser, algorithmName); mlModel.setModelId(modelId); @@ -1115,7 +1136,31 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio } else { listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); } - }, e -> { listener.onFailure(e); })); + }, listener::onFailure)); + } + + private void getConnector(Client client, String connectorId, ActionListener listener) { + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + }, e -> { + log.error("Failed to get connector", e); + listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); + })); } private void retrieveModelChunks(MLModel mlModelMeta, ActionListener listener) throws InterruptedException { @@ -1326,6 +1371,10 @@ public String[] getWorkerNodes(String modelId, FunctionName functionName, boolea return eligibleNodeIds; } + public int getWorkerNodesSize(String modelId, FunctionName functionName, boolean onlyEligibleNode) { + return getWorkerNodes(modelId, functionName, onlyEligibleNode).length; + } + /** * Get worker node of specific model without filtering eligible node. * @@ -1337,6 +1386,10 @@ public String[] getWorkerNodes(String modelId, FunctionName functionName) { return getWorkerNodes(modelId, functionName, false); } + public int getWorkerNodesSize(String modelId, FunctionName functionName) { + return getWorkerNodes(modelId, functionName, false).length; + } + /** * Get predictable instance with model id. * diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index bbfd5fdfbc..5fb925e077 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -71,6 +71,7 @@ import org.opensearch.ml.action.trainpredict.TransportTrainAndPredictionTaskAction; import org.opensearch.ml.action.undeploy.TransportUndeployModelAction; import org.opensearch.ml.action.undeploy.TransportUndeployModelsAction; +import org.opensearch.ml.action.update_cache.UpdateModelCacheTransportAction; import org.opensearch.ml.action.upload_chunk.MLModelChunkUploader; import org.opensearch.ml.action.upload_chunk.TransportRegisterModelMetaAction; import org.opensearch.ml.action.upload_chunk.TransportUploadModelChunkAction; @@ -122,6 +123,7 @@ import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction; import org.opensearch.ml.engine.MLEngine; @@ -316,6 +318,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLRegisterModelMetaAction.INSTANCE, TransportRegisterModelMetaAction.class), new ActionHandler<>(MLUploadModelChunkAction.INSTANCE, TransportUploadModelChunkAction.class), new ActionHandler<>(MLUpdateModelAction.INSTANCE, UpdateModelTransportAction.class), + new ActionHandler<>(MLUpdateModelCacheAction.INSTANCE, UpdateModelCacheTransportAction.class), new ActionHandler<>(MLForwardAction.INSTANCE, TransportForwardAction.class), new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class), new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java index 79959cbf26..fd5c828201 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -15,7 +15,9 @@ import java.util.Locale; import org.opensearch.OpenSearchParseException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; @@ -64,9 +66,16 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); try { MLUpdateModelInput input = MLUpdateModelInput.parse(parser); - // Model ID can only be set here. Model version can only be set automatically. + if (input.getConnectorId() != null && input.getConnectorUpdateContent() != null) { + throw new OpenSearchStatusException( + "Model cannot have both stand-alone connector and internal connector. Please check your update input body.", + RestStatus.BAD_REQUEST + ); + } + // Model ID can only be set here. Model version as well as connector field can only be set automatically. input.setModelId(modelId); input.setVersion(null); + input.setConnector(null); return new MLUpdateModelRequest(input); } catch (IllegalStateException e) { throw new OpenSearchParseException(e.getMessage()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index e1bbcfa881..b73fd907a6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -16,6 +16,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.UUID; import org.apache.lucene.search.TotalHits; @@ -61,11 +62,10 @@ import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { - private UpdateConnectorTransportAction transportUpdateConnectorAction; + private UpdateConnectorTransportAction updateConnectorTransportAction; @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; @@ -157,7 +157,7 @@ public void setup() throws IOException { Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); - transportUpdateConnectorAction = new UpdateConnectorTransportAction( + updateConnectorTransportAction = new UpdateConnectorTransportAction( transportService, actionFilters, client, @@ -179,8 +179,8 @@ public void setup() throws IOException { .name("test") .protocol("http") .version("1") - .credential(ImmutableMap.of("api_key", "credential_value")) - .parameters(ImmutableMap.of("param1", "value1")) + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) .actions( Arrays .asList( @@ -189,7 +189,7 @@ public void setup() throws IOException { .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("https://api.openai.com/v1/chat/completions") - .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") .build() ) @@ -203,7 +203,7 @@ public void setup() throws IOException { } @Test - public void test_execute_connectorAccessControl_success() { + public void testExecuteConnectorAccessControlSuccess() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -218,15 +218,15 @@ public void test_execute_connectorAccessControl_success() { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void test_execute_connectorAccessControl_NoPermission() { + public void testExecuteConnectorAccessControlNoPermission() { doReturn(false).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -236,31 +236,31 @@ public void test_execute_connectorAccessControl_NoPermission() { } @Test - public void test_execute_connectorAccessControl_AccessError() { + public void testExecuteConnectorAccessControlAccessError() { doThrow(new RuntimeException("Connector Access Control Error")) .when(connectorAccessControlHelper) .validateConnectorAccess(any(Client.class), any(Connector.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); } @Test - public void test_execute_connectorAccessControl_Exception() { + public void testExecuteConnectorAccessControlException() { doThrow(new RuntimeException("exception in access control")) .when(connectorAccessControlHelper) .validateConnectorAccess(any(Client.class), any(Connector.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); } @Test - public void test_execute_UpdateWrongStatus() { + public void testExecuteUpdateWrongStatus() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -276,12 +276,12 @@ public void test_execute_UpdateWrongStatus() { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void test_execute_UpdateException() { + public void testExecuteUpdateException() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -296,14 +296,14 @@ public void test_execute_UpdateException() { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("update document failure", argumentCaptor.getValue().getMessage()); } @Test - public void test_execute_SearchResponseNotEmpty() { + public void testExecuteSearchResponseNotEmpty() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -312,7 +312,7 @@ public void test_execute_SearchResponseNotEmpty() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertTrue( @@ -321,7 +321,7 @@ public void test_execute_SearchResponseNotEmpty() { } @Test - public void test_execute_SearchResponseError() { + public void testExecuteSearchResponseError() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -330,14 +330,14 @@ public void test_execute_SearchResponseError() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); } @Test - public void test_execute_SearchIndexNotFoundError() { + public void testExecuteSearchIndexNotFoundError() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -347,8 +347,8 @@ public void test_execute_SearchIndexNotFoundError() { .name("test") .protocol("http") .version("1") - .credential(ImmutableMap.of("api_key", "credential_value")) - .parameters(ImmutableMap.of("param1", "value1")) + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) .actions( Arrays .asList( @@ -357,7 +357,7 @@ public void test_execute_SearchIndexNotFoundError() { .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("https://api.openai.com/v1/chat/completions") - .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") .build() ) @@ -381,7 +381,7 @@ public void test_execute_SearchIndexNotFoundError() { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 143482ccc4..02629c392a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -79,8 +79,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableMap; - public class TransportDeployModelActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; @@ -473,7 +471,7 @@ public void testUpdateModelDeployStatusAndTriggerOnNodesAction_success() throws clientField.setAccessible(true); clientField.set(mlModelManager, client); - doCallRealMethod().when(mlModelManager).updateModel(anyString(), any(ImmutableMap.class), isA(ActionListener.class)); + doCallRealMethod().when(mlModelManager).updateModel(anyString(), any(Map.class), isA(ActionListener.class)); MLDeployModelNodesResponse MLDeployModelNodesResponse = mock(MLDeployModelNodesResponse.class); doAnswer(invocation -> { @@ -513,7 +511,7 @@ public void testUpdateModelDeployStatusAndTriggerOnNodesAction_success() throws } public void testUpdateModelDeployStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() { - doCallRealMethod().when(mlModelManager).updateModel(anyString(), any(ImmutableMap.class), isA(ActionListener.class)); + doCallRealMethod().when(mlModelManager).updateModel(anyString(), any(Map.class), isA(ActionListener.class)); transportDeployModelAction .updateModelDeployStatusAndTriggerOnNodesAction(modelId, "mock_task_id", mlModel, localNodeId, mlTask, eligibleNodes, false); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 0dca491658..1f09877c92 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -14,11 +14,17 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.Map; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -32,7 +38,9 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -48,11 +56,15 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; @@ -63,6 +75,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.collect.ImmutableList; + public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; @@ -103,11 +117,18 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); private ShardId shardId; + private Settings settings; + + Connector testConnector; + UpdateResponse updateResponse; UpdateModelTransportAction transportUpdateModelAction; @@ -116,25 +137,26 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { MLUpdateModelInput updateLocalModelInput; - MLUpdateModelRequest updateRemoteModelRequest; - - MLUpdateModelInput updateRemoteModelInput; - MLModel mlModelWithNullFunctionName; MLModel localModel; ThreadContext threadContext; + + ClusterState testState; + @Mock ClusterService clusterService; @Mock MLEngine mlEngine; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of("^https://api\\.test\\.com/.*$"); + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - + testState = setupTestClusterState(); updateLocalModelInput = MLUpdateModelInput .builder() .modelId("test_model_id") @@ -143,15 +165,6 @@ public void setup() throws IOException { .modelGroupId("updated_test_model_group_id") .build(); updateLocalModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateLocalModelInput).build(); - updateRemoteModelInput = MLUpdateModelInput - .builder() - .modelId("test_model_id") - .name("updated_test_name") - .description("updated_test_description") - .modelGroupId("updated_test_model_group_id") - .connectorId("updated_test_connector_id") - .build(); - updateRemoteModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); mlModelWithNullFunctionName = MLModel .builder() @@ -162,7 +175,22 @@ public void setup() throws IOException { .modelState(MLModelState.REGISTERED) .build(); - Settings settings = Settings.builder().build(); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX); + + localModel = prepareMLModel("TEXT_EMBEDDING"); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.state()).thenReturn(testState); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); transportUpdateModelAction = spy( new UpdateModelTransportAction( @@ -174,17 +202,32 @@ public void setup() throws IOException { mlModelManager, mlModelGroupManager, settings, - clusterService + clusterService, + mlEngine ) ); - localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING, false); - threadContext = new ThreadContext(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - when(clusterService.getSettings()).thenReturn(settings); - shardId = new ShardId(new Index("indexName", "uuid"), 1); - updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + testConnector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.test.com/v1/test") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -247,205 +290,81 @@ public void testUpdateLocalModelSuccess() { } @Test - public void testUpdateModelStateLoadedException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.LOADED).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } - - @Test - public void testUpdateModelStateLoadingException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.LOADING).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } - - @Test - public void testUpdateModelStatePartiallyLoadedException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } - - @Test - public void testUpdateModelStateDeployedException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.DEPLOYED).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); + public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelStateDeployingException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.DEPLOYING).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); + public void testUpdateModelWithRegisterToSameModelGroupSuccess() { + updateLocalModelRequest.getUpdateModelInput().setModelGroupId("test_model_group_id"); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelStatePartiallyDeployedException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - + public void testUpdateRemoteModelWithLocalInformationSuccess() { + MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); + listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - @Test - public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { - updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateRemoteModelWithLocalInformationSuccess() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); + public void testUpdateExternalRemoteModelWithExternalRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateRemoteModelWithRemoteInformationSuccess() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); + public void testUpdateInternalRemoteModelWithInternalRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); verify(actionListener).onResponse(updateResponse); } @Test public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, true); + MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL", true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); doReturn(true).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); verify(actionListener).onResponse(updateResponse); } @Test public void testUpdateHiddenRemoteModelPermissionError() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, true); + MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL", true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); doReturn(false).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -455,7 +374,7 @@ public void testUpdateHiddenRemoteModelPermissionError() { } @Test - public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { + public void testUpdateRemoteModelWithNoExternalConnectorFound() { MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -463,7 +382,7 @@ public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -474,7 +393,7 @@ public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { @Test public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); + MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -487,7 +406,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl return null; }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -498,7 +417,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl @Test public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); + MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -514,7 +433,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl return null; }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -650,11 +569,19 @@ public void testUpdateModelWithFunctionNameFieldNotFound() { } @Test - public void testUpdateLocalModelWithRemoteInformation() { - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + public void testUpdateLocalModelWithExternalRemoteInformation() { + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Trying to update the connector or connector_id field on a local model.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithInternalRemoteInformation() { + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Trying to update the connector or connector_id field on a local model", argumentCaptor.getValue().getMessage()); + assertEquals("Trying to update the connector or connector_id field on a local model.", argumentCaptor.getValue().getMessage()); } @Test @@ -666,13 +593,10 @@ public void testUpdateLocalModelWithUnsupportedFunction() { return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "User doesn't have privilege to perform this operation on this function category: KMEANS", - argumentCaptor.getValue().getMessage() - ); + assertEquals("The function category KMEANS is not supported at this time.", argumentCaptor.getValue().getMessage()); } @Test @@ -816,69 +740,435 @@ public void testGetUpdateResponseListenerOtherException() { ); } - // TODO: Add UT to make sure that version incremented successfully. + @Test + public void testUpdateModelStateDeployingException() { + MLModel testDeployingModel = prepareMLModel("TEXT_EMBEDDING", MLModelState.DEPLOYING); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testDeployingModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); - private MLModel prepareMLModel(FunctionName functionName, boolean isHidden) throws IllegalArgumentException { - MLModel mlModel; - switch (functionName) { - case TEXT_EMBEDDING: - mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelGroupId("test_model_group_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .isHidden(isHidden) - .build(); - return mlModel; - case REMOTE: - mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelGroupId("test_model_group_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.REMOTE) - .connectorId("test_connector_id") - .isHidden(isHidden) - .build(); - return mlModel; - default: - throw new IllegalArgumentException("Please choose from FunctionName.TEXT_EMBEDDING and FunctionName.REMOTE"); - } + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Model is deploying, please wait for it complete. model ID test_model_id", argumentCaptor.getValue().getMessage()); } - private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { - MLModel mlModel; - switch (unsupportedCase) { - case REMOTE: - mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelGroupId("test_model_group_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.REMOTE) - .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) - .build(); - return mlModel; - case KMEANS: - mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelGroupId("test_model_group_id") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.KMEANS) - .build(); - return mlModel; - default: - throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); - } + @Test + public void testUpdateModelCacheModelStateDeployedSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelStateDeployedWrongStatus() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelCacheModelStateDeployedUpdateException() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheException() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelCacheModelStateLoadedSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.LOADED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.PARTIALLY_DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.PARTIALLY_LOADED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + // TODO: Add UT to make sure that version incremented successfully. + private MLModel prepareMLModel(String functionName, MLModelState modelState, boolean isHidden) throws IllegalArgumentException { + MLModel mlModel; + switch (functionName) { + case "TEXT_EMBEDDING": + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(modelState) + .algorithm(FunctionName.TEXT_EMBEDDING) + .isHidden(isHidden) + .build(); + return mlModel; + case "REMOTE_EXTERNAL": + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(modelState) + .algorithm(FunctionName.REMOTE) + .connectorId("test_connector_id") + .isHidden(isHidden) + .build(); + return mlModel; + case "REMOTE_INTERNAL": + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(modelState) + .algorithm(FunctionName.REMOTE) + .connector(testConnector) + .isHidden(isHidden) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from TEXT_EMBEDDING, REMOTE_EXTERNAL, or REMOTE_INTERNAL"); + } + } + + private MLModel prepareMLModel(String functionName, MLModelState modelState) throws IllegalArgumentException { + return prepareMLModel(functionName, modelState, false); + } + + private MLModel prepareMLModel(String functionName, boolean isHidden) throws IllegalArgumentException { + return prepareMLModel(functionName, MLModelState.REGISTERED, isHidden); + } + + private MLModel prepareMLModel(String functionName) throws IllegalArgumentException { + return prepareMLModel(functionName, MLModelState.REGISTERED, false); + } + + private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { + MLModel mlModel; + switch (unsupportedCase) { + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); + return mlModel; + case KMEANS: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); + } + } + + private MLUpdateModelRequest prepareRemoteRequest(String remoteRequestType) throws IllegalArgumentException { + MLUpdateModelInput updateRemoteModelInput; + switch (remoteRequestType) { + case "REMOTE_EXTERNAL": + updateRemoteModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .connectorId("updated_test_connector_id") + .build(); + return MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); + case "REMOTE_INTERNAL": + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("1") + .description("updated description") + .build(); + updateRemoteModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .connectorUpdateContent(updateContent) + .build(); + return MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); + default: + throw new IllegalArgumentException("Please choose from REMOTE_EXTERNAL or REMOTE_INTERNAL"); + } } private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { @@ -887,4 +1177,104 @@ private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOExcep GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); return new GetResponse(getResult); } + + @Ignore + @Test + public void testUpdateModelStateLoadingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Ignore + @Test + public void testUpdateModelStateLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Ignore + @Test + public void testUpdateModelStatePartiallyLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Ignore + @Test + public void testUpdateModelStatePartiallyDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java new file mode 100644 index 0000000000..ec659b872c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java @@ -0,0 +1,179 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.update_cache; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportService; + +@RunWith(MockitoJUnitRunner.class) +public class UpdateModelCacheTransportActionTests { + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private Client client; + + @Mock + private DiscoveryNodeHelper nodeFilter; + + @Mock + private MLStats mlStats; + + @Mock + NamedXContentRegistry xContentRegistry; + + private UpdateModelCacheTransportAction action; + + private DiscoveryNode localNode; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Before + public void setUp() throws Exception { + action = new UpdateModelCacheTransportAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + null, + client, + nodeFilter, + mlStats, + xContentRegistry, + modelAccessControlHelper + ); + + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse("successful"); + return null; + }).when(mlModelManager).updateModelCache(any(), any(Boolean.class), any()); + } + + @Test + public void testNewResponses() { + final MLUpdateModelCacheNodesRequest nodesRequest = new MLUpdateModelCacheNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId", + true + ); + Map modelUpdateModelCacheStatusMap = new HashMap<>(); + modelUpdateModelCacheStatusMap.put("modelName:version", "response"); + MLUpdateModelCacheNodeResponse response = new MLUpdateModelCacheNodeResponse(localNode, modelUpdateModelCacheStatusMap); + final List responses = List.of(response); + final List failures = new ArrayList<>(); + MLUpdateModelCacheNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response1); + } + + @Test + public void testNewNodeRequest() { + final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId", + true + ); + final MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = action.newNodeRequest(request); + assertNotNull(updateModelCacheNodeRequest); + } + + @Test + public void testNewNodeStreamRequest() throws IOException { + Map updateModelCacheStatus = new HashMap<>(); + updateModelCacheStatus.put("modelId1", "response"); + MLUpdateModelCacheNodeResponse response = new MLUpdateModelCacheNodeResponse(localNode, updateModelCacheStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final MLUpdateModelCacheNodeResponse updateModelCacheNodeResponse = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(updateModelCacheNodeResponse); + } + + @Test + public void testNodeOperation() { + final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId", + true + ); + final MLUpdateModelCacheNodeResponse response = action.nodeOperation(new MLUpdateModelCacheNodeRequest(request)); + assertNotNull(response); + } + + @Test + public void testNodeOperationException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Test exception")); + return null; + }).when(mlModelManager).updateModelCache(any(), any(Boolean.class), any()); + final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId", + true + ); + final MLUpdateModelCacheNodeResponse response = action.nodeOperation(new MLUpdateModelCacheNodeRequest(request)); + assertNotNull(response); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index 28687d1c9c..10e78ecf2d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.TestHelper.toJsonString; import java.util.HashMap; import java.util.List; @@ -24,6 +25,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -32,6 +34,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; @@ -133,6 +136,40 @@ public void testUpdateModelRequestWithNullField() throws Exception { restMLUpdateModelAction.handleRequest(request, channel, client); } + @Test + public void testUpdateModelRequestWithConnectorIDAndConnectorUpdateContent() throws Exception { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule + .expectMessage("Model cannot have both stand-alone connector and internal connector. Please check your update input body."); + RestRequest request = getRestRequestWithConnectorIDAndConnectorUpdateContent(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithConnectorID() throws Exception { + RestRequest request = getRestRequestWithConnectorID(); + restMLUpdateModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateModelInput updateModelInput = argumentCaptor.getValue().getUpdateModelInput(); + assertEquals("testModelName", updateModelInput.getName()); + assertEquals("testConnectorID", updateModelInput.getConnectorId()); + } + + @Test + public void testUpdateModelRequestWithConnectorUpdateContent() throws Exception { + RestRequest request = getRestRequestWithConnectorUpdateContent(); + restMLUpdateModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateModelInput updateModelInput = argumentCaptor.getValue().getUpdateModelInput(); + assertEquals("testModelName", updateModelInput.getName()); + assertEquals( + "{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}", + toJsonString(updateModelInput.getConnectorUpdateContent()) + ); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.PUT; final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); @@ -188,4 +225,73 @@ private RestRequest getRestRequestWithNullField() { .build(); return request; } + + private RestRequest getRestRequestWithConnectorIDAndConnectorUpdateContent() { + RestRequest.Method method = RestRequest.Method.PUT; + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("1") + .description("updated description") + .build(); + final Map modelContent = Map + .of( + "name", + "testModelName", + "description", + "This is test description", + "connector_id", + "testConnectorID", + "connector_update_content", + updateContent + ); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithConnectorID() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map + .of("name", "testModelName", "description", "This is test description", "connector_id", "testConnectorID"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithConnectorUpdateContent() { + RestRequest.Method method = RestRequest.Method.PUT; + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("1") + .description("updated description") + .build(); + final Map modelContent = Map + .of("name", "testModelName", "description", "This is test description", "connector_update_content", updateContent); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } }