Skip to content

Commit

Permalink
add rate limiting for offline batch jobs, set default bulk size to 500
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Oct 16, 2024
1 parent 09ee93f commit a22e926
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public interface Ingestable {
* @param mlBatchIngestionInput batch ingestion input data
* @return successRate (0 - 100)
*/
default double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
default double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
throw new IllegalStateException("Ingest is not implemented");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public OpenAIDataIngestion(Client client) {
}

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
List<String> sources = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
if (Objects.isNull(sources) || sources.isEmpty()) {
return 100;
Expand All @@ -48,13 +48,19 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
boolean isSoleSource = sources.size() == 1;
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) {
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
}

return calculateSuccessRate(successRates);
}

private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) {
private double ingestSingleSource(
String fileId,
MLBatchIngestionInput mlBatchIngestionInput,
int sourceIndex,
boolean isSoleSource,
int bulkSize
) {
double successRate = 0;
try {
String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY);
Expand Down Expand Up @@ -82,8 +88,8 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount % 100 == 0) {
// Process every bulkSize lines
if (lineCount % bulkSize == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public S3DataIngestion(Client client) {
}

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
S3Client s3 = initS3Client(mlBatchIngestionInput);

List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
Expand All @@ -63,7 +63,7 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
boolean isSoleSource = s3Uris.size() == 1;
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) {
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
}

return calculateSuccessRate(successRates);
Expand All @@ -74,7 +74,8 @@ public double ingestSingleSource(
String s3Uri,
MLBatchIngestionInput mlBatchIngestionInput,
int sourceIndex,
boolean isSoleSource
boolean isSoleSource,
int bulkSize
) {
String bucketName = getS3BucketName(s3Uri);
String keyName = getS3KeyName(s3Uri);
Expand All @@ -99,8 +100,8 @@ public double ingestSingleSource(
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount % 100 == 0) {
// Process every bulkSize lines
if (lineCount % bulkSize == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;

Expand All @@ -24,12 +25,15 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
Expand Down Expand Up @@ -60,16 +64,19 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
private final Client client;
private ThreadPool threadPool;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
private volatile Integer batchIngestionBulkSize;

@Inject
public TransportBatchIngestionAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLTaskManager mlTaskManager,
ThreadPool threadPool,
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
Expand All @@ -78,6 +85,12 @@ public TransportBatchIngestionAction(
this.threadPool = threadPool;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

batchIngestionBulkSize = ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_BATCH_INGESTION_BULK_SIZE, it -> batchIngestionBulkSize = it);

}

@Override
Expand Down Expand Up @@ -131,33 +144,44 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu
.state(MLTaskState.CREATED)
.build();

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error = "exceed maximum BATCH_INGEST Task limits";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new MLLimitExceededException(error));
} else {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput, batchIngestionBulkSize);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
listener.onFailure(exception);
}));
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
log.error("Failed to check the maximum BATCH_INGEST Task limits", exception);
listener.onFailure(exception);
}));
}
Expand Down
27 changes: 27 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
Expand Down Expand Up @@ -107,6 +109,7 @@
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.controller.MLRateLimiter;
Expand Down Expand Up @@ -177,6 +180,8 @@ public class MLModelManager {
private volatile Integer maxModelPerNode;
private volatile Integer maxRegisterTasksPerNode;
private volatile Integer maxDeployTasksPerNode;
private volatile Integer maxBatchInferenceTasks;
private volatile Integer maxBatchIngestionTasks;

public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
.of(
Expand Down Expand Up @@ -232,6 +237,16 @@ public MLModelManager(
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);

maxBatchInferenceTasks = ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, it -> maxBatchInferenceTasks = it);

maxBatchIngestionTasks = ML_COMMONS_MAX_BATCH_INGESTION_TASKS.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INGESTION_TASKS, it -> maxBatchIngestionTasks = it);
}

public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
Expand Down Expand Up @@ -867,6 +882,18 @@ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
}

/**
* Check if exceed batch job task limit
*
* @param mlTask ML task
* @param listener ActionListener if the limit is exceeded
*/
public void checkMaxBatchJobTask(MLTask mlTask, ActionListener<Boolean> listener) {
MLTaskType taskType = mlTask.getTaskType();
int maxLimit = taskType.equals(MLTaskType.BATCH_PREDICTION) ? maxBatchInferenceTasks : maxBatchIngestionTasks;
mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener);
}

private void updateModelRegisterStateAsDone(
MLRegisterModelInput registerModelInput,
String taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,10 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS,
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS,
MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ private MLCommonsSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INFERENCE_TASKS = Setting
.intSetting("plugins.ml_commons.max_batch_inference_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INGESTION_TASKS = Setting
.intSetting("plugins.ml_commons.max_batch_ingestion_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_BATCH_INGESTION_BULK_SIZE = Setting
.intSetting("plugins.ml_commons.batch_ingestion_bulk_size", 500, 100, 100000, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_deploy_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
Expand Down Expand Up @@ -253,6 +254,32 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
.lastUpdateTime(now)
.async(false)
.build();
if (actionType.equals(ActionType.BATCH_PREDICT)) {
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error = "exceed maximum BATCH_PREDICTION Task limits";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new MLLimitExceededException(error));
} else {
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
}
}, exception -> {
log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception);
listener.onFailure(exception);
}));
return;
}
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
}

private void executePredictionByInputDataType(
MLInputDataType inputDataType,
String modelId,
MLInput mlInput,
MLTask mlTask,
FunctionName functionName,
ActionListener<MLTaskResponse> listener
) {
switch (inputDataType) {
case SEARCH_QUERY:
ActionListener<MLInputDataset> dataFrameActionListener = ActionListener.wrap(dataSet -> {
Expand Down
Loading

0 comments on commit a22e926

Please sign in to comment.