diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index 6fd03b7b52..6df22475e9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG; import java.time.Instant; import java.util.List; @@ -35,6 +36,7 @@ import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.ingest.Ingestable; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.tasks.Task; @@ -55,6 +57,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction remoteJobStatusFields; volatile Pattern remoteJobCompletedStatusRegexPattern; @@ -111,6 +112,7 @@ public GetTaskTransportAction( EncryptorImpl encryptor, MLTaskManager mlTaskManager, MLModelManager mlModelManager, + MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings ) { super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new); @@ -122,6 +124,7 @@ public GetTaskTransportAction( this.encryptor = encryptor; this.mlTaskManager = mlTaskManager; this.mlModelManager = mlModelManager; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it); @@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, - MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED + MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED, + MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, + MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java index aeee474864..5284d42a32 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java @@ -24,7 +24,7 @@ import com.google.common.collect.ImmutableList; public class RestMLGetTaskAction extends BaseRestHandler { - private static final String ML_GET_Task_ACTION = "ml_get_task_action"; + private static final String ML_GET_TASK_ACTION = "ml_get_task_action"; /** * Constructor @@ -33,7 +33,7 @@ public RestMLGetTaskAction() {} @Override public String getName() { - return ML_GET_Task_ACTION; + return ML_GET_TASK_ACTION; } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 72b841eb7b..68c0146ab2 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; @@ -131,6 +132,8 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { + throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG); } else if (!ActionType.isValidActionInModelPrediction(actionType)) { throw new IllegalArgumentException("Wrong action type in the rest request path!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 4b034c035a..5b0e110d52 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -136,6 +136,12 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting + .boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting + .boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting .listSetting( "plugins.ml_commons.trusted_connector_endpoints_regex", diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index 0fd11e7c72..93159125de 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -11,6 +11,8 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import java.util.concurrent.atomic.AtomicBoolean; @@ -27,6 +29,8 @@ public class MLFeatureEnabledSetting { private volatile AtomicBoolean isConnectorPrivateIpEnabled; private volatile Boolean isControllerEnabled; + private volatile Boolean isBatchIngestionEnabled; + private volatile Boolean isBatchInferenceEnabled; public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); @@ -34,6 +38,8 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); + isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings); + isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings); clusterService .getClusterSettings() @@ -46,6 +52,12 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it)); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it); } /** @@ -84,4 +96,19 @@ public Boolean isControllerEnabled() { return isControllerEnabled; } + /** + * Whether the offline batch ingestion is enabled. If disabled, APIs in ml-commons will block offline batch ingestion. + * @return whether the feature is enabled. + */ + public Boolean isOfflineBatchIngestionEnabled() { + return isBatchIngestionEnabled; + } + + /** + * Whether the offline batch inference is enabled. If disabled, APIs in ml-commons will block offline batch inference. + * @return whether the feature is enabled. + */ + public Boolean isOfflineBatchInferenceEnabled() { + return isBatchInferenceEnabled; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 5340edba0f..7a056c762c 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -22,10 +22,14 @@ public class MLExceptionUtils { "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; public static final String LOCAL_MODEL_DISABLED_ERR_MSG = "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true."; + public static final String BATCH_INFERENCE_DISABLED_ERR_MSG = + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true."; public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG = "Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true."; public static final String CONTROLLER_DISABLED_ERR_MSG = "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true."; + public static final String OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG = + "Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index 092edfe951..f1ab6715f6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -46,6 +46,7 @@ import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -73,6 +74,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { ThreadPool threadPool; @Mock ExecutorService executorService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; @@ -81,7 +84,14 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool); + batchAction = new TransportBatchIngestionAction( + transportService, + actionFilters, + client, + mlTaskManager, + threadPool, + mlFeatureEnabledSetting + ); Map fieldMap = new HashMap<>(); fieldMap.put("chapter", "$.content[0]"); @@ -106,6 +116,8 @@ public void setup() { .dataSources(dataSource) .build(); when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + + when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true); } public void test_doExecute_success() { @@ -181,6 +193,18 @@ public void test_doExecute_handleSuccessRate0() { ); } + public void test_doExecute_batchIngestionDisabled() { + when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(false); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + public void test_doExecute_noDataSource() { MLBatchIngestionInput batchInput = MLBatchIngestionInput .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 99d9fbf8a1..0c6939ea77 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -61,6 +61,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -106,6 +107,9 @@ public class CancelBatchJobTransportActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -139,7 +143,8 @@ public void setup() throws IOException { connectorAccessControlHelper, encryptor, mlTaskManager, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); @@ -182,7 +187,7 @@ public void setup() throws IOException { listener.onResponse(connector); return null; }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); - + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); } public void testGetTask_NullResponse() { @@ -221,6 +226,28 @@ public void testGetTask_IndexNotFoundException() { assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); } + public void testGetTask_FeatureFlagDisabled() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Ignore public void testGetTask_SuccessBatchPredictCancel() throws IOException { Map remoteJob = new HashMap<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 1c9a1c449a..25c43eb9b6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -71,6 +71,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -116,6 +117,9 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -172,6 +176,7 @@ public void setup() throws IOException { encryptor, mlTaskManager, mlModelManager, + mlFeatureEnabledSetting, settings ) ); @@ -215,7 +220,7 @@ public void setup() throws IOException { listener.onResponse(connector); return null; }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); - + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); } public void testGetTask_NullResponse() { @@ -299,6 +304,31 @@ public void test_BatchPredictStatus_NoConnector() throws IOException { assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); } + public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index 001b3709a8..c90f765ed0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -157,6 +157,7 @@ public void testPrepareRequest() throws Exception { public void testPrepareBatchRequest() throws Exception { RestRequest request = getBatchRestRequest(); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); restMLPredictionAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argumentCaptor.capture(), any()); @@ -164,6 +165,18 @@ public void testPrepareBatchRequest() throws Exception { verifyParsedBatchMLInput(mlInput); } + public void testPrepareBatchRequest_FeatureFlagDisabled() throws Exception { + thrown.expect(IllegalStateException.class); + thrown + .expectMessage( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true." + ); + + RestRequest request = getBatchRestRequest(); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + restMLPredictionAction.handleRequest(request, channel, client); + } + public void testPrepareBatchRequest_WrongActionType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("Wrong Action Type");