diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 84c3a96712..f82e742866 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -37,7 +37,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; @@ -186,6 +186,9 @@ public class CommonValue { + MLModel.DEPLOY_TO_ALL_NODES_FIELD + "\": {\"type\": \"boolean\"},\n" + " \"" + + MLModel.IS_HIDDEN_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + MLModel.MODEL_CONFIG_FIELD + "\" : {\"properties\":{\"" + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 78f6f4ac60..fef1af1196 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -75,6 +75,8 @@ public class MLModel implements ToXContentObject { public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count"; public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes"; public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes"; + + public static final String IS_HIDDEN_FIELD = "is_hidden"; public static final String CONNECTOR_FIELD = "connector"; public static final String CONNECTOR_ID_FIELD = "connector_id"; @@ -110,6 +112,9 @@ public class MLModel implements ToXContentObject { private String[] planningWorkerNodes; // plan to deploy model to these nodes private boolean deployToAllNodes; + //is domain manager creates any special hidden model in the cluster this status will be true. Otherwise, + // False by default + private Boolean isHidden; @Setter private Connector connector; private String connectorId; @@ -139,6 +144,7 @@ public MLModel(String name, Integer currentWorkerNodeCount, String[] planningWorkerNodes, boolean deployToAllNodes, + Boolean isHidden, Connector connector, String connectorId) { this.name = name; @@ -166,6 +172,7 @@ public MLModel(String name, this.currentWorkerNodeCount = currentWorkerNodeCount; this.planningWorkerNodes = planningWorkerNodes; this.deployToAllNodes = deployToAllNodes; + this.isHidden = isHidden; this.connector = connector; this.connectorId = connectorId; } @@ -210,6 +217,7 @@ public MLModel(StreamInput input) throws IOException{ currentWorkerNodeCount = input.readOptionalInt(); planningWorkerNodes = input.readOptionalStringArray(); deployToAllNodes = input.readBoolean(); + isHidden = input.readOptionalBoolean(); modelGroupId = input.readOptionalString(); if (input.readBoolean()) { connector = Connector.fromStream(input); @@ -263,6 +271,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(currentWorkerNodeCount); out.writeOptionalStringArray(planningWorkerNodes); out.writeBoolean(deployToAllNodes); + out.writeOptionalBoolean(isHidden); out.writeOptionalString(modelGroupId); if (connector != null) { out.writeBoolean(true); @@ -351,6 +360,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (deployToAllNodes) { builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes); } + if (isHidden != null) { + builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); + } if (connector != null) { builder.field(CONNECTOR_FIELD, connector); } @@ -393,6 +405,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws Integer currentWorkerNodeCount = null; List planningWorkerNodes = new ArrayList<>(); boolean deployToAllNodes = false; + boolean isHidden = false; Connector connector = null; String connectorId = null; @@ -476,6 +489,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case DEPLOY_TO_ALL_NODES_FIELD: deployToAllNodes = parser.booleanValue(); break; + case IS_HIDDEN_FIELD: + isHidden = parser.booleanValue(); + break; case CONNECTOR_FIELD: connector = createConnector(parser); break; @@ -537,6 +553,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .currentWorkerNodeCount(currentWorkerNodeCount) .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) .deployToAllNodes(deployToAllNodes) + .isHidden(isHidden) .connector(connector) .connectorId(connectorId) .build(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 71c5200971..7cad570f1d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -31,17 +31,23 @@ public class MLModelGetRequest extends ActionRequest { String modelId; boolean returnContent; + // This is to identify if the get request is initiated by user or not. Sometimes during + // delete/update options, we also perform get operation. This field is to distinguish between + // these two situations. + boolean isUserInitiatedGetRequest; @Builder - public MLModelGetRequest(String modelId, boolean returnContent) { + public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest) { this.modelId = modelId; this.returnContent = returnContent; + this.isUserInitiatedGetRequest = isUserInitiatedGetRequest; } public MLModelGetRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); this.returnContent = in.readBoolean(); + this.isUserInitiatedGetRequest = in.readBoolean(); } @Override @@ -49,6 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.modelId); out.writeBoolean(returnContent); + out.writeBoolean(isUserInitiatedGetRequest); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index a871332f95..e8866bc8e4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -15,6 +15,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; @@ -25,6 +26,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -42,7 +44,6 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; public static final String URL_FIELD = "url"; - public static final String HASH_VALUE_FIELD = "model_content_hash_value"; public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String DEPLOY_MODEL_FIELD = "deploy_model"; @@ -75,6 +76,8 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private AccessMode accessMode; private Boolean doesVersionCreateModelGroup; + private Boolean isHidden; + @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, String modelName, @@ -92,13 +95,10 @@ public MLRegisterModelInput(FunctionName functionName, List backendRoles, Boolean addAllBackendRoles, AccessMode accessMode, - Boolean doesVersionCreateModelGroup + Boolean doesVersionCreateModelGroup, + Boolean isHidden ) { - if (functionName == null) { - this.functionName = FunctionName.TEXT_EMBEDDING; - } else { - this.functionName = functionName; - } + this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); } @@ -126,6 +126,7 @@ public MLRegisterModelInput(FunctionName functionName, this.addAllBackendRoles = addAllBackendRoles; this.accessMode = accessMode; this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; + this.isHidden = isHidden; } @@ -161,6 +162,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.accessMode = in.readEnum(AccessMode.class); } this.doesVersionCreateModelGroup = in.readOptionalBoolean(); + this.isHidden = in.readOptionalBoolean(); } @Override @@ -207,6 +209,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(doesVersionCreateModelGroup); + out.writeOptionalBoolean(isHidden); } @Override @@ -227,7 +230,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(URL_FIELD, url); } if (hashValue != null) { - builder.field(HASH_VALUE_FIELD, hashValue); + builder.field(MODEL_CONTENT_HASH_VALUE_FIELD, hashValue); } if (modelFormat != null) { builder.field(MODEL_FORMAT_FIELD, modelFormat); @@ -257,6 +260,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (doesVersionCreateModelGroup != null) { builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); } + if (isHidden != null) { + builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); + } builder.endObject(); return builder; } @@ -276,6 +282,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName Boolean addAllBackendRoles = null; AccessMode accessMode = null; Boolean doesVersionCreateModelGroup = null; + Boolean isHidden = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -291,7 +298,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case URL_FIELD: url = parser.text(); break; - case HASH_VALUE_FIELD: + case MODEL_CONTENT_HASH_VALUE_FIELD: hashValue = parser.text(); break; case DESCRIPTION_FIELD: @@ -324,6 +331,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case ADD_ALL_BACKEND_ROLES_FIELD: addAllBackendRoles = parser.booleanValue(); break; + case MLModel.IS_HIDDEN_FIELD: + isHidden = parser.booleanValue(); + break; case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; @@ -335,7 +345,8 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); + } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -355,6 +366,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo AccessMode accessMode = null; Boolean addAllBackendRoles = null; Boolean doesVersionCreateModelGroup = null; + Boolean isHidden = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -383,7 +395,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case CONNECTOR_FIELD: connector = createConnector(parser); break; - case HASH_VALUE_FIELD: + case MODEL_CONTENT_HASH_VALUE_FIELD: hashValue = parser.text(); break; case CONNECTOR_ID_FIELD: @@ -416,11 +428,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case DOES_VERSION_CREATE_MODEL_GROUP: doesVersionCreateModelGroup = parser.booleanValue(); break; + case MLModel.IS_HIDDEN_FIELD: + isHidden = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index ecb03d9bb6..60e74da150 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -68,12 +69,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private AccessMode accessMode; private Boolean isAddAllBackendRoles; private Boolean doesVersionCreateModelGroup; + private Boolean isHidden; @Builder(toBuilder = true) public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, Boolean isAddAllBackendRoles, - Boolean doesVersionCreateModelGroup) { + Boolean doesVersionCreateModelGroup, Boolean isHidden) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -108,6 +110,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.accessMode = accessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; + this.isHidden = isHidden; } public MLRegisterModelMetaInput(StreamInput in) throws IOException{ @@ -134,6 +137,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ } this.isAddAllBackendRoles = in.readOptionalBoolean(); this.doesVersionCreateModelGroup = in.readOptionalBoolean(); + this.isHidden = in.readOptionalBoolean(); } @Override @@ -178,21 +182,22 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalBoolean(isAddAllBackendRoles); out.writeOptionalBoolean(doesVersionCreateModelGroup); + out.writeOptionalBoolean(isHidden); } @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); - builder.field(MODEL_NAME_FIELD, name); - builder.field(FUNCTION_NAME_FIELD, functionName); + builder.field(MLModel.MODEL_NAME_FIELD, name); + builder.field(MLModel.FUNCTION_NAME_FIELD, functionName); if (modelGroupId != null) { - builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); + builder.field(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId); } if (version != null) { builder.field(VERSION_FIELD, version); } if (description != null) { - builder.field(DESCRIPTION_FIELD, description); + builder.field(MLModel.DESCRIPTION_FIELD, description); } builder.field(MODEL_FORMAT_FIELD, modelFormat); if (modelState != null) { @@ -216,6 +221,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (doesVersionCreateModelGroup != null) { builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); } + if (isHidden != null) { + builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); + } builder.endObject(); return builder; } @@ -236,6 +244,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc AccessMode accessMode = null; Boolean isAddAllBackendRoles = null; Boolean doesVersionCreateModelGroup = null; + Boolean isHidden = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -291,12 +300,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case DOES_VERSION_CREATE_MODEL_GROUP: doesVersionCreateModelGroup = parser.booleanValue(); break; + case MLModel.IS_HIDDEN_FIELD: + isHidden = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java index b59fa2ac2a..b493ea29e9 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java @@ -59,6 +59,7 @@ public void setUp() { .modelId("model_id") .chunkNumber(1) .totalChunks(10) + .isHidden(false) .build(); function = parser -> { try { @@ -71,11 +72,11 @@ public void setUp() { @Test public void toXContent() throws IOException { - MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version("1.0.0").content("test_content").build(); + MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version("1.0.0").content("test_content").isHidden(true).build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\"}", mlModelContent); + assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\",\"is_hidden\":true}", mlModelContent); } @Test @@ -111,5 +112,15 @@ public void readInputStream(MLModel mlModel) throws IOException { assertEquals(mlModel.getVersion(), parsedMLModel.getVersion()); assertEquals(mlModel.getContent(), parsedMLModel.getContent()); assertEquals(mlModel.getUser(), parsedMLModel.getUser()); + assertEquals(mlModel.getIsHidden(), parsedMLModel.getIsHidden()); + assertEquals(mlModel.getDescription(), parsedMLModel.getDescription()); + assertEquals(mlModel.getModelContentHash(), parsedMLModel.getModelContentHash()); + assertEquals(mlModel.getCreatedTime(), parsedMLModel.getCreatedTime()); + assertEquals(mlModel.getLastRegisteredTime(), parsedMLModel.getLastRegisteredTime()); + assertEquals(mlModel.getLastDeployedTime(), parsedMLModel.getLastDeployedTime()); + assertEquals(mlModel.getLastUndeployedTime(), parsedMLModel.getLastUndeployedTime()); + assertEquals(mlModel.getModelId(), parsedMLModel.getModelId()); + assertEquals(mlModel.getChunkNumber(), parsedMLModel.getChunkNumber()); + assertEquals(mlModel.getTotalChunks(), parsedMLModel.getTotalChunks()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 6de3788243..17e5292f70 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -44,10 +44,21 @@ public class MLRegisterModelInputTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," + - "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + - "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + - "},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + + "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\"," + + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\"," + + "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"},\"is_hidden\":false}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; @@ -63,6 +74,7 @@ public void setUp() throws Exception { .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(100) .build(); + HttpConnector connector = HttpConnectorTest.createHttpConnector(); input = MLRegisterModelInput.builder() .functionName(functionName) @@ -74,6 +86,10 @@ public void setUp() throws Exception { .modelConfig(config) .deployModel(true) .modelNodeIds(new String[]{"modelNodeIds" }) + .isHidden(false) + .description("test description") + .hashValue("hash_value_test") + .connector(connector) .build(); } @@ -155,7 +171,16 @@ public void testToXContent() throws Exception { public void testToXContent_Incomplete() throws Exception { String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + - "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}"; + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + + "\"model_content_hash_value\":\"hash_value_test\",\"deploy_model\":true,\"connector\":" + + "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + + "\"protocol\":\"http\",\"parameters\":{\"input\":\"test input value\"}," + + "\"credential\":{\"key\":\"test_key_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":" + + "\"POST\",\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\",\"pre_process_function\":" + + "\"connector.pre_process.openai.embedding\",\"post_process_function\":" + + "\"connector.post_process.openai.embedding\"}],\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"},\"is_hidden\":false}"; input.setUrl(null); input.setModelConfig(null); input.setModelFormat(null); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 61e57d4ac6..82aa8becde 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,9 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, + 200L, "123", config, 2, + null, null, false, false, false); } @Test @@ -63,27 +65,39 @@ private void readInputStream(MLRegisterModelMetaInput input) throws IOException StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); assertEquals(input.getName(), newInput.getName()); + assertEquals(input.getFunctionName(), newInput.getFunctionName()); + assertEquals(input.getModelGroupId(), newInput.getModelGroupId()); + assertEquals(input.getVersion(), newInput.getVersion()); assertEquals(input.getDescription(), newInput.getDescription()); assertEquals(input.getModelFormat(), newInput.getModelFormat()); assertEquals(input.getModelConfig().getAllConfig(), newInput.getModelConfig().getAllConfig()); assertEquals(input.getModelConfig().getModelType(), newInput.getModelConfig().getModelType()); + assertEquals(input.getModelFormat(), newInput.getModelFormat()); + assertEquals(input.getModelState(), newInput.getModelState()); + assertEquals(input.getModelContentSizeInBytes(), newInput.getModelContentSizeInBytes()); + assertEquals(input.getModelContentHashValue(), newInput.getModelContentHashValue()); + assertEquals(input.getTotalChunks(), newInput.getTotalChunks()); + assertEquals(input.getBackendRoles(), newInput.getBackendRoles()); + assertEquals(input.getIsAddAllBackendRoles(), newInput.getIsAddAllBackendRoles()); + assertEquals(input.getAccessMode(), newInput.getAccessMode()); + assertEquals(input.getDoesVersionCreateModelGroup(), newInput.getDoesVersionCreateModelGroup()); + assertEquals(input.getIsHidden(), newInput.getIsHidden()); } @Test - public void testToXContent() throws IOException {{ - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + - "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; - assertEquals(expected, mlModelContent); - } + public void testToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + - "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; + final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":" + + "\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\"," + + "\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\"," + + "\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\"," + + "\"model_config\":{\"model_type\":\"Model Type\",\"embedding_dimension\":123," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"," + + "\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2," + + "\"add_all_backend_roles\":false,\"does_version_create_model_group\":false,\"is_hidden\":false}"; assertEquals(expected, mlModelContent); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index d7039780f0..5fdf55757c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null, null); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 77001b92e7..bb810b40d3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -63,6 +63,7 @@ public void downloadPrebuiltModelConfig( String modelName = registerModelInput.getModelName(); String version = registerModelInput.getVersion(); MLModelFormat modelFormat = registerModelInput.getModelFormat(); + Boolean isHidden = registerModelInput.getIsHidden(); boolean deployModel = registerModelInput.isDeployModel(); String[] modelNodeIds = registerModelInput.getModelNodeIds(); String modelGroupId = registerModelInput.getModelGroupId(); @@ -94,6 +95,7 @@ public void downloadPrebuiltModelConfig( .url(modelZipFileUrl) .deployModel(deployModel) .modelNodeIds(modelNodeIds) + .isHidden(isHidden) .modelGroupId(modelGroupId) .functionName(FunctionName.from((String) config.get("model_task_type"))); @@ -139,7 +141,7 @@ public void downloadPrebuiltModelConfig( } builder.modelConfig(configBuilder.build()); break; - case MLRegisterModelInput.HASH_VALUE_FIELD: + case MLRegisterModelInput.MODEL_CONTENT_HASH_VALUE_FIELD: builder.hashValue(entry.getValue().toString()); break; default: diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index ec2fc1d141..c44704f688 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -355,7 +355,7 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false); ActionFuture future = client.execute(MLModelGetAction.INSTANCE, getRequest); MLModelGetResponse response = future.actionGet(5000); return response.getMlModel(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index b9d339a338..124145b581 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -23,6 +23,7 @@ import java.util.Set; import java.util.stream.Collectors; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -34,6 +35,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -41,7 +43,6 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelInput; @@ -76,6 +77,8 @@ public class TransportDeployModelAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); + Boolean isHidden = mlModel.getIsHidden(); if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model")); + if (isHidden != null && isHidden) { + if (isSuperAdmin) { + deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener); } else { - String[] targetNodeIds = deployModelRequest.getModelNodeIds(); - boolean deployToAllNodes = targetNodeIds == null || targetNodeIds.length == 0; - if (!allowCustomDeploymentPlan && !deployToAllNodes) { - throw new IllegalArgumentException("Don't allow custom deployment plan"); - } - DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(functionName); - Map nodeMapping = new HashMap<>(); - for (DiscoveryNode node : allEligibleNodes) { - nodeMapping.put(node.getId(), node); - } - - Set allEligibleNodeIds = Arrays - .stream(allEligibleNodes) - .map(DiscoveryNode::getId) - .collect(Collectors.toSet()); - - List eligibleNodes = new ArrayList<>(); - List nodeIds = new ArrayList<>(); - if (!deployToAllNodes) { - for (String nodeId : targetNodeIds) { - if (allEligibleNodeIds.contains(nodeId)) { - eligibleNodes.add(nodeMapping.get(nodeId)); - nodeIds.add(nodeId); - } - } - String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName); - if (workerNodes != null && workerNodes.length > 0) { - Set difference = new HashSet(Arrays.asList(workerNodes)); - difference.removeAll(Arrays.asList(targetNodeIds)); - if (difference.size() > 0) { - wrappedListener - .onFailure( - new IllegalArgumentException( - "Model already deployed to these nodes: " - + Arrays.toString(difference.toArray(new String[0])) - + ", but they are not included in target node ids. Undeploy model from these nodes if don't need them any more." - ) - ); - return; - } - } - } else { - nodeIds.addAll(allEligibleNodeIds); - eligibleNodes.addAll(Arrays.asList(allEligibleNodes)); - } - if (nodeIds.size() == 0) { - wrappedListener.onFailure(new IllegalArgumentException("no eligible node found")); - return; - } - - log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds)); - String localNodeId = clusterService.localNode().getId(); - - FunctionName algorithm = mlModel.getAlgorithm(); - // TODO: Track deploy failure - // mlStats.createCounterStatIfAbsent(algorithm, ActionName.DEPLOY, - // MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); - MLTask mlTask = MLTask - .builder() - .async(true) - .modelId(modelId) - .taskType(MLTaskType.DEPLOY_MODEL) - .functionName(algorithm) - .createTime(Instant.now()) - .lastUpdateTime(Instant.now()) - .state(MLTaskState.CREATED) - .workerNodes(nodeIds) - .build(); - mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { - String taskId = response.getId(); - mlTask.setTaskId(taskId); - if (algorithm == FunctionName.REMOTE) { - mlTaskManager.add(mlTask, nodeIds); - deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); - return; - } - try { - mlTaskManager.add(mlTask, nodeIds); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { wrappedListener - .onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); - threadPool - .executor(DEPLOY_THREAD_POOL) - .execute( - () -> updateModelDeployStatusAndTriggerOnNodesAction( - modelId, - taskId, - mlModel, - localNodeId, - mlTask, - eligibleNodes, - deployToAllNodes + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN ) ); - } catch (Exception ex) { - log.error("Failed to deploy model", ex); - mlTaskManager - .updateMLTask( - taskId, - ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), - TASK_SEMAPHORE_TIMEOUT, - true - ); - wrappedListener.onFailure(ex); + } else { + deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener); } - }, exception -> { - log.error("Failed to create deploy model task for " + modelId, exception); - wrappedListener.onFailure(exception); + }, e -> { + log.error("Failed to Validate Access for ModelId " + modelId, e); + wrappedListener.onFailure(e); })); - } - }, e -> { - log.error("Failed to Validate Access for ModelId " + modelId, e); - wrappedListener.onFailure(e); - })); + } + }, e -> { log.error("Failed to retrieve the ML model with ID: " + modelId, e); wrappedListener.onFailure(e); @@ -267,6 +187,119 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener, + ActionListener listener + ) { + String[] targetNodeIds = deployModelRequest.getModelNodeIds(); + boolean deployToAllNodes = targetNodeIds == null || targetNodeIds.length == 0; + if (!allowCustomDeploymentPlan && !deployToAllNodes) { + throw new IllegalArgumentException("Don't allow custom deployment plan"); + } + DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(mlModel.getAlgorithm()); + Map nodeMapping = new HashMap<>(); + for (DiscoveryNode node : allEligibleNodes) { + nodeMapping.put(node.getId(), node); + } + + Set allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet()); + + List eligibleNodes = new ArrayList<>(); + List nodeIds = new ArrayList<>(); + if (!deployToAllNodes) { + for (String nodeId : targetNodeIds) { + if (allEligibleNodeIds.contains(nodeId)) { + eligibleNodes.add(nodeMapping.get(nodeId)); + nodeIds.add(nodeId); + } + } + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); + if (workerNodes != null && workerNodes.length > 0) { + Set difference = new HashSet(Arrays.asList(workerNodes)); + difference.removeAll(Arrays.asList(targetNodeIds)); + if (difference.size() > 0) { + wrappedListener + .onFailure( + new IllegalArgumentException( + "Model already deployed to these nodes: " + + Arrays.toString(difference.toArray(new String[0])) + + ", but they are not included in target node ids. Undeploy model from these nodes if don't need them any more." + ) + ); + return; + } + } + } else { + nodeIds.addAll(allEligibleNodeIds); + eligibleNodes.addAll(Arrays.asList(allEligibleNodes)); + } + if (nodeIds.size() == 0) { + wrappedListener.onFailure(new IllegalArgumentException("no eligible node found")); + return; + } + + log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds)); + String localNodeId = clusterService.localNode().getId(); + + FunctionName algorithm = mlModel.getAlgorithm(); + // TODO: Track deploy failure + // mlStats.createCounterStatIfAbsent(algorithm, ActionName.DEPLOY, + // MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + MLTask mlTask = MLTask + .builder() + .async(true) + .modelId(modelId) + .taskType(MLTaskType.DEPLOY_MODEL) + .functionName(algorithm) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .state(MLTaskState.CREATED) + .workerNodes(nodeIds) + .build(); + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + mlTask.setTaskId(taskId); + if (algorithm == FunctionName.REMOTE) { + mlTaskManager.add(mlTask, nodeIds); + deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); + return; + } + try { + mlTaskManager.add(mlTask, nodeIds); + wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); + threadPool + .executor(DEPLOY_THREAD_POOL) + .execute( + () -> updateModelDeployStatusAndTriggerOnNodesAction( + modelId, + taskId, + mlModel, + localNodeId, + mlTask, + eligibleNodes, + deployToAllNodes + ) + ); + } catch (Exception ex) { + log.error("Failed to deploy model", ex); + mlTaskManager + .updateMLTask( + taskId, + ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), + TASK_SEMAPHORE_TIMEOUT, + true + ); + wrappedListener.onFailure(ex); + } + }, exception -> { + log.error("Failed to create deploy model task for " + modelId, exception); + wrappedListener.onFailure(exception); + })); + } + @VisibleForTesting void deployRemoteModel( MLModel mlModel, @@ -405,4 +438,10 @@ void updateModelDeployStatusAndTriggerOnNodesAction( ); } + // this method is only to stub static method. + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 682189ad34..ecc0e3fd3c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -29,6 +29,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.ExistsQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.indices.InvalidIndexNameException; import org.opensearch.ml.common.CommonValue; @@ -97,6 +98,29 @@ public void search(SearchRequest request, ActionListener actionL Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null), excludes.toArray(new String[0]) ); + + // Check if the original query is not null before adding it to the must clause + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + + // Create a BoolQueryBuilder for the should clauses + BoolQueryBuilder shouldQuery = QueryBuilders.boolQuery(); + + // Add a should clause to include documents where IS_HIDDEN_FIELD is false + shouldQuery.should(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, false)); + + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null + shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLModel.IS_HIDDEN_FIELD))); + + // Set minimum should match to 1 to ensure at least one of the should conditions is met + shouldQuery.minimumShouldMatch(1); + + // Add the shouldQuery to the main queryBuilder + queryBuilder.filter(shouldQuery); + + request.source().query(queryBuilder); request.source().fetchSource(rebuiltFetchSourceContext); if (modelAccessControlHelper.skipModelAccessControl(user)) { client.search(request, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 9070781d63..984dbdd451 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; @@ -24,6 +25,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -35,7 +37,6 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; @@ -64,6 +65,8 @@ public class DeleteModelTransportAction extends HandledTransportAction actionListener) { MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request); String modelId = mlModelDeleteRequest.getModelId(); - MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false); + MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false); FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); + boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); @@ -103,34 +108,55 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (!access) { - wrappedListener - .onFailure( - new MLValidationException("User doesn't have privilege to perform this operation on this model") - ); - } else if (mlModelState.equals(MLModelState.LOADED) - || mlModelState.equals(MLModelState.LOADING) - || mlModelState.equals(MLModelState.PARTIALLY_LOADED) - || mlModelState.equals(MLModelState.DEPLOYED) - || mlModelState.equals(MLModelState.DEPLOYING) - || mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) { + if (isHidden != null && isHidden) { + if (!isSuperAdmin) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + if (isModelNotDeployed(mlModelState)) { + deleteModel(modelId, actionListener); + } else { wrappedListener .onFailure( new Exception( "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete" ) ); - } else { - deleteModel(modelId, actionListener); } - }, e -> { - log.error("Failed to validate Access for Model Id " + modelId, e); - wrappedListener.onFailure(e); - })); + } + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else if (isModelNotDeployed(mlModelState)) { + deleteModel(modelId, actionListener); + } else { + wrappedListener + .onFailure( + new Exception( + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete" + ) + ); + } + }, e -> { + log.error("Failed to validate Access for Model Id " + modelId, e); + wrappedListener.onFailure(e); + })); + } } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); @@ -201,4 +227,19 @@ public void onFailure(Exception e) { } }); } + + private Boolean isModelNotDeployed(MLModelState mlModelState) { + return !mlModelState.equals(MLModelState.LOADED) + && !mlModelState.equals(MLModelState.LOADING) + && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) + && !mlModelState.equals(MLModelState.DEPLOYED) + && !mlModelState.equals(MLModelState.DEPLOYING) + && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); + } + + // this method is only to stub static method. + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 3e508f1f64..0d5fc5f659 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -8,18 +8,19 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -30,7 +31,6 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; @@ -40,6 +40,8 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.AccessLevel; import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; @@ -54,17 +56,21 @@ public class GetModelTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); @@ -84,30 +91,45 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (!access) { - wrappedListener - .onFailure( - new MLValidationException("User Doesn't have privilege to perform this operation on this model") - ); - } else { - log.debug("Completed Get Model Request, id:{}", modelId); - Connector connector = mlModel.getConnector(); - if (connector != null) { - connector.removeCredential(); + if (isHidden != null && isHidden) { + if (isSuperAdmin || !mlModelGetRequest.isUserInitiatedGetRequest()) { + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); } - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } - }, e -> { - log.error("Failed to validate Access for Model Id " + modelId, e); - wrappedListener.onFailure(e); - })); - + }, e -> { + log.error("Failed to validate Access for Model Id " + modelId, e); + wrappedListener.onFailure(e); + })); + } } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); @@ -135,4 +157,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Client client; + + Settings settings; + ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelManager mlModelManager; @@ -70,7 +77,9 @@ public UpdateModelTransportAction( ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, MLModelManager mlModelManager, - MLModelGroupManager mlModelGroupManager + MLModelGroupManager mlModelGroupManager, + Settings settings, + ClusterService clusterService ) { super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); this.client = client; @@ -78,6 +87,8 @@ public UpdateModelTransportAction( this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; this.mlModelGroupManager = mlModelGroupManager; + this.clusterService = clusterService; + this.settings = settings; } @Override @@ -86,6 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { FunctionName functionName = mlModel.getAlgorithm(); MLModelState mlModelState = mlModel.getModelState(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - if (isModelDeployed(mlModelState)) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); + if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { + if (isSuperAdmin) { + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, + RestStatus.FORBIDDEN + ) + ); + } + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + if (isModelDeployed(mlModelState)) { + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "ML Model " + + modelId + + " is in deploying or deployed state, please undeploy the models first!", + RestStatus.FORBIDDEN + ) + ); + } } else { actionListener .onFailure( new OpenSearchStatusException( - "ML Model " - + modelId - + " is in deploying or deployed state, please undeploy the models first!", + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, RestStatus.FORBIDDEN ) ); } - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model, model ID " + modelId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - actionListener.onFailure(exception); - })); + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } + } else { actionListener .onFailure( @@ -396,4 +424,9 @@ private Boolean isModelDeployed(MLModelState mlModelState) { && !mlModelState.equals(MLModelState.DEPLOYING) && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); } + + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 98927a5e5a..8aca5cb140 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -79,6 +79,8 @@ public class TransportRegisterModelAction extends HandledTransportAction trustedUrlRegex = it); @@ -139,6 +142,7 @@ public TransportRegisterModelAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client)); if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { if (modelGroups != null diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index a4f0b9f2f9..da030239ae 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -5,15 +5,18 @@ package org.opensearch.ml.action.undeploy; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.MLModel; @@ -33,6 +36,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -43,6 +48,8 @@ public class TransportUndeployModelsAction extends HandledTransportAction { listener.onResponse(new MLUndeployModelsResponse(r)); - }, listener::onFailure)); + }, e -> { listener.onFailure(e); })); } else { listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId)); } @@ -114,18 +123,39 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { User user = RestActionUtils.getUserContext(client); + boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { - modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, listener); + Boolean isHidden = mlModel.getIsHidden(); + if (isHidden != null && isHidden) { + if (isSuperAdmin) { + listener.onResponse(true); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } + } else { + modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, listener); + } }, e -> { log.error("Failed to find Model", e); listener.onFailure(e); - }), () -> context.restore())); + }), context::restore)); } catch (Exception e) { log.error("Failed to undeploy ML model"); listener.onFailure(e); } } + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index dd1deac4ab..2402374a99 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -124,6 +124,7 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Files; @@ -285,11 +286,17 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .totalChunks(mlRegisterModelMetaInput.getTotalChunks()) .modelContentHash(mlRegisterModelMetaInput.getModelContentHashValue()) .modelContentSizeInBytes(mlRegisterModelMetaInput.getModelContentSizeInBytes()) + .isHidden(mlRegisterModelMetaInput.getIsHidden()) .createdTime(now) .lastUpdateTime(now) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), EMPTY_PARAMS)); + + if (mlRegisterModelMetaInput.getIsHidden()) { + indexRequest.id(modelName); + } + indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(indexRequest, ActionListener.wrap(response -> { @@ -510,9 +517,13 @@ private void indexRemoteModel( .modelConfig(registerModelInput.getModelConfig()) .createdTime(now) .lastUpdateTime(now) + .isHidden(registerModelInput.getIsHidden()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); + if (registerModelInput.getIsHidden()) { + indexModelMetaRequest.id(modelName); + } indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -542,7 +553,8 @@ private void indexRemoteModel( } } - private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { + @VisibleForTesting + void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -568,8 +580,12 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml .modelConfig(registerModelInput.getModelConfig()) .createdTime(now) .lastUpdateTime(now) + .isHidden(registerModelInput.getIsHidden()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); + if (registerModelInput.getIsHidden()) { + indexModelMetaRequest.id(modelName); + } indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); // create model meta doc @@ -627,11 +643,15 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .modelConfig(registerModelInput.getModelConfig()) .createdTime(now) .lastUpdateTime(now) + .isHidden(registerModelInput.getIsHidden()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (functionName == FunctionName.METRICS_CORRELATION) { indexModelMetaRequest.id(functionName.name()); } + if (registerModelInput.getIsHidden()) { + indexModelMetaRequest.id(modelName); + } indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); // create model meta doc @@ -706,8 +726,12 @@ private void registerModel( .content(Base64.getEncoder().encodeToString(bytes)) .createdTime(now) .lastUpdateTime(now) + .isHidden(registerModelInput.getIsHidden()) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + if (registerModelInput.getIsHidden()) { + indexRequest.id(modelName); + } String chunkId = getModelChunkId(modelId, chunkNum); indexRequest.id(chunkId); indexRequest.source(mlModel.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); @@ -822,7 +846,8 @@ private void updateModelRegisterStateAsDone( })); } - private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { + @VisibleForTesting + void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { String[] modelNodeIds = registerModelInput.getModelNodeIds(); log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java index 86367fe4b1..097bc6fb77 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java @@ -60,6 +60,6 @@ MLModelGetRequest getRequest(RestRequest request) throws IOException { String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); - return new MLModelGetRequest(modelId, returnContent); + return new MLModelGetRequest(modelId, returnContent, true); } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 23bb3f21a4..98f5f87d22 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -9,10 +9,16 @@ import static org.opensearch.ml.common.MLModel.OLD_MODEL_CONTENT_FIELD; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Optional; +import java.util.Set; + +import javax.naming.InvalidNameException; +import javax.naming.ldap.LdapName; import org.apache.commons.lang3.ArrayUtils; import org.apache.logging.log4j.LogManager; @@ -21,6 +27,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.Strings; @@ -37,6 +44,8 @@ public class RestActionUtils { private static final Logger logger = LogManager.getLogger(RestActionUtils.class); + public static final String SECURITY_AUTHCZ_ADMIN_DN = "plugins.security.authcz.admin_dn"; + public static final String PARAMETER_ALGORITHM = "algorithm"; public static final String PARAMETER_ASYNC = "async"; public static final String PARAMETER_RETURN_CONTENT = "return_content"; @@ -49,6 +58,13 @@ public class RestActionUtils { public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; public static final String[] UI_METADATA_EXCLUDE = new String[] { "ui_metadata" }; + public static final String PARAMETER_TOOL_NAME = "tool_name"; + + public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; + public static final String OPENDISTRO_SECURITY_SSL_PRINCIPAL = OPENDISTRO_SECURITY_CONFIG_PREFIX + "ssl_principal"; + + static final Set adminDn = new HashSet<>(); + public static String getAlgorithm(RestRequest request) { String algorithm = request.param(PARAMETER_ALGORITHM); if (Strings.isNullOrEmpty(algorithm)) { @@ -190,4 +206,43 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + // TODO: Integration test needs to be added (MUST) + public static boolean isSuperAdminUser(ClusterService clusterService, Client client) { + + final List adminDnsA = clusterService.getSettings().getAsList(SECURITY_AUTHCZ_ADMIN_DN, Collections.emptyList()); + + for (String dn : adminDnsA) { + try { + logger.debug("{} is registered as an admin dn", dn); + adminDn.add(new LdapName(dn)); + } catch (final InvalidNameException e) { + logger.error("Unable to parse admin dn {}", dn, e); + } + } + + ThreadContext threadContext = client.threadPool().getThreadContext(); + final String sslPrincipal = threadContext.getTransient(OPENDISTRO_SECURITY_SSL_PRINCIPAL); + return isAdminDN(sslPrincipal); + } + + private static boolean isAdminDN(String dn) { + if (dn == null) + return false; + try { + return isAdminDN(new LdapName(dn)); + } catch (InvalidNameException e) { + return false; + } + } + + private static boolean isAdminDN(LdapName dn) { + if (dn == null) + return false; + boolean isAdmin = adminDn.contains(dn); + if (logger.isTraceEnabled()) { + logger.trace("Is principal {} an admin cert? {}", dn.toString(), isAdmin); + } + return isAdmin; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index 6a1fb3bb50..c4fb3f906c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -385,7 +385,7 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, true); MLModelGetResponse response = client().execute(MLModelGetAction.INSTANCE, getRequest).actionGet(5000); return response.getMlModel(); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index b1c2a49479..143482ccc4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.isA; @@ -39,6 +40,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; import org.opensearch.action.index.IndexResponse; @@ -146,6 +148,7 @@ public void setup() { settings = Settings.builder().put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), true).build(); clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); @@ -217,6 +220,97 @@ public void testDoExecute_success() { verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class)); } + public void testDoExecute_success_hidden_model() { + transportDeployModelAction = spy( + new TransportDeployModelAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + clusterService, + threadPool, + client, + namedXContentRegistry, + nodeFilter, + mlTaskDispatcher, + mlModelManager, + mlStats, + settings, + modelAccessControlHelper, + mlFeatureEnabledSetting + ) + ); + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); + when(mlModel.getIsHidden()).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockIndexId"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + ActionListener deployModelResponseListener = mock(ActionListener.class); + doReturn(true).when(transportDeployModelAction).isSuperAdminUserWrapper(clusterService, client); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class)); + } + + public void testDoExecute_no_permission_hidden_model() { + transportDeployModelAction = spy( + new TransportDeployModelAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + clusterService, + threadPool, + client, + namedXContentRegistry, + nodeFilter, + mlTaskDispatcher, + mlModelManager, + mlStats, + settings, + modelAccessControlHelper, + mlFeatureEnabledSetting + ) + ); + + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); + when(mlModel.getIsHidden()).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockIndexId"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + doReturn(false).when(transportDeployModelAction).isSuperAdminUserWrapper(clusterService, client); + ActionListener deployModelResponseListener = mock(ActionListener.class); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(deployModelResponseListener).onFailure(argumentCaptor.capture()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); + } + public void testDoExecute_userHasNoAccessException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); @@ -236,7 +330,7 @@ public void testDoExecute_userHasNoAccessException() { transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(deployModelResponseListener).onFailure(argumentCaptor.capture()); - assertEquals("User Doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } public void testDoExecuteRemoteInferenceDisabled() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 57051643a6..b01d28e20f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -109,6 +110,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, xContentRegistry, clusterService, modelAccessControlHelper @@ -122,6 +124,7 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); threadContext = new ThreadContext(settings); + when(clusterService.getSettings()).thenReturn(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -140,7 +143,7 @@ public void testDeleteModel_Success() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -151,6 +154,60 @@ public void testDeleteModel_Success() throws IOException { verify(actionListener).onResponse(deleteResponse); } + public void testDeleteHiddenModel_Success() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, true); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doReturn(true).when(deleteModelTransportAction).isSuperAdminUserWrapper(clusterService, client); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + public void testDeleteHiddenModel_NoSuperAdminPermission() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, true); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doReturn(false).when(deleteModelTransportAction).isSuperAdminUserWrapper(clusterService, client); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); + } + public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -165,7 +222,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -177,7 +234,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { } public void test_UserHasNoAccessException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID", false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -197,7 +254,7 @@ public void test_UserHasNoAccessException() throws IOException { } public void testDeleteModel_CheckModelState() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null); + GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -240,7 +297,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -254,7 +311,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { } public void test_ValidationFailedException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -298,7 +355,7 @@ public void testDeleteModelChunks_Success() { } public void testDeleteModel_RuntimeException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -392,8 +449,9 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() { assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID) throws IOException { - MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).build(); + public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID, boolean isHidden) throws IOException { + MLModel mlModel; + mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build(); XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index 238b9fb47e..8a95805aa8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -65,6 +66,8 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { @Mock ClusterService clusterService; + private Settings settings; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -79,10 +82,18 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").build(); - Settings settings = Settings.builder().build(); + settings = Settings.builder().build(); getModelTransportAction = spy( - new GetModelTransportAction(transportService, actionFilters, client, xContentRegistry, clusterService, modelAccessControlHelper) + new GetModelTransportAction( + transportService, + actionFilters, + client, + settings, + xContentRegistry, + clusterService, + modelAccessControlHelper + ) ); doAnswer(invocation -> { @@ -93,6 +104,7 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); + when(clusterService.getSettings()).thenReturn(settings); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -103,17 +115,57 @@ public void testGetModel_UserHasNodeAccess() throws IOException { return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - GetResponse getResponse = prepareMLModel(); + GetResponse getResponse = prepareMLModel(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); + } + + public void testGetModel_Success() throws IOException { + GetResponse getResponse = prepareMLModel(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + verify(actionListener).onResponse(any(MLModelGetResponse.class)); + } + + public void testGetModelHidden_Success() throws IOException { + GetResponse getResponse = prepareMLModel(true); + mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").isUserInitiatedGetRequest(true).build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + doReturn(true).when(getModelTransportAction).isSuperAdminUserWrapper(clusterService, client); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + verify(actionListener).onResponse(any(MLModelGetResponse.class)); + } + public void testGetModelHidden_SuperUserPermissionError() throws IOException { + GetResponse getResponse = prepareMLModel(true); + mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").isUserInitiatedGetRequest(true).build(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doReturn(false).when(getModelTransportAction).isSuperAdminUserWrapper(clusterService, client); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("User Doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } public void testGetModel_ValidateAccessFailed() throws IOException { @@ -123,7 +175,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException { return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - GetResponse getResponse = prepareMLModel(); + GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); @@ -172,12 +224,13 @@ public void testGetModel_RuntimeException() { assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareMLModel() throws IOException { + public GetResponse prepareMLModel(boolean isHidden) throws IOException { MLModel mlModel = MLModel .builder() .modelId("test_id") .modelState(MLModelState.REGISTERED) .algorithm(FunctionName.TEXT_EMBEDDING) + .isHidden(isHidden) .build(); XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 85dfaa552a..0dca491658 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -32,6 +32,7 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -52,6 +53,7 @@ import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; @@ -123,6 +125,11 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { MLModel localModel; ThreadContext threadContext; + @Mock + ClusterService clusterService; + + @Mock + MLEngine mlEngine; @Before public void setup() throws IOException { @@ -165,14 +172,17 @@ public void setup() throws IOException { connectorAccessControlHelper, modelAccessControlHelper, mlModelManager, - mlModelGroupManager + mlModelGroupManager, + settings, + clusterService ) ); - localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING, false); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.getSettings()).thenReturn(settings); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); @@ -389,7 +399,7 @@ public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { @Test public void testUpdateRemoteModelWithLocalInformationSuccess() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -402,7 +412,7 @@ public void testUpdateRemoteModelWithLocalInformationSuccess() { @Test public void testUpdateRemoteModelWithRemoteInformationSuccess() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -413,6 +423,37 @@ public void testUpdateRemoteModelWithRemoteInformationSuccess() { verify(actionListener).onResponse(updateResponse); } + @Test + public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + doReturn(true).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateHiddenRemoteModelPermissionError() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + doReturn(false).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model, model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); @@ -433,7 +474,7 @@ public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { @Test public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -457,7 +498,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl @Test public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { - MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE, false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(remoteModel); @@ -777,7 +818,7 @@ public void testGetUpdateResponseListenerOtherException() { // TODO: Add UT to make sure that version incremented successfully. - private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { + private MLModel prepareMLModel(FunctionName functionName, boolean isHidden) throws IllegalArgumentException { MLModel mlModel; switch (functionName) { case TEXT_EMBEDDING: @@ -789,6 +830,7 @@ private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgument .description("test_description") .modelState(MLModelState.REGISTERED) .algorithm(FunctionName.TEXT_EMBEDDING) + .isHidden(isHidden) .build(); return mlModel; case REMOTE: @@ -801,6 +843,7 @@ private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgument .modelState(MLModelState.REGISTERED) .algorithm(FunctionName.REMOTE) .connectorId("test_connector_id") + .isHidden(isHidden) .build(); return mlModel; default: diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index ac1f09dea1..0222b4efe1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -164,6 +164,7 @@ public void setup() throws IOException { ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); transportRegisterModelAction = new TransportRegisterModelAction( transportService, actionFilters, diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 8e2599a6c1..42152f473d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -10,21 +10,28 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; @@ -35,6 +42,7 @@ import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; @@ -65,6 +73,9 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; + @Mock + private ClusterName clusterName; + @Mock Client client; @@ -83,6 +94,9 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { @Mock MLModelManager mlModelManager; + @Mock + MLUndeployModelNodeResponse mlUndeployModelNodeResponse; + @Mock ModelAccessControlHelper modelAccessControlHelper; @@ -105,19 +119,23 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - transportUndeployModelsAction = new TransportUndeployModelsAction( - transportService, - actionFilters, - modelHelper, - mlTaskManager, - clusterService, - threadPool, - client, - xContentRegistry, - nodeFilter, - mlTaskDispatcher, - mlModelManager, - modelAccessControlHelper + Settings settings = Settings.builder().build(); + transportUndeployModelsAction = spy( + new TransportUndeployModelsAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + clusterService, + threadPool, + client, + settings, + xContentRegistry, + nodeFilter, + mlTaskDispatcher, + mlModelManager, + modelAccessControlHelper + ) ); when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(true); @@ -126,7 +144,27 @@ public void setup() throws IOException { ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.getSettings()).thenReturn(settings); + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId("someModelId") + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(false) + .build(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + } + public void testHiddenModelSuccess() { MLModel mlModel = MLModel .builder() .user(User.parse(USER_STRING)) @@ -137,12 +175,63 @@ public void setup() throws IOException { .algorithm(FunctionName.BATCH_RCF) .content("content") .totalChunks(2) + .isHidden(true) .build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(mlModel); return null; }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + transportUndeployModelsAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + } + + public void testHiddenModelPermissionError() { + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId("someModelId") + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + doReturn(false).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + transportUndeployModelsAction.doExecute(task, request, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } public void testDoExecute() { @@ -152,15 +241,17 @@ public void testDoExecute() { return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); - MLUndeployModelsResponse mlUndeployModelsResponse = new MLUndeployModelsResponse(mock(MLUndeployModelNodesResponse.class)); + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlUndeployModelsResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); transportUndeployModelsAction.doExecute(task, request, actionListener); - verify(actionListener).onFailure(isA(Exception.class)); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); } public void testDoExecute_modelAccessControl_notEnabled() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 14034aa93e..598d2db5a7 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -98,6 +98,7 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; @@ -216,6 +217,7 @@ public void setup() throws URISyntaxException { .modelFormat(modelFormat) .modelConfig(modelConfig) .url(url) + .isHidden(false) .build(); Map> stats = new ConcurrentHashMap<>(); @@ -380,7 +382,7 @@ public void testRegisterMLModel_DownloadModelFileFailure() { verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any()); } - public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException { + public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException, IOException { doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); @@ -392,6 +394,23 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE listener.onResponse(pretrainedInput); return null; }).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any()); + + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + doAnswer(invocation -> { + ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; + indexResponseActionListener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + String[] newChunks = createTempChunkFiles(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(7); + Map result = new HashMap<>(); + result.put(MODEL_SIZE_IN_BYTES, modelContentSize); + result.put(CHUNK_FILES, Arrays.asList(newChunks[0], newChunks[1])); + result.put(MODEL_FILE_HASH, randomAlphaOfLength(10)); + listener.onResponse(result); + return null; + }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any(), any()); MLTask pretrainedTask = MLTask .builder() .taskId("pretrained") @@ -409,6 +428,50 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE ); } + public void testRegisterMLRemoteModel() throws PrivilegedActionException { + ActionListener listener = mock(ActionListener.class); + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); + MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); + MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + doAnswer(invocation -> { + ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; + indexResponseActionListener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + when(indexResponse.getId()).thenReturn("mockIndexId"); + modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); + } + + public void testIndexRemoteModel() throws PrivilegedActionException { + ActionListener listener = mock(ActionListener.class); + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); + MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); + MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + doAnswer(invocation -> { + ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; + indexResponseActionListener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + when(indexResponse.getId()).thenReturn("mockIndexId"); + modelManager.indexRemoteModel(pretrainedInput, pretrainedTask, "1.0.0"); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); + verify(modelManager).deployModelAfterRegistering(any(), anyString()); + + } + @Ignore public void testRegisterMLModel_DownloadModelFile() throws IOException { doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); @@ -862,6 +925,9 @@ private void setUpMock_DownloadModelFile(String[] chunks, Long modelContentSize) @Mock private IndexResponse indexResponse; + @Mock + private UpdateResponse updateResponse; + private String[] createTempChunkFiles() throws IOException { String tmpFolder = randomAlphaOfLength(10); String newChunk0 = chunk0.substring(0, chunk0.length() - 2) + "/" + tmpFolder + "/0"; @@ -949,6 +1015,7 @@ private MLRegisterModelMetaInput prepareRequest() { ) ) .totalChunks(2) + .isHidden(true) .build(); return input; } @@ -961,6 +1028,20 @@ private MLRegisterModelInput mockPretrainedInput() { .modelGroupId("modelGroupId") .modelFormat(modelFormat) .functionName(FunctionName.SPARSE_ENCODING) + .isHidden(true) + .build(); + } + + private MLRegisterModelInput mockRemoteModelInput(boolean isHidden) { + return MLRegisterModelInput + .builder() + .modelName(modelName) + .version(version) + .modelGroupId("modelGroupId") + .modelFormat(modelFormat) + .functionName(FunctionName.REMOTE) + .isHidden(isHidden) + .deployModel(true) .build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 12bc3737fb..8655a4eb06 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -76,6 +76,7 @@ public void setup() { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java index ee1272f2fb..bad05e5aad 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java @@ -7,6 +7,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.OPENSEARCH_DASHBOARDS_USER_AGENT; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; @@ -14,15 +15,20 @@ import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.UI_METADATA_EXCLUDE; +import java.net.InetAddress; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; @@ -32,6 +38,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.plugin.MachineLearningPlugin; @@ -59,6 +66,7 @@ public class RestActionUtilsTests extends OpenSearchTestCase { public void setup() { param = ImmutableMap.builder().put(PARAMETER_ALGORITHM, algoName).build(); fakeRestRequest = createRestRequest(param); + } private FakeRestRequest createRestRequest(Map param) { @@ -81,6 +89,52 @@ public void testGetAlgorithm_EmptyValue() { RestActionUtils.getAlgorithm(fakeRestRequest); } + @Test + public void testReturnContent() { + RestRequest request = mock(RestRequest.class); + when(request.paramAsBoolean("return_content", false)).thenReturn(true); + Assert.assertTrue(RestActionUtils.returnContent(request)); + + when(request.paramAsBoolean("return_content", false)).thenReturn(false); + Assert.assertFalse(RestActionUtils.returnContent(request)); + } + + @Test + public void testGetAllNodes() { + + DiscoveryNode localNode = new DiscoveryNode( + "mockClusterManagerNodeId", + "mockClusterManagerNodeId", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + ClusterState clusterState = mock(ClusterState.class); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + DiscoveryNode dataNode = mock(DiscoveryNode.class); + ClusterService clusterService = mock(ClusterService.class); + + when(dataNode.getId()).thenReturn("mockDataNodeId"); + final Map dataNodes = Map.of("0", dataNode); + when(discoveryNodes.getDataNodes()).thenReturn(dataNodes); + when(clusterState.nodes()).thenReturn(discoveryNodes); + when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getSize()).thenReturn(2); // Assuming 2 nodes in the cluster + // Mock two discovery nodes + DiscoveryNode node1 = mock(DiscoveryNode.class); + DiscoveryNode node2 = mock(DiscoveryNode.class); + when(node1.getId()).thenReturn("node1"); + when(node2.getId()).thenReturn("node2"); + when(discoveryNodes.iterator()).thenReturn(Arrays.asList(node1, node2).iterator()); + + String[] nodeIds = RestActionUtils.getAllNodes(clusterService); + Assert.assertArrayEquals(new String[] { "node1", "node2" }, nodeIds); + } + public void testIsAsync() { fakeRestRequest = createRestRequest(ImmutableMap.builder().put(PARAMETER_ASYNC, "true").build()); boolean isAsync = RestActionUtils.isAsync(fakeRestRequest); @@ -230,4 +284,37 @@ public void test_getUserContext() { User user = RestActionUtils.getUserContext(client); assertNotNull(user); } + + @Test + public void testIsSuperAdminUser() { + ClusterService clusterService = mock(ClusterService.class); + Client client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(clusterService.getSettings()) + .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); + when(client.threadPool()).thenReturn(mock(ThreadPool.class)); + when(client.threadPool().getThreadContext()).thenReturn(threadContext); + + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=admin"); + + boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); + Assert.assertTrue(isAdmin); + } + + @Test + public void testIsSuperAdminUser_NotAdmin() { + ClusterService clusterService = mock(ClusterService.class); + Client client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(clusterService.getSettings()) + .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); + when(client.threadPool()).thenReturn(mock(ThreadPool.class)); + when(client.threadPool().getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=notadmin"); + + boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); + Assert.assertFalse(isAdmin); + } }