Skip to content

Commit

Permalink
Add flatten operator [ci fast]
Browse files Browse the repository at this point in the history
Signed-off-by: Paolo Di Tommaso <[email protected]>
  • Loading branch information
pditommaso committed Jan 14, 2025
1 parent 2f643f9 commit 84f7806
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 45 deletions.
106 changes: 106 additions & 0 deletions modules/nextflow/src/main/groovy/nextflow/extension/FlattenOp.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package nextflow.extension


import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import groovyx.gpars.dataflow.DataflowReadChannel
import groovyx.gpars.dataflow.DataflowWriteChannel
import groovyx.gpars.dataflow.expression.DataflowExpression
import groovyx.gpars.dataflow.operator.DataflowEventAdapter
import groovyx.gpars.dataflow.operator.DataflowProcessor
import nextflow.Channel
import nextflow.Global
import nextflow.Session
import nextflow.extension.op.Op
/**
* Implements "flatten" operator
*
* @author Paolo Di Tommaso <[email protected]>
*/
@Slf4j
@CompileStatic
class FlattenOp {

private static Session getSession() { Global.getSession() as Session }

private DataflowReadChannel source
private DataflowWriteChannel target

FlattenOp withSource(DataflowReadChannel source) {
assert source!=null
this.source = source
return this
}

FlattenOp setTarget( DataflowWriteChannel target ) {
this.target = target
return this
}

DataflowWriteChannel apply() {
final target = CH.create()
final stopOnFirst = source instanceof DataflowExpression

final listener = new DataflowEventAdapter() {
@Override
void afterRun(final DataflowProcessor dp, final List<Object> messages) {
if( stopOnFirst )
dp.terminate()
}

@Override
void afterStop(final DataflowProcessor dp) {
Op.bind(dp, target, Channel.STOP)
}

boolean onException(final DataflowProcessor dp, final Throwable e) {
FlattenOp.log.error("@unknown", e)
session.abort(e)
return true;
}
}

new Op()
.withInput(source)
.withListener(listener)
.withCode { Object item ->
final dp = getDelegate() as DataflowProcessor
switch( item ) {
case Collection:
((Collection)item).flatten().each { value -> Op.bind(dp, target, value) }
break

case (Object[]):
((Collection)item).flatten().each { value -> Op.bind(dp, target, value) }
break

case Channel.VOID:
break

default:
Op.bind(dp, target, item)
}
}
.apply()

return target
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -548,51 +548,9 @@ class OperatorImpl {
}

DataflowWriteChannel flatten( final DataflowReadChannel source ) {

final listeners = []
final target = CH.create()
final stopOnFirst = source instanceof DataflowExpression

listeners << new DataflowEventAdapter() {
@Override
void afterRun(final DataflowProcessor dp, final List<Object> messages) {
if( stopOnFirst )
dp.terminate()
}

@Override
void afterStop(final DataflowProcessor dp) {
dp.bindOutput(Channel.STOP)
}

boolean onException(final DataflowProcessor dp, final Throwable e) {
OperatorImpl.log.error("@unknown", e)
session.abort(e)
return true;
}
}

newOperator(inputs: [source], outputs: [target], listeners: listeners) { item ->

def proc = ((DataflowProcessor) getDelegate())
switch( item ) {
case Collection:
((Collection)item).flatten().each { value -> proc.bindOutput(value) }
break

case (Object[]):
((Collection)item).flatten().each { value -> proc.bindOutput(value) }
break

case Channel.VOID:
break

default:
proc.bindOutput(item)
}
}

return target
new FlattenOp()
.withSource(source)
.apply()
}

/**
Expand Down
37 changes: 37 additions & 0 deletions modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -735,4 +735,41 @@ class ProvTest extends Dsl2Spec {
upstreamTasksOf('p3 (4)')
.name == ['p2 (2)']
}

def 'should track provenance with flatten operator' () {
when:
dsl_eval(globalConfig(), '''
workflow {
channel.of([1,'a'], [2,'b']) \
| p1 \
| flatten \
| p2
}
process p1 {
input: val(x)
output: val(y)
exec:
y = x
}
process p2 {
input: val(x)
exec:
println x
}
''')
then:
upstreamTasksOf('p2 (1)')
.name == ['p1 (1)']
and:
upstreamTasksOf('p2 (2)')
.name == ['p1 (1)']
and:
upstreamTasksOf('p2 (3)')
.name == ['p1 (2)']
and:
upstreamTasksOf('p2 (4)')
.name == ['p1 (2)']
}
}

0 comments on commit 84f7806

Please sign in to comment.