Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HaplotypeCallerSpark extend AssemblyRegionWalkerSpark #5386

Merged
merged 2 commits into from
Nov 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;

import java.io.Serializable;

public abstract class AssemblyRegionArgumentCollection implements Serializable {
private static final long serialVersionUID = 1L;

@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
public int minAssemblyRegionSize = defaultMinAssemblyRegionSize();

@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
public int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();

@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
public int assemblyRegionPadding = defaultAssemblyRegionPadding();

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
public int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
public double activeProbThreshold = defaultActiveProbThreshold();

@Advanced
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
public int maxProbPropagationDistance = defaultMaxProbPropagationDistance();

/**
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMinAssemblyRegionSize();

/**
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxAssemblyRegionSize();

/**
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultAssemblyRegionPadding();

/**
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxReadsPerAlignmentStart();

/**
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
*/
protected abstract double defaultActiveProbThreshold();

/**
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxProbPropagationDistance();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.barclay.argparser.Argument;

import java.io.Serializable;

public class AssemblyRegionReadShardArgumentCollection implements Serializable {
private static final long serialVersionUID = 1L;

public static final int DEFAULT_READSHARD_SIZE = 5000;
public static final int DEFAULT_READSHARD_PADDING_SIZE = 100;

@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
public int readShardSize = DEFAULT_READSHARD_SIZE;

@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
public int readShardPadding = DEFAULT_READSHARD_PADDING_SIZE;
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
package org.broadinstitute.hellbender.engine.spark;

import com.google.common.collect.Iterators;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceMultiSparkSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.tools.DownsampleableSparkReadShard;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -32,71 +36,16 @@
public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool {
private static final long serialVersionUID = 1L;

@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
protected int readShardSize = defaultReadShardSize();
@ArgumentCollection
public final AssemblyRegionReadShardArgumentCollection shardingArgs = new AssemblyRegionReadShardArgumentCollection();

@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
protected int readShardPadding = defaultReadShardPadding();

@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize();

@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();

@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
protected int assemblyRegionPadding = defaultAssemblyRegionPadding();

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
protected double activeProbThreshold = defaultActiveProbThreshold();

@Advanced
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance();

/**
* @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardSize();

/**
* @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardPadding();

/**
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMinAssemblyRegionSize();

/**
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxAssemblyRegionSize();

/**
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultAssemblyRegionPadding();

/**
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxReadsPerAlignmentStart();
@ArgumentCollection
public final AssemblyRegionArgumentCollection assemblyRegionArgs = getAssemblyRegionArgumentCollection();

/**
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
* @return a subclass of {@link AssemblyRegionArgumentCollection} with the default values filled in.
*/
protected abstract double defaultActiveProbThreshold();

/**
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxProbPropagationDistance();
protected abstract AssemblyRegionArgumentCollection getAssemblyRegionArgumentCollection();

/**
* subclasses can override this to control if reads with deletions should be included in isActive pileups
Expand Down Expand Up @@ -128,6 +77,20 @@ public List<ReadFilter> getDefaultReadFilters() {
*/
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();

/**
* Tools that use an evaluator that is expensive to create, and/or that is not compatible with Spark broadcast, can
* override this method to return a broadcast of a supplier of the evaluator. The supplier will be invoked once for
* each Spark partition, thus each partition will have its own evaluator instance.
*/
protected Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast(final JavaSparkContext ctx) {
return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, assemblyRegionEvaluator());
}

private static Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcastFunction(final JavaSparkContext ctx, final AssemblyRegionEvaluator assemblyRegionEvaluator) {
Supplier<AssemblyRegionEvaluator> supplier = () -> assemblyRegionEvaluator;
return ctx.broadcast(supplier);
}

private List<ShardBoundary> intervalShards;

/**
Expand All @@ -138,10 +101,10 @@ protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals)
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals;
intervalShards = intervals.stream()
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream())
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardingArgs.readShardSize, shardingArgs.readShardPadding, sequenceDictionary).stream())
.collect(Collectors.toList());
List<SimpleInterval> paddedIntervalsForReads =
intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList());
intervals.stream().map(interval -> interval.expandWithinContig(shardingArgs.readShardPadding, sequenceDictionary)).collect(Collectors.toList());
return paddedIntervalsForReads;
}

Expand All @@ -154,37 +117,61 @@ protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals)
*/
protected JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
return getAssemblyRegions(ctx, getReads(), getHeaderForReads(), sequenceDictionary, referenceFileName, features,
intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, assemblyRegionArgs,
includeReadsWithDeletionsInIsActivePileups(), shuffle);
}

protected static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(
final JavaSparkContext ctx,
final JavaRDD<GATKRead> reads,
final SAMFileHeader header,
final SAMSequenceDictionary sequenceDictionary,
final String referenceFileName,
final FeatureManager features,
final List<ShardBoundary> intervalShards,
final Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean includeReadsWithDeletionsInIsActivePileups,
final boolean shuffle) {
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, getHeaderForReads(),
assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups()));
return shardedReads.mapPartitions(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs, includeReadsWithDeletionsInIsActivePileups));
}

private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(
private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(
final String referenceFileName,
final Broadcast<FeatureManager> bFeatureManager,
final SAMFileHeader header,
final AssemblyRegionEvaluator evaluator,
final int minAssemblyRegionSize,
final int maxAssemblyRegionSize,
final int assemblyRegionPadding,
final double activeProbThreshold,
final int maxProbPropagationDistance,
final Broadcast<Supplier<AssemblyRegionEvaluator>> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean includeReadsWithDeletionsInIsActivePileups) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
return (FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext>) shardedReadIterator -> {
ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();

final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(shardedRead),
header, reference, features, evaluator,
minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold,
maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups);
final Iterable<AssemblyRegion> assemblyRegions = () -> assemblyRegionIter;
return Utils.stream(assemblyRegions).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
AssemblyRegionEvaluator assemblyRegionEvaluator = supplierBroadcast.getValue().get(); // one AssemblyRegionEvaluator instance per Spark partition
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the downsampler want to be configurable at this stage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that's needed yet since the examples and HaplotypeCaller both use the PositionalDownsampler. There would be a bit of work to ensure the downsampler is serializable.


Iterator<Iterator<AssemblyRegionWalkerContext>> iterators = Utils.stream(shardedReadIterator)
.map(shardedRead -> new ShardToMultiIntervalShardAdapter<>(
new DownsampleableSparkReadShard(
new ShardBoundary(shardedRead.getInterval(), shardedRead.getPaddedInterval()), shardedRead, readsDownsampler)))
.map(shardedRead -> {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potential optimization we may want to look at for the haplotype caller is being able to generate these assembly regions and reconstruct the constituent up the reads later to save on holding an entire bam in an rdd at a time. That doesn't seem easy to do given the way this class is structured.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a branch based on this one that does exactly that :-)

new ShardToMultiIntervalShardAdapter<>(shardedRead),
header, reference, features, assemblyRegionEvaluator,
assemblyRegionArgs.minAssemblyRegionSize, assemblyRegionArgs.maxAssemblyRegionSize,
assemblyRegionArgs.assemblyRegionPadding, assemblyRegionArgs.activeProbThreshold,
assemblyRegionArgs.maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups);
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
}).iterator();
return Iterators.concat(iterators);
};
}

Expand Down
Loading