From cd95e2976bb6f6bcee7b9b48a28ec167be2ddca1 Mon Sep 17 00:00:00 2001 From: Paolo Di Tommaso Date: Fri, 1 Apr 2022 23:17:58 +0200 Subject: [PATCH] Add Aws Batch native retry on spot reclaim --- docs/config.rst | 1 + .../aws/batch/AwsBatchTaskHandler.groovy | 57 +++++++------------ .../cloud/aws/batch/AwsOptions.groovy | 5 ++ .../aws/batch/AwsBatchTaskHandlerTest.groovy | 38 ++++--------- 4 files changed, 38 insertions(+), 63 deletions(-) diff --git a/docs/config.rst b/docs/config.rst index e7af2d76b0..0215e068b6 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -163,6 +163,7 @@ volumes One or more container mounts. Mounts can be specifie delayBetweenAttempts Delay between download attempts from S3 (default `10 sec`). maxParallelTransfers Max parallel upload/download transfer operations *per job* (default: ``4``). maxTransferAttempts Max number of downloads attempts from S3 (default: `1`). +maxSpotAttempts Max number of execution attempts of a job interrupted by a EC2 spot reclaim event (default: ``5``, requires ``22.04.0`` or later) =========================== ================ diff --git a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy index a10ded29af..e32a5d43e9 100644 --- a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy +++ b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy @@ -17,11 +17,10 @@ package nextflow.cloud.aws.batch -import static AwsContainerOptionsMapper.* +import static nextflow.cloud.aws.batch.AwsContainerOptionsMapper.* import java.nio.file.Path import java.nio.file.Paths -import java.util.regex.Pattern import com.amazonaws.services.batch.AWSBatch import com.amazonaws.services.batch.model.AWSBatchException @@ -33,6 +32,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult import com.amazonaws.services.batch.model.DescribeJobsRequest import com.amazonaws.services.batch.model.DescribeJobsResult +import com.amazonaws.services.batch.model.EvaluateOnExit import com.amazonaws.services.batch.model.Host import com.amazonaws.services.batch.model.JobDefinition import com.amazonaws.services.batch.model.JobDefinitionType @@ -53,14 +53,12 @@ import groovy.transform.CompileStatic import groovy.util.logging.Slf4j import nextflow.cloud.types.CloudMachineInfo import nextflow.container.ContainerNameValidator -import nextflow.exception.NodeTerminationException import nextflow.exception.ProcessSubmitException import nextflow.exception.ProcessUnrecoverableException import nextflow.executor.BashWrapperBuilder import nextflow.executor.res.AcceleratorResource import nextflow.processor.BatchContext import nextflow.processor.BatchHandler -import nextflow.processor.ErrorStrategy import nextflow.processor.TaskBean import nextflow.processor.TaskHandler import nextflow.processor.TaskRun @@ -74,8 +72,6 @@ import nextflow.util.CacheHelper @Slf4j class AwsBatchTaskHandler extends TaskHandler implements BatchHandler { - private static Pattern TERMINATED = ~/^Host EC2 .* terminated.*/ - private final Path exitFile private final Path wrapperFile @@ -108,8 +104,6 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler environment - private boolean batchNativeRetry - final static private Map jobDefinitions = [:] /** @@ -256,23 +250,15 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler0 ) { + final retry = new RetryStrategy() + .withAttempts( attempts ) + .withEvaluateOnExit( new EvaluateOnExit().withOnReason('Host EC2*').withAction('RETRY') ) result.setRetryStrategy(retry) - this.batchNativeRetry = true } // set task timeout diff --git a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsOptions.groovy b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsOptions.groovy index de29f0112b..513a75df3e 100644 --- a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsOptions.groovy +++ b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsOptions.groovy @@ -41,6 +41,8 @@ class AwsOptions implements CloudTransferOptions { public static final int DEFAULT_AWS_MAX_ATTEMPTS = 5 + public static final int DEFAULT_MAX_SPOT_ATTEMPTS = 5 + private Map env = System.getenv() String cliPath @@ -61,6 +63,8 @@ class AwsOptions implements CloudTransferOptions { String retryMode + int maxSpotAttempts + volatile Boolean fetchInstanceType /** @@ -93,6 +97,7 @@ class AwsOptions implements CloudTransferOptions { maxParallelTransfers = session.config.navigate('aws.batch.maxParallelTransfers', MAX_TRANSFER) as int maxTransferAttempts = session.config.navigate('aws.batch.maxTransferAttempts', defaultMaxTransferAttempts()) as int delayBetweenAttempts = session.config.navigate('aws.batch.delayBetweenAttempts', DEFAULT_DELAY_BETWEEN_ATTEMPTS) as Duration + maxSpotAttempts = session.config.navigate('aws.batch.maxSpotAttempts', DEFAULT_MAX_SPOT_ATTEMPTS) as int region = session.config.navigate('aws.region') as String volumes = makeVols(session.config.navigate('aws.batch.volumes')) jobRole = session.config.navigate('aws.batch.jobRole') diff --git a/plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy b/plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy index 6581af1417..2d15f52791 100644 --- a/plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy +++ b/plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy @@ -25,6 +25,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult import com.amazonaws.services.batch.model.DescribeJobsRequest import com.amazonaws.services.batch.model.DescribeJobsResult +import com.amazonaws.services.batch.model.EvaluateOnExit import com.amazonaws.services.batch.model.JobDefinition import com.amazonaws.services.batch.model.JobDetail import com.amazonaws.services.batch.model.KeyValuePair @@ -36,7 +37,6 @@ import com.amazonaws.services.batch.model.SubmitJobResult import com.amazonaws.services.batch.model.TerminateJobRequest import nextflow.cloud.types.CloudMachineInfo import nextflow.cloud.types.PriceModel -import nextflow.exception.NodeTerminationException import nextflow.exception.ProcessUnrecoverableException import nextflow.executor.Executor import nextflow.processor.BatchContext @@ -84,6 +84,7 @@ class AwsBatchTaskHandlerTest extends Specification { when: def req = handler.newSubmitRequest(task) then: + 1 * handler.maxSpotAttempts() >> 5 1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') } 1 * handler.getJobQueue(task) >> 'queue1' 1 * handler.getJobDefinition(task) >> 'job-def:1' @@ -98,11 +99,12 @@ class AwsBatchTaskHandlerTest extends Specification { req.getContainerOverrides().getResourceRequirements().find { it.type=='MEMORY'}.getValue() == '8192' req.getContainerOverrides().getEnvironment() == [VAR_FOO, VAR_BAR] req.getContainerOverrides().getCommand() == ['bash', '-o','pipefail','-c', "trap \"{ ret=\$?; /bin/aws s3 cp --only-show-errors .command.log s3://bucket/test/.command.log||true; exit \$ret; }\" EXIT; /bin/aws s3 cp --only-show-errors s3://bucket/test/.command.run - | bash 2>&1 | tee .command.log".toString()] - req.getRetryStrategy() == null // <-- retry is managed by NF, hence this must be null + req.getRetryStrategy() == new RetryStrategy().withAttempts(5).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') ) when: req = handler.newSubmitRequest(task) then: + 1 * handler.maxSpotAttempts() >> 0 1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') } 1 * handler.getJobQueue(task) >> 'queue1' 1 * handler.getJobDefinition(task) >> 'job-def:1' @@ -135,6 +137,7 @@ class AwsBatchTaskHandlerTest extends Specification { then: handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') } and: + 1 * handler.maxSpotAttempts() >> 0 1 * handler.getJobQueue(task) >> 'queue1' 1 * handler.getJobDefinition(task) >> 'job-def:1' and: @@ -160,6 +163,7 @@ class AwsBatchTaskHandlerTest extends Specification { task.getConfig() >> new TaskConfig() handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') } and: + 1 * handler.maxSpotAttempts() >> 0 1 * handler.getJobQueue(task) >> 'queue1' 1 * handler.getJobDefinition(task) >> 'job-def:1' and: @@ -176,6 +180,7 @@ class AwsBatchTaskHandlerTest extends Specification { task.getConfig() >> new TaskConfig(time: '5 sec') handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') } and: + 1 * handler.maxSpotAttempts() >> 0 1 * handler.getJobQueue(task) >> 'queue2' 1 * handler.getJobDefinition(task) >> 'job-def:2' and: @@ -193,6 +198,7 @@ class AwsBatchTaskHandlerTest extends Specification { task.getConfig() >> new TaskConfig(time: '1 hour') handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') } and: + 1 * handler.maxSpotAttempts() >> 0 1 * handler.getJobQueue(task) >> 'queue3' 1 * handler.getJobDefinition(task) >> 'job-def:3' and: @@ -221,6 +227,7 @@ class AwsBatchTaskHandlerTest extends Specification { then: handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', retryMode: 'adaptive', maxTransferAttempts: 10) } and: + 1 * handler.maxSpotAttempts() >> 3 1 * handler.getJobQueue(task) >> 'queue1' 1 * handler.getJobDefinition(task) >> 'job-def:1' 1 * handler.wrapperFile >> Paths.get('/bucket/test/.command.run') @@ -230,7 +237,7 @@ class AwsBatchTaskHandlerTest extends Specification { req.getJobQueue() == 'queue1' req.getJobDefinition() == 'job-def:1' // no error `retry` error strategy is defined by NF, use `maxRetries` to se Batch attempts - req.getRetryStrategy() == new RetryStrategy().withAttempts(3) + req.getRetryStrategy() == new RetryStrategy().withAttempts(3).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') ) req.getContainerOverrides().getEnvironment() == [VAR_RETRY_MODE, VAR_MAX_ATTEMPTS, VAR_METADATA_ATTEMPTS] } @@ -727,29 +734,4 @@ class AwsBatchTaskHandlerTest extends Specification { trace.machineInfo.priceModel == PriceModel.spot } - def 'should check spot termination' () { - given: - def JOB_ID = 'job-2' - def client = Mock(AWSBatch) - def task = new TaskRun() - def handler = Spy(AwsBatchTaskHandler) - handler.client = client - handler.jobId = JOB_ID - handler.task = task - and: - handler.isRunning() >> true - handler.describeJob(JOB_ID) >> Mock(JobDetail) { - getStatus() >> 'FAILED' - getStatusReason() >> "Host EC2 (instance i-0e2d5c2edc932b4e8) terminated." - } - - when: - def done = handler.checkIfCompleted() - then: - task.aborted - task.error instanceof NodeTerminationException - and: - done == true - - } }