Skip to content

Commit

Permalink
Add Platform workflow prefix in AWS Batch job names (#5318) [e2e test]
Browse files Browse the repository at this point in the history
Signed-off-by: Paolo Di Tommaso <[email protected]>
  • Loading branch information
pditommaso committed Oct 2, 2024
1 parent d687033 commit 42dd4ba
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,21 @@ abstract class TaskHandler {
return true
return false
}

/**
* Prepend the workflow Id to the job/task name. The workflow id is defined
* by the environment variable {@code TOWER_WORKFLOW_ID}
*
* @param name
* The desired job name
* @param env
* A map representing the variables in the host environment
* @return
* The job having the prefix {@code tw-<ID>} when the variable {@code TOWER_WORKFLOW_ID}
* is defined in the host environment or just {@code name} otherwise
*/
static String prependWorkflowPrefix(String name, Map<String,String> env) {
final workflowId = env.get("TOWER_WORKFLOW_ID")
return workflowId ? "tw-${workflowId}-${name}" : name
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,18 @@ class TaskHandlerTest extends Specification {
TaskStatus.RUNNING | false | false | true | true | false
TaskStatus.COMPLETED| false | false | false | false | true
}

@Unroll
def 'should include the tower prefix'() {
given:
def name = 'job_1'

expect:
TaskHandler.prependWorkflowPrefix(name, ENV) == EXPECTED

where:
ENV | EXPECTED
[:] | "job_1"
[TOWER_WORKFLOW_ID: '1234'] | "tw-1234-job_1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import groovy.transform.CompileStatic
import groovy.transform.Memoized
import groovy.util.logging.Slf4j
import nextflow.BuildInfo
import nextflow.SysEnv
import nextflow.cloud.types.CloudMachineInfo
import nextflow.container.ContainerNameValidator
import nextflow.exception.ProcessException
Expand Down Expand Up @@ -111,7 +112,7 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job

private CloudMachineInfo machineInfo

private Map<String,String> environment
private Map<String,String> environment = Map<String,String>.of()

final static private Map<String,String> jobDefinitions = [:]

Expand All @@ -133,7 +134,7 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
super(task)
this.executor = executor
this.client = executor.client
this.environment = System.getenv()
this.environment = SysEnv.get()
this.logFile = task.workDir.resolve(TaskRun.CMD_LOG)
this.scriptFile = task.workDir.resolve(TaskRun.CMD_SCRIPT)
this.inputFile = task.workDir.resolve(TaskRun.CMD_INFILE)
Expand Down Expand Up @@ -698,6 +699,11 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
return executor.awsOptions.maxSpotAttempts
}

protected String getJobName(TaskRun task) {
final result = prependWorkflowPrefix(task.name, environment)
return normalizeJobName(result)
}

/**
* Create a new Batch job request for the given NF {@link TaskRun}
*
Expand All @@ -712,7 +718,7 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
final opts = getAwsOptions()
final labels = task.config.getResourceLabels()
final result = new SubmitJobRequest()
result.setJobName(normalizeJobName(task.name))
result.setJobName(getJobName(task))
result.setJobQueue(getJobQueue(task))
result.setJobDefinition(getJobDefinition(task))
if( labels ) {
Expand Down Expand Up @@ -838,11 +844,10 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
* @return A job name without invalid characters
*/
protected String normalizeJobName(String name) {
def result = name.replaceAll(' ','_').replaceAll(/[^a-zA-Z0-9_]/,'')
def result = name.replaceAll(' ','_').replaceAll(/[^a-zA-Z0-9_-]/,'')
result.size()>128 ? result.substring(0,128) : result
}


protected CloudMachineInfo getMachineInfo() {
if( machineInfo )
return machineInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AwsBatchTaskHandlerTest extends Specification {
expect:
handler.normalizeJobName('foo') == 'foo'
handler.normalizeJobName('foo (12)') == 'foo_12'
handler.normalizeJobName('foo-12') == 'foo-12'

when:
def looong = '012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789'
Expand Down Expand Up @@ -97,7 +98,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.getEnvironmentVars() >> [VAR_FOO, VAR_BAR]

req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
req.getContainerOverrides().getResourceRequirements().find { it.type=='VCPU'}.getValue() == '4'
Expand All @@ -118,7 +119,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.getEnvironmentVars() >> [VAR_FOO, VAR_BAR]

req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
req.getContainerOverrides().getResourceRequirements().find { it.type=='VCPU'}.getValue() == '4'
Expand Down Expand Up @@ -148,7 +149,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.getEnvironmentVars() >> []

req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
req.getContainerOverrides().getResourceRequirements().find { it.type=='VCPU'}.getValue() == '4'
Expand All @@ -165,7 +166,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.getEnvironmentVars() >> []

req2.getJobName() == 'batchtask'
req2.getJobName() == 'batch-task'
req2.getJobQueue() == 'queue1'
req2.getJobDefinition() == 'job-def:1'
req2.getContainerOverrides().getResourceRequirements().find { it.type=='VCPU'}.getValue() == '4'
Expand Down Expand Up @@ -232,7 +233,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
and:
req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
req.getTimeout() == null
Expand All @@ -249,7 +250,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobQueue(task) >> 'queue2'
1 * handler.getJobDefinition(task) >> 'job-def:2'
and:
req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue2'
req.getJobDefinition() == 'job-def:2'
// minimal allowed timeout is 60 seconds
Expand All @@ -268,7 +269,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobQueue(task) >> 'queue3'
1 * handler.getJobDefinition(task) >> 'job-def:3'
and:
req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue3'
req.getJobDefinition() == 'job-def:3'
// minimal allowed timeout is 60 seconds
Expand Down Expand Up @@ -299,7 +300,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
and:
req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
// no error `retry` error strategy is defined by NF, use `maxRetries` to se Batch attempts
Expand Down Expand Up @@ -1013,7 +1014,7 @@ class AwsBatchTaskHandlerTest extends Specification {
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.getEnvironmentVars() >> [VAR_FOO, VAR_BAR]

req.getJobName() == 'batchtask'
req.getJobName() == 'batch-task'
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
req.getContainerOverrides().getResourceRequirements().find { it.type=='VCPU'}.getValue() == '4'
Expand Down Expand Up @@ -1100,4 +1101,25 @@ class AwsBatchTaskHandlerTest extends Specification {
'job1' | 'job1'
'job1:task2' | 'job1'
}

def 'should get job name' () {
given:
def handler = Spy(new AwsBatchTaskHandler(environment: ENV))
def task = Mock(TaskRun)

when:
def result = handler.getJobName(task)
then:
task.getName() >> NAME
and:
result == EXPECTED

where:
ENV | NAME | EXPECTED
[:] | 'foo' | 'foo'
[TOWER_WORKFLOW_ID: '12345'] | 'foo' | 'tw-12345-foo'
[TOWER_WORKFLOW_ID: '12345'] | 'foo' | 'tw-12345-foo'
[TOWER_WORKFLOW_ID: '12345'] | 'foo(12)' | 'tw-12345-foo12'

}
}

0 comments on commit 42dd4ba

Please sign in to comment.