Skip to content

Commit

Permalink
Make HaplotypeCallerSpark extend AssemblyRegionWalkerSpark (broadins…
Browse files Browse the repository at this point in the history
…titute#5386)

* Factor out command line options for assembly region walkers (Spark)

* Make HaplotypeCallerSpark extend AssemblyRegionWalkerSpark, while still allowing
ReadsPipelineSpark to share common code.
  • Loading branch information
tomwhite authored and EdwardDixon committed Nov 9, 2018
1 parent 56d443d commit 575b014
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 381 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;

import java.io.Serializable;

public abstract class AssemblyRegionArgumentCollection implements Serializable {
private static final long serialVersionUID = 1L;

@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
public int minAssemblyRegionSize = defaultMinAssemblyRegionSize();

@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
public int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();

@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
public int assemblyRegionPadding = defaultAssemblyRegionPadding();

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
public int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
public double activeProbThreshold = defaultActiveProbThreshold();

@Advanced
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
public int maxProbPropagationDistance = defaultMaxProbPropagationDistance();

/**
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMinAssemblyRegionSize();

/**
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxAssemblyRegionSize();

/**
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultAssemblyRegionPadding();

/**
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxReadsPerAlignmentStart();

/**
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
*/
protected abstract double defaultActiveProbThreshold();

/**
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxProbPropagationDistance();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.barclay.argparser.Argument;

import java.io.Serializable;

public class AssemblyRegionReadShardArgumentCollection implements Serializable {
private static final long serialVersionUID = 1L;

public static final int DEFAULT_READSHARD_SIZE = 5000;
public static final int DEFAULT_READSHARD_PADDING_SIZE = 100;

@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
public int readShardSize = DEFAULT_READSHARD_SIZE;

@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
public int readShardPadding = DEFAULT_READSHARD_PADDING_SIZE;
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
package org.broadinstitute.hellbender.engine.spark;

import com.google.common.collect.Iterators;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceMultiSparkSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.tools.DownsampleableSparkReadShard;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -32,71 +36,16 @@
public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool {
private static final long serialVersionUID = 1L;

@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
protected int readShardSize = defaultReadShardSize();
@ArgumentCollection
public final AssemblyRegionReadShardArgumentCollection shardingArgs = new AssemblyRegionReadShardArgumentCollection();

@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
protected int readShardPadding = defaultReadShardPadding();

@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize();

@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();

@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
protected int assemblyRegionPadding = defaultAssemblyRegionPadding();

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
protected double activeProbThreshold = defaultActiveProbThreshold();

@Advanced
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance();

/**
* @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardSize();

/**
* @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardPadding();

/**
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMinAssemblyRegionSize();

/**
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxAssemblyRegionSize();

/**
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultAssemblyRegionPadding();

/**
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxReadsPerAlignmentStart();
@ArgumentCollection
public final AssemblyRegionArgumentCollection assemblyRegionArgs = getAssemblyRegionArgumentCollection();

/**
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
* @return a subclass of {@link AssemblyRegionArgumentCollection} with the default values filled in.
*/
protected abstract double defaultActiveProbThreshold();

/**
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxProbPropagationDistance();
protected abstract AssemblyRegionArgumentCollection getAssemblyRegionArgumentCollection();

/**
* subclasses can override this to control if reads with deletions should be included in isActive pileups
Expand Down Expand Up @@ -128,6 +77,20 @@ public List<ReadFilter> getDefaultReadFilters() {
*/
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();

/**
* Tools that use an evaluator that is expensive to create, and/or that is not compatible with Spark broadcast, can
* override this method to return a broadcast of a supplier of the evaluator. The supplier will be invoked once for
* each Spark partition, thus each partition will have its own evaluator instance.
*/
protected Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast(final JavaSparkContext ctx) {
return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, assemblyRegionEvaluator());
}

private static Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcastFunction(final JavaSparkContext ctx, final AssemblyRegionEvaluator assemblyRegionEvaluator) {
Supplier<AssemblyRegionEvaluator> supplier = () -> assemblyRegionEvaluator;
return ctx.broadcast(supplier);
}

private List<ShardBoundary> intervalShards;

/**
Expand All @@ -138,10 +101,10 @@ protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals)
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals;
intervalShards = intervals.stream()
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream())
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardingArgs.readShardSize, shardingArgs.readShardPadding, sequenceDictionary).stream())
.collect(Collectors.toList());
List<SimpleInterval> paddedIntervalsForReads =
intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList());
intervals.stream().map(interval -> interval.expandWithinContig(shardingArgs.readShardPadding, sequenceDictionary)).collect(Collectors.toList());
return paddedIntervalsForReads;
}

Expand All @@ -154,37 +117,61 @@ protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals)
*/
protected JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
return getAssemblyRegions(ctx, getReads(), getHeaderForReads(), sequenceDictionary, referenceFileName, features,
intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, assemblyRegionArgs,
includeReadsWithDeletionsInIsActivePileups(), shuffle);
}

protected static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(
final JavaSparkContext ctx,
final JavaRDD<GATKRead> reads,
final SAMFileHeader header,
final SAMSequenceDictionary sequenceDictionary,
final String referenceFileName,
final FeatureManager features,
final List<ShardBoundary> intervalShards,
final Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean includeReadsWithDeletionsInIsActivePileups,
final boolean shuffle) {
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, getHeaderForReads(),
assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups()));
return shardedReads.mapPartitions(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs, includeReadsWithDeletionsInIsActivePileups));
}

private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(
private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(
final String referenceFileName,
final Broadcast<FeatureManager> bFeatureManager,
final SAMFileHeader header,
final AssemblyRegionEvaluator evaluator,
final int minAssemblyRegionSize,
final int maxAssemblyRegionSize,
final int assemblyRegionPadding,
final double activeProbThreshold,
final int maxProbPropagationDistance,
final Broadcast<Supplier<AssemblyRegionEvaluator>> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean includeReadsWithDeletionsInIsActivePileups) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
return (FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext>) shardedReadIterator -> {
ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();

final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(shardedRead),
header, reference, features, evaluator,
minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold,
maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups);
final Iterable<AssemblyRegion> assemblyRegions = () -> assemblyRegionIter;
return Utils.stream(assemblyRegions).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
AssemblyRegionEvaluator assemblyRegionEvaluator = supplierBroadcast.getValue().get(); // one AssemblyRegionEvaluator instance per Spark partition
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;

Iterator<Iterator<AssemblyRegionWalkerContext>> iterators = Utils.stream(shardedReadIterator)
.map(shardedRead -> new ShardToMultiIntervalShardAdapter<>(
new DownsampleableSparkReadShard(
new ShardBoundary(shardedRead.getInterval(), shardedRead.getPaddedInterval()), shardedRead, readsDownsampler)))
.map(shardedRead -> {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(shardedRead),
header, reference, features, assemblyRegionEvaluator,
assemblyRegionArgs.minAssemblyRegionSize, assemblyRegionArgs.maxAssemblyRegionSize,
assemblyRegionArgs.assemblyRegionPadding, assemblyRegionArgs.activeProbThreshold,
assemblyRegionArgs.maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups);
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
}).iterator();
return Iterators.concat(iterators);
};
}

Expand Down
Loading

0 comments on commit 575b014

Please sign in to comment.