diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index ca226f0251..e13ea03173 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; @@ -89,6 +90,7 @@ public class TransportRegisterModelAction extends HandledTransportAction trustedConnectorEndpointsRegex; ModelAccessControlHelper modelAccessControlHelper; + private volatile boolean isModelUrlAllowed; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelGroupManager mlModelGroupManager; @@ -132,6 +134,9 @@ public TransportRegisterModelAction( trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_URL_REGEX, it -> trustedUrlRegex = it); + isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -142,6 +147,11 @@ public TransportRegisterModelAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { + throw new IllegalArgumentException( + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client)); if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 9e76a48c97..631462e773 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -7,7 +7,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; @@ -35,7 +34,6 @@ public class RestMLRegisterModelAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action"; - private volatile boolean isModelUrlAllowed; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** @@ -51,8 +49,6 @@ public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting * @param settings settings */ public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { - isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @@ -103,11 +99,6 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - if (mlInput.getUrl() != null && !isModelUrlAllowed) { - throw new IllegalArgumentException( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - } return new MLRegisterModelRequest(mlInput); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index d5c1347e26..5b2eae3f4a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -5,11 +5,14 @@ package org.opensearch.ml.action.models; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.settings.Settings; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.MLCommonsIntegTestCase; @@ -187,4 +190,9 @@ private void test_matchPhrase_search() { assertEquals(1, response.getHits().getTotalHits().value); } + @Override + protected Settings nodeSettings(int ordinal) { + return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build(); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b40a278289..83ecd01069 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -155,12 +156,14 @@ public void setup() throws IOException { settings = Settings .builder() .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true) .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() { assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage()); } + public void testRegisterModelUrlNotAllowed() throws Exception { + Settings settings = Settings + .builder() + .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false) + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + transportRegisterModelAction = new TransportRegisterModelAction( + transportService, + actionFilters, + modelHelper, + mlIndicesHandler, + mlModelManager, + mlTaskManager, + clusterService, + settings, + threadPool, + client, + nodeFilter, + mlTaskDispatcher, + mlStats, + modelAccessControlHelper, + connectorAccessControlHelper, + mlModelGroupManager + ); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener) + ); + assertEquals( + e.getMessage(), + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } + public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 8655a4eb06..b3f3e3f956 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -147,20 +147,6 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { restMLRegisterModelAction.handleRequest(request, channel, client); } - public void testRegisterModelUrlNotAllowed() throws Exception { - settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - RestRequest request = getRestRequest(); - restMLRegisterModelAction.handleRequest(request, channel, client); - } - public void testRegisterModelRequestWithNullUrlAndUrlNotAllowed() throws Exception { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);