diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 9e2fbb7133..0ff6754e6b 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -39,6 +39,7 @@ public class MLModelGroup implements ToXContentObject { @Setter private String name; private String description; + @Setter private int latestVersion; private List backendRoles; private User owner; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java new file mode 100644 index 0000000000..2d584a0e73 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateModelAction extends ActionType { + public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/update"; + + private MLUpdateModelAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java new file mode 100644 index 0000000000..ca0a2f70d4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Data; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.Connector.createConnector; + +@Data +public class MLUpdateModelInput implements ToXContentObject, Writeable { + + public static final String MODEL_ID_FIELD = "model_id"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String MODEL_VERSION_FIELD = "model_version"; // optional + public static final String MODEL_NAME_FIELD = "name"; // optional + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String MODEL_CONFIG_FIELD = "model_config"; // optional + public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional + + @Getter + private String modelId; + private String description; + private String version; + private String name; + private String modelGroupId; + private MLModelConfig modelConfig; + private String connectorId; + + @Builder(toBuilder = true) + public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { + this.modelId = modelId; + this.description = description; + this.version = version; + this.name = name; + this.modelGroupId = modelGroupId; + this.modelConfig = modelConfig; + this.connectorId = connectorId; + } + + public MLUpdateModelInput(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.description = in.readOptionalString(); + this.version = in.readOptionalString(); + this.name = in.readOptionalString(); + this.modelGroupId = in.readOptionalString(); + if (in.readBoolean()) { + modelConfig = new TextEmbeddingModelConfig(in); + } + this.connectorId = in.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + if (name != null) { + builder.field(MODEL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (version != null) { + builder.field(MODEL_VERSION_FIELD, version); + } + if (modelGroupId != null) { + builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); + } + if (modelConfig != null) { + builder.field(MODEL_CONFIG_FIELD, modelConfig); + } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(description); + out.writeOptionalString(version); + out.writeOptionalString(name); + out.writeOptionalString(modelGroupId); + if (modelConfig != null) { + out.writeBoolean(true); + modelConfig.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(connectorId); + } + + public static MLUpdateModelInput parse(XContentParser parser) throws IOException { + String modelId = null; + String description = null; + String version = null; + String name = null; + String modelGroupId = null; + MLModelConfig modelConfig = null; + String connectorId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case MODEL_NAME_FIELD: + name = parser.text(); + break; + case MODEL_VERSION_FIELD: + version = parser.text(); + break; + case MODEL_GROUP_ID_FIELD: + modelGroupId = parser.text(); + break; + case MODEL_CONFIG_FIELD: + modelConfig = TextEmbeddingModelConfig.parse(parser); + break; + case CONNECTOR_ID_FIELD: + connectorId = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + // Model ID can only be set through RestRequest. Model version can only be set automatically. + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java new file mode 100644 index 0000000000..b589f71ed4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLUpdateModelRequest extends ActionRequest { + + MLUpdateModelInput updateModelInput; + + @Builder + public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { + this.updateModelInput = updateModelInput; + } + + public MLUpdateModelRequest(StreamInput in) throws IOException { + super(in); + updateModelInput = new MLUpdateModelInput(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (updateModelInput == null) { + exception = addValidationError("Update Model Input can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.updateModelInput.writeTo(out); + } + + public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ + if (actionRequest instanceof MLUpdateModelRequest) { + return (MLUpdateModelRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateModelRequest(in); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java new file mode 100644 index 0000000000..eaa1474709 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +public class MLUpdateModelInputTest { + + private MLUpdateModelInput updateModelInput; + private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + + @Before + public void setUp() throws Exception { + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .version("2") + .name("name") + .description("description") + .modelConfig(config) + .connectorId("test-connector_id") + .build(); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(updateModelInput, parsedInput -> { + assertEquals("test-model_id", parsedInput.getModelId()); + assertEquals(updateModelInput.getName(), parsedInput.getName()); + }); + } + + @Test + public void readInputStream_SuccessWithNullFields() throws IOException { + updateModelInput.setModelConfig(null); + readInputStream(updateModelInput, parsedInput -> { + assertNull(parsedInput.getModelConfig()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContent_Incomplete() throws Exception { + String expectedIncompleteInputStr = + "{\"model_id\":\"test-model_id\"}"; + updateModelInput.setDescription(null); + updateModelInput.setVersion(null); + updateModelInput.setName(null); + updateModelInput.setModelGroupId(null); + updateModelInput.setModelConfig(null); + updateModelInput.setConnectorId(null); + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void parse_Success() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("name", parsedInput.getName()); + }); + } + + @Test + public void parse_WithIllegalFieldWithoutModel() throws Exception { + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + parser.nextToken(); + MLUpdateModelInput parsedInput = MLUpdateModelInput.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLUpdateModelInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUpdateModelInput parsedInput = new MLUpdateModelInput(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLUpdateModelInput input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java new file mode 100644 index 0000000000..cadf865b1c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.junit.Before; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + + +public class MLUpdateModelRequestTest { + + private MLUpdateModelRequest updateModelRequest; + + @Before + public void setUp(){ + MockitoAnnotations.openMocks(this); + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .name("name") + .description("description") + .modelConfig(config) + .build(); + + updateModelRequest = MLUpdateModelRequest.builder() + .updateModelInput(updateModelInput) + .build(); + + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + updateModelRequest.writeTo(bytesStreamOutput); + MLUpdateModelRequest parsedUpdateRequest = new MLUpdateModelRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("test-model_id", parsedUpdateRequest.getUpdateModelInput().getModelId()); + assertEquals("name", parsedUpdateRequest.getUpdateModelInput().getName()); + } + + @Test + public void validate_Success() { + assertNull(updateModelRequest.validate()); + } + + @Test + public void validate_Exception_NullModelInput() { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.builder().build(); + Exception exception = updateModelRequest.validate(); + + assertEquals("Validation Failed: 1: Update Model Input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + assertSame(MLUpdateModelRequest.fromActionRequest(updateModelRequest), updateModelRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateModelRequest.writeTo(out); + } + }; + MLUpdateModelRequest request = MLUpdateModelRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateModelRequest); + assertEquals(updateModelRequest.getUpdateModelInput().getName(), request.getUpdateModelInput().getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUpdateModelRequest.fromActionRequest(actionRequest); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java new file mode 100644 index 0000000000..ea4116c365 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -0,0 +1,324 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class UpdateModelTransportAction extends HandledTransportAction { + Client client; + ModelAccessControlHelper modelAccessControlHelper; + ConnectorAccessControlHelper connectorAccessControlHelper; + MLModelManager mlModelManager; + MLModelGroupManager mlModelGroupManager; + + @Inject + public UpdateModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper, + ModelAccessControlHelper modelAccessControlHelper, + MLModelManager mlModelManager, + MLModelGroupManager mlModelGroupManager + ) { + super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); + this.client = client; + this.modelAccessControlHelper = modelAccessControlHelper; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + this.mlModelGroupManager = mlModelGroupManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.fromActionRequest(request); + MLUpdateModelInput updateModelInput = updateModelRequest.getUpdateModelInput(); + String modelId = updateModelInput.getModelId(); + User user = RestActionUtils.getUserContext(client); + + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { + FunctionName functionName = mlModel.getAlgorithm(); + MLModelState mlModelState = mlModel.getModelState(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + if (!mlModelState.equals(MLModelState.LOADED) + && !mlModelState.equals(MLModelState.LOADING) + && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) + && !mlModelState.equals(MLModelState.DEPLOYED) + && !mlModelState.equals(MLModelState.DEPLOYING) + && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) { + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener); + } else { + actionListener + .onFailure( + new MLValidationException( + "ML Model " + + modelId + + " is in deploying or deployed state, please undeploy the models first!" + ) + ); + } + } else { + actionListener + .onFailure( + new MLValidationException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new MLValidationException( + "User doesn't have privilege to perform this operation on this function category: " + + functionName.toString() + ) + ); + } + }, + e -> actionListener + .onFailure(new MLResourceNotFoundException("Failed to find model to update with the provided model id: " + modelId)) + ), () -> context.restore())); + } catch (Exception e) { + log.error("Failed to update ML model for " + modelId, e); + actionListener.onFailure(e); + } + } + + private void updateRemoteOrTextEmbeddingModel( + String modelId, + MLUpdateModelInput updateModelInput, + MLModel mlModel, + User user, + ActionListener actionListener + ) { + String newModelGroupId = Strings.hasLength(updateModelInput.getModelGroupId()) ? updateModelInput.getModelGroupId() : null; + String relinkConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; + + if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { + if (relinkConnectorId == null) { + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + } else { + actionListener + .onFailure(new IllegalArgumentException("Trying to update the connector or connector_id field on a local model")); + } + } else { + // mlModel.getAlgorithm() == REMOTE + if (relinkConnectorId == null) { + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + } else { + updateModelWithRelinkStandAloneConnector( + modelId, + newModelGroupId, + relinkConnectorId, + mlModel, + user, + updateModelInput, + actionListener + ); + } + } + } + + private void updateModelWithRelinkStandAloneConnector( + String modelId, + String newModelGroupId, + String relinkConnectorId, + MLModel mlModel, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener + ) { + if (Strings.hasLength(mlModel.getConnectorId())) { + connectorAccessControlHelper + .validateConnectorAccess(client, relinkConnectorId, ActionListener.wrap(hasRelinkConnectorPermission -> { + if (hasRelinkConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + } else { + actionListener + .onFailure( + new MLValidationException( + "You don't have permission to update the connector, connector id: " + relinkConnectorId + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", relinkConnectorId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new IllegalArgumentException("This remote does not have a connector_id field, maybe it uses an internal connector.") + ); + } + } + + private void updateModelWithRegisteringToAnotherModelGroup( + String modelId, + String newModelGroupId, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener + ) { + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); + if (newModelGroupId != null) { + modelAccessControlHelper.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasRelinkPermission -> { + if (hasRelinkPermission) { + mlModelGroupManager.getModelGroup(newModelGroupId, ActionListener.wrap(newModelGroup -> { + updateRequestConstructor(modelId, updateRequest, updateModelInput, newModelGroup, actionListener); + }, + exception -> actionListener + .onFailure( + new MLResourceNotFoundException( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + + newModelGroupId + ) + ) + )); + } else { + actionListener + .onFailure( + new MLValidationException( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + + newModelGroupId + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + updateRequestConstructor(modelId, updateRequest, updateModelInput, actionListener); + } + } + + private void updateRequestConstructor( + String modelId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + ActionListener actionListener + ) { + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener)); + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private void updateRequestConstructor( + String modelId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + MLModelGroup newModelGroup, + ActionListener actionListener + ) { + String updatedVersion = incrementLatestVersion(newModelGroup); + updateModelInput.setVersion(updatedVersion); + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + client.update(updateRequest, getUpdateResponseListener(modelId, newModelGroup, updatedVersion, actionListener)); + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener(String modelId, ActionListener actionListener) { + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Completed Update Model Request, model id:{} updated", modelId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }); + } + + private ActionListener getUpdateResponseListener( + String modelId, + MLModelGroup newModelGroup, + String updatedVersion, + ActionListener actionListener + ) { + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Completed Update Model Request, model id:{} updated", modelId); + newModelGroup.setLatestVersion(Integer.parseInt(updatedVersion)); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }); + } + + private String incrementLatestVersion(MLModelGroup mlModelGroup) { + return Integer.toString(mlModelGroup.getLatestVersion() + 1); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 94cbcf5364..a28185d7dc 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -5,12 +5,14 @@ package org.opensearch.ml.model; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import java.time.Instant; import java.util.HashSet; import java.util.Iterator; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -23,16 +25,20 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -207,6 +213,34 @@ public void validateUniqueModelGroupName(String name, ActionListener listener) { + GetRequest getRequest = new GetRequest(); + getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + listener.onResponse(mlModelGroup); + } catch (Exception e) { + log.error("Failed to parse ml model group.", e); + listener.onFailure(e); + } + } else { + listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); + } + }, e -> { listener.onFailure(e); })); + } + private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { if (input.getModelAccessMode() != null || input.getIsAddAllBackendRoles() != null diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ee1213c057..b4787cc5ed 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -53,6 +53,7 @@ import org.opensearch.ml.action.models.DeleteModelTransportAction; import org.opensearch.ml.action.models.GetModelTransportAction; import org.opensearch.ml.action.models.SearchModelTransportAction; +import org.opensearch.ml.action.models.UpdateModelTransportAction; import org.opensearch.ml.action.prediction.TransportPredictionTaskAction; import org.opensearch.ml.action.profile.MLProfileAction; import org.opensearch.ml.action.profile.MLProfileTransportAction; @@ -100,6 +101,7 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; @@ -166,6 +168,7 @@ import org.opensearch.ml.rest.RestMLTrainingAction; import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateConnectorAction; +import org.opensearch.ml.rest.RestMLUpdateModelAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -282,6 +285,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLUndeployModelsAction.INSTANCE, TransportUndeployModelsAction.class), new ActionHandler<>(MLRegisterModelMetaAction.INSTANCE, TransportRegisterModelMetaAction.class), new ActionHandler<>(MLUploadModelChunkAction.INSTANCE, TransportUploadModelChunkAction.class), + new ActionHandler<>(MLUpdateModelAction.INSTANCE, UpdateModelTransportAction.class), new ActionHandler<>(MLForwardAction.INSTANCE, TransportForwardAction.class), new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class), new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), @@ -537,6 +541,7 @@ public List getRestHandlers( RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(); RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); + RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); @@ -558,6 +563,7 @@ public List getRestHandlers( restMLGetModelAction, restMLDeleteModelAction, restMLSearchModelAction, + restMLUpdateModelAction, restMLGetTaskAction, restMLDeleteTaskAction, restMLSearchTaskAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java new file mode 100644 index 0000000000..002a899f5d --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMLUpdateModelAction extends BaseRestHandler { + + private static final String ML_UPDATE_MODEL_ACTION = "ml_update_model_action"; + + @Override + public String getName() { + return ML_UPDATE_MODEL_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/models/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLUpdateModelRequest updateModelRequest = getRequest(request); + return channel -> client.execute(MLUpdateModelAction.INSTANCE, updateModelRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLUpdateModelRequest from a RestRequest + * + * @param request RestRequest + * @return MLUpdateModelRequest + */ + private MLUpdateModelRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new IOException("Model update request has empty body"); + } + + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLUpdateModelInput input = MLUpdateModelInput.parse(parser); + // Model ID can only be set here. Model version can only be set automatically. + input.setModelId(modelId); + input.setVersion(null); + return new MLUpdateModelRequest(input); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java index fc6020474a..974d5ed1ca 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -152,6 +153,7 @@ public void setup() throws IOException { updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); } + @Test public void test_execute_connectorAccessControl_success() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -175,6 +177,7 @@ public void test_execute_connectorAccessControl_success() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_connectorAccessControl_NoPermission() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -191,6 +194,7 @@ public void test_execute_connectorAccessControl_NoPermission() { ); } + @Test public void test_execute_connectorAccessControl_AccessError() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -204,6 +208,7 @@ public void test_execute_connectorAccessControl_AccessError() { assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_connectorAccessControl_Exception() { doThrow(new RuntimeException("exception in access control")) .when(connectorAccessControlHelper) @@ -215,6 +220,7 @@ public void test_execute_connectorAccessControl_Exception() { assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_UpdateWrongStatus() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -239,6 +245,7 @@ public void test_execute_UpdateWrongStatus() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_UpdateException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -264,6 +271,7 @@ public void test_execute_UpdateException() { assertEquals("update document failure", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchResponseNotEmpty() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -283,6 +291,7 @@ public void test_execute_SearchResponseNotEmpty() { assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchResponseError() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -302,6 +311,7 @@ public void test_execute_SearchResponseError() { assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchIndexNotFoundError() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java new file mode 100644 index 0000000000..11fcc75b10 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -0,0 +1,804 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateModelTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + Task task; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + MLUpdateModelInput mockUpdateModelInput; + + @Mock + MLUpdateModelRequest mockUpdateModelRequest; + + @Mock + MLModel mockModel; + + @Mock + MLModelGroup mockModelGroup; + + @Mock + MLModelManager mlModelManager; + + @Mock + MLModelGroupManager mlModelGroupManager; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private ShardId shardId; + + UpdateResponse updateResponse; + + UpdateModelTransportAction transportUpdateModelAction; + + MLUpdateModelRequest updateLocalModelRequest; + + MLUpdateModelInput updateLocalModelInput; + + MLUpdateModelRequest updateRemoteModelRequest; + + MLUpdateModelInput updateRemoteModelInput; + + MLModel mlModelWithNullFunctionName; + + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + updateLocalModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .build(); + updateLocalModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateLocalModelInput).build(); + updateRemoteModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .connectorId("updated_test_connector_id") + .build(); + updateRemoteModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); + + mlModelWithNullFunctionName = MLModel + .builder() + .modelId("test_model_id") + .name("test_name") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .build(); + + Settings settings = Settings.builder().build(); + + transportUpdateModelAction = spy( + new UpdateModelTransportAction( + transportService, + actionFilters, + client, + connectorAccessControlHelper, + modelAccessControlHelper, + mlModelManager, + mlModelGroupManager + ) + ); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), eq("updated_test_connector_id"), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockModelGroup); + return null; + }).when(mlModelGroupManager).getModelGroup(eq("updated_test_model_group_id"), isA(ActionListener.class)); + } + + @Test + public void testUpdateLocalModelSuccess() { + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelStateLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateLoadingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithLocalInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { + MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModelWithInternalConnector); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: updated_test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") + ); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other connector access control Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model, model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during update the model. Please check log for more details." + ) + ); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during update the model. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details." + ) + ); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); + return null; + }).when(mlModelGroupManager).getModelGroup(eq("updated_test_model_group_id"), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model to update with the provided model id: test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelWithFunctionNameFieldNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModelWithNullFunctionName); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + @Test + public void testUpdateLocalModelWithRemoteInformation() { + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Trying to update the connector or connector_id field on a local model", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithUnsupportedFunction() { + MLModel localModelWithUnsupportedFunction = prepareUnsupportedMLModel(FunctionName.KMEANS); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModelWithUnsupportedFunction); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this function category: KMEANS", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRequestDocIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doReturn("mockUpdateModelGroupId").when(mockUpdateModelInput).getModelGroupId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockModelGroup); + return null; + }).when(mlModelGroupManager).getModelGroup(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetUpdateResponseListenerWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { + MLModel mlModel; + switch (functionName) { + case TEXT_EMBEDDING: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.TEXT_EMBEDDING) + .build(); + return mlModel; + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connectorId("test_connector_id") + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.TEXT_EMBEDDING and FunctionName.REMOTE"); + } + } + + private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { + MLModel mlModel; + switch (unsupportedCase) { + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); + return mlModel; + case KMEANS: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f7eb759026..5425775331 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -23,35 +23,37 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Mock - private TransportService transportService; - @Mock private MLIndicesHandler mlIndicesHandler; @@ -61,26 +63,23 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock private ThreadPool threadPool; - @Mock - private Task task; - @Mock private Client client; - @Mock - private ActionFilters actionFilters; @Mock private ActionListener actionListener; + @Mock + private ActionListener modelGroupListener; + @Mock private IndexResponse indexResponse; ThreadContext threadContext; - private TransportRegisterModelGroupAction transportRegisterModelGroupAction; - @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock private MLModelGroupManager mlModelGroupManager; @@ -335,6 +334,62 @@ public void test_ExceptionInitModelGroupIndexIfAbsent() { assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } + public void test_SuccessGetModelGroup() throws IOException { + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("testModelGroupID") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGroup.class); + verify(modelGroupListener).onResponse(argumentCaptor.capture()); + } + + public void test_OtherExceptionGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException("Any other Exception occurred during getting the model group. Please check log for more details.") + ); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during getting the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_NotFoundGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelGroupInput prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { return MLRegisterModelGroupInput .builder() @@ -363,4 +418,10 @@ private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOE return searchResponse; } + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java new file mode 100644 index 0000000000..e4511df9dc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMLUpdateModelActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateModelAction restMLUpdateModelAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + restMLUpdateModelAction = new RestMLUpdateModelAction(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateModelAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testConstructor() { + RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(); + assertNotNull(UpdateModelAction); + } + + @Test + public void testGetName() { + String actionName = restMLUpdateModelAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_model_action", actionName); + } + + @Test + public void testRoutes() { + List routes = restMLUpdateModelAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); + } + + @Test + public void testUpdateModelRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateModelInput updateModelInput = argumentCaptor.getValue().getUpdateModelInput(); + assertEquals("testModelName", updateModelInput.getName()); + assertEquals("This is test description", updateModelInput.getDescription()); + } + + @Test + public void testUpdateModelRequestWithEmptyContent() throws Exception { + exceptionRule.expect(IOException.class); + exceptionRule.expectMessage("Model update request has empty body"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithNullModelId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain model_id"); + RestRequest request = getRestRequestWithNullModelId(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.PUT; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullModelId() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}