diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java index e020dcdd60..6cfd159258 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java @@ -13,7 +13,7 @@ public interface Ingestable { * @param mlBatchIngestionInput batch ingestion input data * @return successRate (0 - 100) */ - default double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + default double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) { throw new IllegalStateException("Ingest is not implemented"); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java index 8dc94894ef..955db0a038 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java @@ -39,7 +39,7 @@ public OpenAIDataIngestion(Client client) { } @Override - public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) { List sources = (List) mlBatchIngestionInput.getDataSources().get(SOURCE); if (Objects.isNull(sources) || sources.isEmpty()) { return 100; @@ -48,13 +48,19 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { boolean isSoleSource = sources.size() == 1; List successRates = Collections.synchronizedList(new ArrayList<>()); for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) { - successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource)); + successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize)); } return calculateSuccessRate(successRates); } - private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) { + private double ingestSingleSource( + String fileId, + MLBatchIngestionInput mlBatchIngestionInput, + int sourceIndex, + boolean isSoleSource, + int bulkSize + ) { double successRate = 0; try { String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY); @@ -82,8 +88,8 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn linesBuffer.add(line); lineCount++; - // Process every 100 lines - if (lineCount % 100 == 0) { + // Process every bulkSize lines + if (lineCount % bulkSize == 0) { // Create a CompletableFuture that will be completed by the bulkResponseListener CompletableFuture future = new CompletableFuture<>(); batchIngest( diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java index b6fb3e1226..27aafd72d8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java @@ -53,7 +53,7 @@ public S3DataIngestion(Client client) { } @Override - public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) { S3Client s3 = initS3Client(mlBatchIngestionInput); List s3Uris = (List) mlBatchIngestionInput.getDataSources().get(SOURCE); @@ -63,7 +63,7 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { boolean isSoleSource = s3Uris.size() == 1; List successRates = Collections.synchronizedList(new ArrayList<>()); for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) { - successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource)); + successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize)); } return calculateSuccessRate(successRates); @@ -74,7 +74,8 @@ public double ingestSingleSource( String s3Uri, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, - boolean isSoleSource + boolean isSoleSource, + int bulkSize ) { String bucketName = getS3BucketName(s3Uri); String keyName = getS3KeyName(s3Uri); @@ -99,8 +100,8 @@ public double ingestSingleSource( linesBuffer.add(line); lineCount++; - // Process every 100 lines - if (lineCount % 100 == 0) { + // Process every bulkSize lines + if (lineCount % bulkSize == 0) { // Create a CompletableFuture that will be completed by the bulkResponseListener CompletableFuture future = new CompletableFuture<>(); batchIngest( diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index aa78c119f9..cdb32dde84 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG; @@ -24,12 +25,15 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; @@ -60,16 +64,19 @@ public class TransportBatchIngestionAction extends HandledTransportAction batchIngestionBulkSize = it); + } @Override @@ -131,33 +144,44 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu .state(MLTaskState.CREATED) .build(); - mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { - String taskId = response.getId(); - try { - mlTask.setTaskId(taskId); - mlTaskManager.add(mlTask); - listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name())); - String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE); - Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class); - threadPool.executor(INGEST_THREAD_POOL).execute(() -> { - executeWithErrorHandling(() -> { - double successRate = ingestable.ingest(mlBatchIngestionInput); - handleSuccessRate(successRate, taskId); - }, taskId); - }); - } catch (Exception ex) { - log.error("Failed in batch ingestion", ex); - mlTaskManager - .updateMLTask( - taskId, - Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), - TASK_SEMAPHORE_TIMEOUT, - true - ); - listener.onFailure(ex); + mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> { + if (exceedLimits) { + String error = "exceed maximum BATCH_INGEST Task limits"; + log.warn(error + " in task " + mlTask.getTaskId()); + listener.onFailure(new MLLimitExceededException(error)); + } else { + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + try { + mlTask.setTaskId(taskId); + mlTaskManager.add(mlTask); + listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name())); + String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE); + Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class); + threadPool.executor(INGEST_THREAD_POOL).execute(() -> { + executeWithErrorHandling(() -> { + double successRate = ingestable.ingest(mlBatchIngestionInput, batchIngestionBulkSize); + handleSuccessRate(successRate, taskId); + }, taskId); + }); + } catch (Exception ex) { + log.error("Failed in batch ingestion", ex); + mlTaskManager + .updateMLTask( + taskId, + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), + TASK_SEMAPHORE_TIMEOUT, + true + ); + listener.onFailure(ex); + } + }, exception -> { + log.error("Failed to create batch ingestion task", exception); + listener.onFailure(exception); + })); } }, exception -> { - log.error("Failed to create batch ingestion task", exception); + log.error("Failed to check the maximum BATCH_INGEST Task limits", exception); listener.onFailure(exception); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 46114881a5..bdbd0687c9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -40,6 +40,8 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -107,6 +109,7 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; @@ -177,6 +180,8 @@ public class MLModelManager { private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; private volatile Integer maxDeployTasksPerNode; + private volatile Integer maxBatchInferenceTasks; + private volatile Integer maxBatchIngestionTasks; public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet .of( @@ -232,6 +237,16 @@ public MLModelManager( clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); + + maxBatchInferenceTasks = ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, it -> maxBatchInferenceTasks = it); + + maxBatchIngestionTasks = ML_COMMONS_MAX_BATCH_INGESTION_TASKS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INGESTION_TASKS, it -> maxBatchIngestionTasks = it); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { @@ -867,6 +882,18 @@ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) { mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit); } + /** + * Check if exceed batch job task limit + * + * @param mlTask ML task + * @param listener ActionListener if the limit is exceeded + */ + public void checkMaxBatchJobTask(MLTask mlTask, ActionListener listener) { + MLTaskType taskType = mlTask.getTaskType(); + int maxLimit = taskType.equals(MLTaskType.BATCH_PREDICTION) ? maxBatchInferenceTasks : maxBatchIngestionTasks; + mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener); + } + private void updateModelRegisterStateAsDone( MLRegisterModelInput registerModelInput, String taskId, diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 4d5d63acc5..065e0ec371 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -972,7 +972,10 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED, MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, - MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED + MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, + MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, + MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS, + MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 5b0e110d52..718427949f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -34,6 +34,15 @@ private MLCommonsSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + + public static final Setting ML_COMMONS_MAX_BATCH_INFERENCE_TASKS = Setting + .intSetting("plugins.ml_commons.max_batch_inference_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_MAX_BATCH_INGESTION_TASKS = Setting + .intSetting("plugins.ml_commons.max_batch_ingestion_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_BATCH_INGESTION_BULK_SIZE = Setting + .intSetting("plugins.ml_commons.batch_ingestion_bulk_size", 500, 100, 100000, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting .intSetting("plugins.ml_commons.max_deploy_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index a59d7bbe2b..17b22d82d5 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -55,6 +55,7 @@ import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; @@ -253,6 +254,32 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { + if (exceedLimits) { + String error = "exceed maximum BATCH_PREDICTION Task limits"; + log.warn(error + " in task " + mlTask.getTaskId()); + listener.onFailure(new MLLimitExceededException(error)); + } else { + executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener); + } + }, exception -> { + log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception); + listener.onFailure(exception); + })); + return; + } + executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener); + } + + private void executePredictionByInputDataType( + MLInputDataType inputDataType, + String modelId, + MLInput mlInput, + MLTask mlTask, + FunctionName functionName, + ActionListener listener + ) { switch (inputDataType) { case SEARCH_QUERY: ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index ca5b5b0abb..7204eb6738 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.common.MLTask.LAST_UPDATE_TIME_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.TASK_TYPE_FIELD; +import static org.opensearch.ml.common.MLTaskState.CREATED; +import static org.opensearch.ml.common.MLTaskState.RUNNING; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.MLExceptionUtils.logException; @@ -25,10 +28,13 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.client.Requests; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -36,6 +42,8 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -43,6 +51,7 @@ import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -61,7 +70,6 @@ public class MLTaskManager { private final ThreadPool threadPool; private final MLIndicesHandler mlIndicesHandler; private final Map runningTasksCount; - public static final ImmutableSet TASK_DONE_STATES = ImmutableSet .of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED); @@ -91,14 +99,55 @@ public synchronized void checkLimitAndAddRunningTask(MLTask mlTask, Integer limi throw new MLLimitExceededException(error); } if (contains(mlTask.getTaskId())) { - getMLTask(mlTask.getTaskId()).setState(MLTaskState.RUNNING); + getMLTask(mlTask.getTaskId()).setState(RUNNING); } else { - mlTask.setState(MLTaskState.RUNNING); + mlTask.setState(RUNNING); add(mlTask); } runningTaskCount.incrementAndGet(); } + public synchronized void checkMaxBatchJobTask(MLTaskType mlTaskType, Integer maxTaskLimit, ActionListener listener) { + try { + BoolQueryBuilder boolQuery = QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery(TASK_TYPE_FIELD, mlTaskType.name())) + .must( + QueryBuilders + .boolQuery() + .should(QueryBuilders.termQuery(STATE_FIELD, CREATED)) + .should(QueryBuilders.termQuery(STATE_FIELD, RUNNING)) + ); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery); + SearchRequest searchRequest = new SearchRequest(ML_TASK_INDEX); + searchRequest.source(searchSourceBuilder); + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(ActionListener.wrap(searchResponse -> { + long matchedCount = searchResponse.getHits().getHits().length; + Boolean exceedLimit = false; + if (matchedCount >= maxTaskLimit) { + exceedLimit = true; + } + listener.onResponse(exceedLimit); + }, exception -> { listener.onFailure(exception); }), () -> threadContext.restore()); + + client.admin().indices().refresh(Requests.refreshRequest(ML_TASK_INDEX), ActionListener.wrap(refreshResponse -> { + client.search(searchRequest, internalListener); + }, e -> { + log.error("Failed to refresh Task index during search MLTaskType for " + mlTaskType, e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } catch (Exception e) { + log.error("Failed to search ML task for " + mlTaskType, e); + listener.onFailure(e); + } + } + /** * Put ML task into cache. * If ML task is already in cache, will throw {@link IllegalArgumentException} @@ -140,7 +189,7 @@ public void remove(String taskId) { MLTaskCache taskCache = taskCaches.remove(taskId); MLTask mlTask = taskCache.getMlTask(); - if (mlTask.getState() != MLTaskState.CREATED) { + if (mlTask.getState() != CREATED) { // Task initial state is CREATED. It will move forward to RUNNING state once it starts on worker node. // When finished or failed, it's possible to move to COMPLETED/FAILED state. // So if its state is not CREATED when remove it, the task already started on worker node, we should @@ -205,7 +254,7 @@ public int getRunningTaskCount() { int res = 0; for (Map.Entry entry : taskCaches.entrySet()) { MLTask mlTask = entry.getValue().getMlTask(); - if (mlTask.getState() != null && mlTask.getState() == MLTaskState.RUNNING) { + if (mlTask.getState() != null && mlTask.getState() == RUNNING) { res++; } } @@ -252,9 +301,9 @@ public void updateTaskStateAsRunning(String taskId, boolean isAsyncTask) { throw new IllegalArgumentException("Task not found"); } MLTask task = getMLTask(taskId); - task.setState(MLTaskState.RUNNING); + task.setState(RUNNING); if (isAsyncTask) { - updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); + updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, RUNNING), TASK_SEMAPHORE_TIMEOUT, false); } } @@ -387,7 +436,7 @@ public List getLocalRunningDeployModelTasks() { List runningDeployModelIds = new ArrayList<>(); for (Map.Entry entry : taskCaches.entrySet()) { MLTask mlTask = entry.getValue().getMlTask(); - if (mlTask.getTaskType() == MLTaskType.DEPLOY_MODEL && mlTask.getState() != MLTaskState.CREATED) { + if (mlTask.getTaskType() == MLTaskType.DEPLOY_MODEL && mlTask.getState() != CREATED) { runningDeployModelTaskIds.add(entry.getKey()); runningDeployModelIds.add(mlTask.getModelId()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index 3ad8ba2d07..795d92bbf7 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -22,7 +22,10 @@ import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE; import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.util.ArrayList; import java.util.Arrays; @@ -38,6 +41,9 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; @@ -79,6 +85,9 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { ExecutorService executorService; @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + private ClusterService clusterService; + private Settings settings; private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; @@ -89,14 +98,23 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_BATCH_INGESTION_BULK_SIZE.getKey(), 100).build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_BATCH_INGESTION_BULK_SIZE, + ML_COMMONS_MAX_BATCH_INGESTION_TASKS + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); batchAction = new TransportBatchIngestionAction( + clusterService, transportService, actionFilters, client, mlTaskManager, threadPool, mlModelManager, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + settings ); Map fieldMap = new HashMap<>(); @@ -133,6 +151,11 @@ public void setup() { .build(); when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(false); + return null; + }).when(mlModelManager).checkMaxBatchJobTask(any(MLTask.class), isA(ActionListener.class)); } public void test_doExecute_success() { @@ -149,7 +172,6 @@ public void test_doExecute_success() { runnable.run(); return null; }).when(executorService).execute(any(Runnable.class)); - batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); verify(actionListener).onResponse(any(MLBatchIngestionResponse.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index a379325cc0..8542891c22 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -28,6 +28,9 @@ import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -194,12 +197,18 @@ public void setup() throws URISyntaxException { settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE.getKey(), 10).build(); + settings = Settings.builder().put(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.getKey(), 10).build(); + settings = Settings.builder().put(ML_COMMONS_MAX_BATCH_INGESTION_TASKS.getKey(), 10).build(); + settings = Settings.builder().put(ML_COMMONS_BATCH_INGESTION_BULK_SIZE.getKey(), 100).build(); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_MAX_MODELS_PER_NODE, ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE, ML_COMMONS_MONITORING_REQUEST_COUNT, - ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE + ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, + ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, + ML_COMMONS_MAX_BATCH_INGESTION_TASKS, + ML_COMMONS_BATCH_INGESTION_BULK_SIZE ); clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService)); xContentRegistry = NamedXContentRegistry.EMPTY;