From 6bff20d5ea2dd16d40904d7eef1c22ffeda5b9ea Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 31 Jan 2024 18:14:04 -0800 Subject: [PATCH] Move allow model setting from rest to transport (#1961) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Backport multiple PRs to main from 2.x (#1652) * fix parameter name in preprocess function; fix remote model function … (#1362) * fix parameter name in preprocess function; fix remote model function name Signed-off-by: Yaliang Wu * fix failed unit test Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu * throw exception when model group not found during update request (#1447) Signed-off-by: Bhavana Ramaram * add status code to model tensor (#1443) (#1453) Signed-off-by: Yaliang Wu * register new versions to a model group based on the name provided (#1452) Signed-off-by: Bhavana Ramaram * fixing metrics correlation algorithm (#1448) * fixing metrics correlation algorithm Signed-off-by: Dhrubo Saha * if model version fails to register, update model group accordingly (#1463) * if model version fails to register, update model group accordingly Signed-off-by: Bhavana Ramaram * Update Model API (#1350) * Update Model API POC Signed-off-by: Sicheng Song * Using GetRequest to get model Signed-off-by: Sicheng Song * Finalize model update API Signed-off-by: Sicheng Song * Fix compile Signed-off-by: Sicheng Song * Fix compileTest Signed-off-by: Sicheng Song * Add Unit Test Cases for Update Model API Signed-off-by: Sicheng Song * Tune back test coverage thereshold Signed-off-by: Sicheng Song * Add more unit tests on Update model API Signed-off-by: Sicheng Song * Add unit test for TransportUpdateModelAction class Signed-off-by: Sicheng Song * Fix a test error Signed-off-by: Sicheng Song * Change exception thrown to failure response Signed-off-by: Sicheng Song * Move the function judgement to the outter block Signed-off-by: Sicheng Song * Check if model is undeployed before update model Signed-off-by: Sicheng Song * Add more unit test for update model API Signed-off-by: Sicheng Song * Fix unit test due to blocking java 11 CI workflow Signed-off-by: Sicheng Song * Enabling auto bumping model version during registering to a new model group and address reviewers' other concern Signed-off-by: Sicheng Song * Autobump new model groups' latest version when register to a new model Signed-off-by: Sicheng Song * Change the REST API method from POST to PUT Signed-off-by: Sicheng Song * Change the update REST API endpoint Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * Add a setting to control the update connector API (#1465) * Add a setting to control the update connector API Signed-off-by: Sicheng Song * Enabling the update connnector setting in unit test Signed-off-by: Sicheng Song * Enabling the update connnector setting in corresponding unit test Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * fix update connector API (#1484) * fix update connector API Signed-off-by: Yaliang Wu * Performance enhacement for predict action by caching model info (#1472) (#1508) * Performance enhacement for predict action by caching model info Signed-off-by: zane-neo * Add context.restore() to avoid missing info Signed-off-by: zane-neo --------- Signed-off-by: zane-neo (cherry picked from commit a985f6ec6dc280072b7045dcf4851959aa575c54) Co-authored-by: zane-neo * fix failed ut from PR 1472 (#1479) (#1510) * fix failed ut from PR 1472 Signed-off-by: Yaliang Wu * exclude class for low coverage Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit da5d82942385c34544016cf361b517c2bb3d36c4) Co-authored-by: Yaliang Wu * [Backport to 2.11] throw exception if remote model doesn't return 2xx status code; fix p… (#1477) (#1509) * throw exception if remote model doesn't return 2xx status code; fix p… (#1473) * throw exception if remote model doesn't return 2xx status code; fix predict runner Signed-off-by: Yaliang Wu * fix kmeans model deploy bug Signed-off-by: Yaliang Wu * support multiple docs for remote embedding model Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu * fix wrong class Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 201c8a89b07126c5da9ed2e743c7f1b0e4806e12) Co-authored-by: Yaliang Wu * fix no worker node exception for remote embedding model (#1482) (#1511) * fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu * only add model info to cache if model cache exist Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 6f83b9fee002026d7a8d0fa3550fe8cf80b30371) Co-authored-by: Yaliang Wu * fix for delete model group API throwing incorrect error when model index not created (#1485) (#1486) (#1512) * fix for delete model group API throwing incorrect error when model index not created Signed-off-by: Bhavana Ramaram (cherry picked from commit 60ef0fd6dfeda18729f1dc2ec6ea9c0418c6ff69) Co-authored-by: Bhavana Ramaram (cherry picked from commit 55446819b7686e14cf9e1d10edf7956ed57148c7) Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> * fix no worker node error on multi-node cluster (#1487) (#1513) Signed-off-by: Yaliang Wu (cherry picked from commit cea1cd675cd95e37c29a3bd88c1cd9d58e81b20a) Co-authored-by: Yaliang Wu * add prefix to show the error is from remote service (#1499) (#1515) Signed-off-by: Yaliang Wu (cherry picked from commit 3897ad179437e683033d6918ebb6b4edf439dd4d) Co-authored-by: Yaliang Wu * fix multiple docs support (#1516) Signed-off-by: Yaliang Wu * adding another fix issue to the release note (#1498) (#1514) Signed-off-by: Dhrubo Saha (cherry picked from commit 440155c5c242113cd264b59818d1927b498c1480) Co-authored-by: Dhrubo Saha * add bedrockURL to trusted connector regex list (#1461) Signed-off-by: Bhavana Ramaram * return parsing exception 400 for parsing errors Signed-off-by: Xun Zhang * add more ut in restupdateconnector Signed-off-by: Xun Zhang * fix format violations Signed-off-by: Bhavana Ramaram * Fix model/connector update API to address security concern (#1595) * Fix model/connector update API to address appsec concern Signed-off-by: Sicheng Song * Fix compile and build failure Signed-off-by: Sicheng Song * Improve unit test coverage Signed-off-by: Sicheng Song * Fix spotless Signed-off-by: Sicheng Song * Merge update connector feature flag to remote inference feature flag Signed-off-by: Sicheng Song * Fix compile Signed-off-by: Sicheng Song * Fix exception status Signed-off-by: Sicheng Song * Keep fixing exception status Signed-off-by: Sicheng Song * Spotless fix Signed-off-by: Sicheng Song * Add UT on parsing exception Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * change XContentFactory to MediaTypeRegistry builder in MLRegisterModelInputTest class Signed-off-by: Bhavana Ramaram --------- Signed-off-by: Yaliang Wu Signed-off-by: Bhavana Ramaram Signed-off-by: Dhrubo Saha Signed-off-by: Sicheng Song Signed-off-by: Xun Zhang Co-authored-by: Yaliang Wu Co-authored-by: Dhrubo Saha Co-authored-by: Sicheng Song Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Co-authored-by: zane-neo Co-authored-by: Xun Zhang * Allow model setting to transport from rest Signed-off-by: Owais Kazi * Added test Signed-off-by: Owais Kazi * Fixed integ test Signed-off-by: Owais Kazi --------- Signed-off-by: Yaliang Wu Signed-off-by: Bhavana Ramaram Signed-off-by: Dhrubo Saha Signed-off-by: Sicheng Song Signed-off-by: Xun Zhang Signed-off-by: Owais Kazi Co-authored-by: Bhavana Ramaram Co-authored-by: Yaliang Wu Co-authored-by: Dhrubo Saha Co-authored-by: Sicheng Song Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Co-authored-by: zane-neo Co-authored-by: Xun Zhang --- .../TransportRegisterModelAction.java | 10 ++++ .../ml/rest/RestMLRegisterModelAction.java | 9 ---- .../ml/action/models/SearchModelITTests.java | 8 ++++ .../TransportRegisterModelActionTests.java | 47 +++++++++++++++++++ .../rest/RestMLRegisterModelActionTests.java | 14 ------ 5 files changed, 65 insertions(+), 23 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index ca226f0251..e13ea03173 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; @@ -89,6 +90,7 @@ public class TransportRegisterModelAction extends HandledTransportAction trustedConnectorEndpointsRegex; ModelAccessControlHelper modelAccessControlHelper; + private volatile boolean isModelUrlAllowed; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelGroupManager mlModelGroupManager; @@ -132,6 +134,9 @@ public TransportRegisterModelAction( trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_URL_REGEX, it -> trustedUrlRegex = it); + isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -142,6 +147,11 @@ public TransportRegisterModelAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { + throw new IllegalArgumentException( + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client)); if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 9e76a48c97..631462e773 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -7,7 +7,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; @@ -35,7 +34,6 @@ public class RestMLRegisterModelAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action"; - private volatile boolean isModelUrlAllowed; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** @@ -51,8 +49,6 @@ public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting * @param settings settings */ public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { - isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @@ -103,11 +99,6 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - if (mlInput.getUrl() != null && !isModelUrlAllowed) { - throw new IllegalArgumentException( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - } return new MLRegisterModelRequest(mlInput); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index d5c1347e26..5b2eae3f4a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -5,11 +5,14 @@ package org.opensearch.ml.action.models; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.settings.Settings; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.MLCommonsIntegTestCase; @@ -187,4 +190,9 @@ private void test_matchPhrase_search() { assertEquals(1, response.getHits().getTotalHits().value); } + @Override + protected Settings nodeSettings(int ordinal) { + return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build(); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b40a278289..83ecd01069 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -155,12 +156,14 @@ public void setup() throws IOException { settings = Settings .builder() .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true) .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() { assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage()); } + public void testRegisterModelUrlNotAllowed() throws Exception { + Settings settings = Settings + .builder() + .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false) + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + transportRegisterModelAction = new TransportRegisterModelAction( + transportService, + actionFilters, + modelHelper, + mlIndicesHandler, + mlModelManager, + mlTaskManager, + clusterService, + settings, + threadPool, + client, + nodeFilter, + mlTaskDispatcher, + mlStats, + modelAccessControlHelper, + connectorAccessControlHelper, + mlModelGroupManager + ); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener) + ); + assertEquals( + e.getMessage(), + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } + public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 8655a4eb06..b3f3e3f956 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -147,20 +147,6 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { restMLRegisterModelAction.handleRequest(request, channel, client); } - public void testRegisterModelUrlNotAllowed() throws Exception { - settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - RestRequest request = getRestRequest(); - restMLRegisterModelAction.handleRequest(request, channel, client); - } - public void testRegisterModelRequestWithNullUrlAndUrlNotAllowed() throws Exception { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);