diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionArgumentCollection.java new file mode 100644 index 00000000000..4b6a4536039 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionArgumentCollection.java @@ -0,0 +1,60 @@ +package org.broadinstitute.hellbender.engine.spark; + +import org.broadinstitute.barclay.argparser.Advanced; +import org.broadinstitute.barclay.argparser.Argument; + +import java.io.Serializable; + +public abstract class AssemblyRegionArgumentCollection implements Serializable { + private static final long serialVersionUID = 1L; + + @Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true) + public int minAssemblyRegionSize = defaultMinAssemblyRegionSize(); + + @Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true) + public int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize(); + + @Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true) + public int assemblyRegionPadding = defaultAssemblyRegionPadding(); + + @Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true) + public int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart(); + + @Advanced + @Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true) + public double activeProbThreshold = defaultActiveProbThreshold(); + + @Advanced + @Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true) + public int maxProbPropagationDistance = defaultMaxProbPropagationDistance(); + + /** + * @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line + */ + protected abstract int defaultMinAssemblyRegionSize(); + + /** + * @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line + */ + protected abstract int defaultMaxAssemblyRegionSize(); + + /** + * @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line + */ + protected abstract int defaultAssemblyRegionPadding(); + + /** + * @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line + */ + protected abstract int defaultMaxReadsPerAlignmentStart(); + + /** + * @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line + */ + protected abstract double defaultActiveProbThreshold(); + + /** + * @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line + */ + protected abstract int defaultMaxProbPropagationDistance(); +} diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionReadShardArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionReadShardArgumentCollection.java new file mode 100644 index 00000000000..40195755144 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionReadShardArgumentCollection.java @@ -0,0 +1,18 @@ +package org.broadinstitute.hellbender.engine.spark; + +import org.broadinstitute.barclay.argparser.Argument; + +import java.io.Serializable; + +public class AssemblyRegionReadShardArgumentCollection implements Serializable { + private static final long serialVersionUID = 1L; + + public static final int DEFAULT_READSHARD_SIZE = 5000; + public static final int DEFAULT_READSHARD_PADDING_SIZE = 100; + + @Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true) + public int readShardSize = DEFAULT_READSHARD_SIZE; + + @Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true) + public int readShardPadding = DEFAULT_READSHARD_PADDING_SIZE; +} diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionWalkerSpark.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionWalkerSpark.java index ee40479590d..dd9ae834374 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionWalkerSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/AssemblyRegionWalkerSpark.java @@ -1,5 +1,6 @@ package org.broadinstitute.hellbender.engine.spark; +import com.google.common.collect.Iterators; import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMSequenceDictionary; import org.apache.spark.SparkFiles; @@ -7,22 +8,25 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.broadinstitute.barclay.argparser.Advanced; import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; import org.broadinstitute.hellbender.engine.*; -import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceMultiSparkSource; import org.broadinstitute.hellbender.engine.filters.ReadFilter; import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary; import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter; +import org.broadinstitute.hellbender.tools.DownsampleableSparkReadShard; import org.broadinstitute.hellbender.utils.IntervalUtils; import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler; +import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler; import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -32,71 +36,16 @@ public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool { private static final long serialVersionUID = 1L; - @Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true) - protected int readShardSize = defaultReadShardSize(); + @ArgumentCollection + public final AssemblyRegionReadShardArgumentCollection shardingArgs = new AssemblyRegionReadShardArgumentCollection(); - @Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true) - protected int readShardPadding = defaultReadShardPadding(); - - @Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true) - protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize(); - - @Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true) - protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize(); - - @Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true) - protected int assemblyRegionPadding = defaultAssemblyRegionPadding(); - - @Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true) - protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart(); - - @Advanced - @Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true) - protected double activeProbThreshold = defaultActiveProbThreshold(); - - @Advanced - @Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true) - protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance(); - - /** - * @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line - */ - protected abstract int defaultReadShardSize(); - - /** - * @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line - */ - protected abstract int defaultReadShardPadding(); - - /** - * @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line - */ - protected abstract int defaultMinAssemblyRegionSize(); - - /** - * @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line - */ - protected abstract int defaultMaxAssemblyRegionSize(); - - /** - * @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line - */ - protected abstract int defaultAssemblyRegionPadding(); - - /** - * @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line - */ - protected abstract int defaultMaxReadsPerAlignmentStart(); + @ArgumentCollection + public final AssemblyRegionArgumentCollection assemblyRegionArgs = getAssemblyRegionArgumentCollection(); /** - * @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line + * @return a subclass of {@link AssemblyRegionArgumentCollection} with the default values filled in. */ - protected abstract double defaultActiveProbThreshold(); - - /** - * @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line - */ - protected abstract int defaultMaxProbPropagationDistance(); + protected abstract AssemblyRegionArgumentCollection getAssemblyRegionArgumentCollection(); /** * subclasses can override this to control if reads with deletions should be included in isActive pileups @@ -128,6 +77,20 @@ public List getDefaultReadFilters() { */ public abstract AssemblyRegionEvaluator assemblyRegionEvaluator(); + /** + * Tools that use an evaluator that is expensive to create, and/or that is not compatible with Spark broadcast, can + * override this method to return a broadcast of a supplier of the evaluator. The supplier will be invoked once for + * each Spark partition, thus each partition will have its own evaluator instance. + */ + protected Broadcast> assemblyRegionEvaluatorSupplierBroadcast(final JavaSparkContext ctx) { + return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, assemblyRegionEvaluator()); + } + + private static Broadcast> assemblyRegionEvaluatorSupplierBroadcastFunction(final JavaSparkContext ctx, final AssemblyRegionEvaluator assemblyRegionEvaluator) { + Supplier supplier = () -> assemblyRegionEvaluator; + return ctx.broadcast(supplier); + } + private List intervalShards; /** @@ -138,10 +101,10 @@ protected List editIntervals(List rawIntervals) SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); List intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals; intervalShards = intervals.stream() - .flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()) + .flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardingArgs.readShardSize, shardingArgs.readShardPadding, sequenceDictionary).stream()) .collect(Collectors.toList()); List paddedIntervalsForReads = - intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList()); + intervals.stream().map(interval -> interval.expandWithinContig(shardingArgs.readShardPadding, sequenceDictionary)).collect(Collectors.toList()); return paddedIntervalsForReads; } @@ -154,37 +117,61 @@ protected List editIntervals(List rawIntervals) */ protected JavaRDD getAssemblyRegions(JavaSparkContext ctx) { SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); - JavaRDD> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle); + return getAssemblyRegions(ctx, getReads(), getHeaderForReads(), sequenceDictionary, referenceFileName, features, + intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, assemblyRegionArgs, + includeReadsWithDeletionsInIsActivePileups(), shuffle); + } + + protected static JavaRDD getAssemblyRegions( + final JavaSparkContext ctx, + final JavaRDD reads, + final SAMFileHeader header, + final SAMSequenceDictionary sequenceDictionary, + final String referenceFileName, + final FeatureManager features, + final List intervalShards, + final Broadcast> assemblyRegionEvaluatorSupplierBroadcast, + final AssemblyRegionReadShardArgumentCollection shardingArgs, + final AssemblyRegionArgumentCollection assemblyRegionArgs, + final boolean includeReadsWithDeletionsInIsActivePileups, + final boolean shuffle) { + JavaRDD> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle); Broadcast bFeatureManager = features == null ? null : ctx.broadcast(features); - return shardedReads.flatMap(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, getHeaderForReads(), - assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups())); + return shardedReads.mapPartitions(getAssemblyRegionsFunction(referenceFileName, bFeatureManager, header, + assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs, includeReadsWithDeletionsInIsActivePileups)); } - private static FlatMapFunction, AssemblyRegionWalkerContext> getAssemblyRegionsFunction( + private static FlatMapFunction>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction( final String referenceFileName, final Broadcast bFeatureManager, final SAMFileHeader header, - final AssemblyRegionEvaluator evaluator, - final int minAssemblyRegionSize, - final int maxAssemblyRegionSize, - final int assemblyRegionPadding, - final double activeProbThreshold, - final int maxProbPropagationDistance, + final Broadcast> supplierBroadcast, + final AssemblyRegionArgumentCollection assemblyRegionArgs, final boolean includeReadsWithDeletionsInIsActivePileups) { - return (FlatMapFunction, AssemblyRegionWalkerContext>) shardedRead -> { + return (FlatMapFunction>, AssemblyRegionWalkerContext>) shardedReadIterator -> { ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName))); final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue(); - - final Iterator assemblyRegionIter = new AssemblyRegionIterator( - new ShardToMultiIntervalShardAdapter<>(shardedRead), - header, reference, features, evaluator, - minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, - maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups); - final Iterable assemblyRegions = () -> assemblyRegionIter; - return Utils.stream(assemblyRegions).map(assemblyRegion -> - new AssemblyRegionWalkerContext(assemblyRegion, - new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), - new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator(); + AssemblyRegionEvaluator assemblyRegionEvaluator = supplierBroadcast.getValue().get(); // one AssemblyRegionEvaluator instance per Spark partition + final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ? + new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null; + + Iterator> iterators = Utils.stream(shardedReadIterator) + .map(shardedRead -> new ShardToMultiIntervalShardAdapter<>( + new DownsampleableSparkReadShard( + new ShardBoundary(shardedRead.getInterval(), shardedRead.getPaddedInterval()), shardedRead, readsDownsampler))) + .map(shardedRead -> { + final Iterator assemblyRegionIter = new AssemblyRegionIterator( + new ShardToMultiIntervalShardAdapter<>(shardedRead), + header, reference, features, assemblyRegionEvaluator, + assemblyRegionArgs.minAssemblyRegionSize, assemblyRegionArgs.maxAssemblyRegionSize, + assemblyRegionArgs.assemblyRegionPadding, assemblyRegionArgs.activeProbThreshold, + assemblyRegionArgs.maxProbPropagationDistance, includeReadsWithDeletionsInIsActivePileups); + return Utils.stream(assemblyRegionIter).map(assemblyRegion -> + new AssemblyRegionWalkerContext(assemblyRegion, + new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), + new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator(); + }).iterator(); + return Iterators.concat(iterators); }; } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java index eebbb13429f..2db87368bb6 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java @@ -1,57 +1,53 @@ package org.broadinstitute.hellbender.tools; -import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterators; import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMSequenceDictionary; -import htsjdk.samtools.reference.ReferenceSequence; import htsjdk.samtools.reference.ReferenceSequenceFile; import htsjdk.samtools.util.IOUtil; import htsjdk.variant.variantcontext.VariantContext; +import org.apache.logging.log4j.Logger; import org.apache.spark.SparkFiles; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.broadinstitute.barclay.argparser.*; -import org.broadinstitute.barclay.argparser.Advanced; import org.broadinstitute.barclay.argparser.Argument; import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; import org.broadinstitute.barclay.help.DocumentedFeature; -import org.broadinstitute.hellbender.cmdline.*; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup; -import org.broadinstitute.hellbender.engine.*; -import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceMultiSparkSource; +import org.broadinstitute.hellbender.engine.AssemblyRegion; +import org.broadinstitute.hellbender.engine.AssemblyRegionEvaluator; +import org.broadinstitute.hellbender.engine.FeatureContext; +import org.broadinstitute.hellbender.engine.ShardBoundary; import org.broadinstitute.hellbender.engine.filters.ReadFilter; -import org.broadinstitute.hellbender.engine.spark.GATKSparkTool; -import org.broadinstitute.hellbender.engine.ShardToMultiIntervalShardAdapter; -import org.broadinstitute.hellbender.engine.spark.SparkSharder; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionArgumentCollection; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionReadShardArgumentCollection; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionWalkerContext; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionWalkerSpark; import org.broadinstitute.hellbender.engine.spark.datasources.VariantsSparkSink; -import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; -import org.broadinstitute.hellbender.tools.walkers.annotator.*; +import org.broadinstitute.hellbender.tools.walkers.annotator.Annotation; +import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReferenceConfidenceMode; -import org.broadinstitute.hellbender.utils.IntervalUtils; -import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.Utils; -import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler; -import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler; import org.broadinstitute.hellbender.utils.fasta.CachingIndexedFastaSequenceFile; import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; -import org.broadinstitute.hellbender.utils.reference.ReferenceBases; -import scala.Tuple2; import java.io.IOException; -import java.io.Serializable; import java.nio.file.Path; -import java.util.*; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; /** * ******************************************************************************** @@ -79,59 +75,46 @@ @CommandLineProgramProperties(summary = "HaplotypeCaller on Spark", oneLineSummary = "HaplotypeCaller on Spark", programGroup = ShortVariantDiscoveryProgramGroup.class) @DocumentedFeature @BetaFeature -public final class HaplotypeCallerSpark extends GATKSparkTool { +public final class HaplotypeCallerSpark extends AssemblyRegionWalkerSpark { private static final long serialVersionUID = 1L; public static final int DEFAULT_READSHARD_SIZE = 5000; - private static final boolean INCLUDE_READS_WITH_DELETIONS_IN_IS_ACTIVE_PILEUPS = true; @Argument(fullName= StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "Single file to which variants should be written") public String output; - @ArgumentCollection - public final ShardingArgumentCollection shardingArgs = new ShardingArgumentCollection(); - - public static class ShardingArgumentCollection implements Serializable { + public static class HaplotypeCallerAssemblyRegionArgumentCollection extends AssemblyRegionArgumentCollection { private static final long serialVersionUID = 1L; - @Argument(fullName="read-shard-size", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true) - public int readShardSize = DEFAULT_READSHARD_SIZE; - - @Argument(fullName="read-shard-padding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true) - public int readShardPadding = HaplotypeCaller.DEFAULT_ASSEMBLY_REGION_PADDING; - - @Argument(fullName = AssemblyRegionWalker.MIN_ASSEMBLY_LONG_NAME, doc = "Minimum size of an assembly region", optional = true) - public int minAssemblyRegionSize = HaplotypeCaller.DEFAULT_MIN_ASSEMBLY_REGION_SIZE; + @Override + protected int defaultMinAssemblyRegionSize() { return HaplotypeCaller.DEFAULT_MIN_ASSEMBLY_REGION_SIZE; } - @Argument(fullName = AssemblyRegionWalker.MAX_ASSEMBLY_LONG_NAME, doc = "Maximum size of an assembly region", optional = true) - public int maxAssemblyRegionSize = HaplotypeCaller.DEFAULT_MAX_ASSEMBLY_REGION_SIZE; + @Override + protected int defaultMaxAssemblyRegionSize() { return HaplotypeCaller.DEFAULT_MAX_ASSEMBLY_REGION_SIZE; } - @Argument(fullName = AssemblyRegionWalker.ASSEMBLY_PADDING_LONG_NAME, doc = "Number of additional bases of context to include around each assembly region", optional = true) - public int assemblyRegionPadding = HaplotypeCaller.DEFAULT_ASSEMBLY_REGION_PADDING; + @Override + protected int defaultAssemblyRegionPadding() { return HaplotypeCaller.DEFAULT_ASSEMBLY_REGION_PADDING; } - @Argument(fullName = AssemblyRegionWalker.MAX_STARTS_LONG_NAME, doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true) - public int maxReadsPerAlignmentStart = HaplotypeCaller.DEFAULT_MAX_READS_PER_ALIGNMENT; + @Override + protected int defaultMaxReadsPerAlignmentStart() { return HaplotypeCaller.DEFAULT_MAX_READS_PER_ALIGNMENT; } - @Advanced - @Argument(fullName = AssemblyRegionWalker.THRESHOLD_LONG_NAME, doc="Minimum probability for a locus to be considered active.", optional = true) - public double activeProbThreshold = HaplotypeCaller.DEFAULT_ACTIVE_PROB_THRESHOLD; + @Override + protected double defaultActiveProbThreshold() { return HaplotypeCaller.DEFAULT_ACTIVE_PROB_THRESHOLD; } - @Advanced - @Argument(fullName = AssemblyRegionWalker.PROPAGATION_LONG_NAME, doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true) - public int maxProbPropagationDistance = HaplotypeCaller.DEFAULT_MAX_PROB_PROPAGATION_DISTANCE; + @Override + protected int defaultMaxProbPropagationDistance() { return HaplotypeCaller.DEFAULT_MAX_PROB_PROPAGATION_DISTANCE; } + } + @Override + protected AssemblyRegionArgumentCollection getAssemblyRegionArgumentCollection() { + return new HaplotypeCallerAssemblyRegionArgumentCollection(); } @ArgumentCollection public HaplotypeCallerArgumentCollection hcArgs = new HaplotypeCallerArgumentCollection(); @Override - public boolean requiresReads(){ - return true; - } - - @Override - public boolean requiresReference(){ + protected boolean includeReadsWithDeletionsInIsActivePileups() { return true; } @@ -149,12 +132,27 @@ public Collection makeVariantAnnotations() { } @Override - protected void runTool(final JavaSparkContext ctx) { + protected void processAssemblyRegions(JavaRDD rdd, JavaSparkContext ctx) { + processAssemblyRegions(rdd, ctx, getHeaderForReads(), referenceArguments.getReferenceFileName(), hcArgs, output, makeVariantAnnotations(), logger); + } + + private static void processAssemblyRegions( + JavaRDD rdd, + final JavaSparkContext ctx, + final SAMFileHeader header, + final String reference, + final HaplotypeCallerArgumentCollection hcArgs, + final String output, + final Collection annotations, + final Logger logger) { //TODO remove me when https://github.com/broadinstitute/gatk/issues/4303 are fixed if (output.endsWith(IOUtil.BCF_FILE_EXTENSION) || output.endsWith(IOUtil.BCF_FILE_EXTENSION + ".gz")) { throw new UserException.UnimplementedFeature("It is currently not possible to write a BCF file on spark. See https://github.com/broadinstitute/gatk/issues/4303 for more details ."); } - SAMFileHeader header = getHeaderForReads(); + Utils.validateArg(hcArgs.dbsnp.dbsnp == null, "HaplotypeCallerSpark does not yet support -D or --dbsnp arguments" ); + Utils.validateArg(hcArgs.comps.isEmpty(), "HaplotypeCallerSpark does not yet support -comp or --comp arguments" ); + Utils.validateArg(hcArgs.bamOutputPath == null, "HaplotypeCallerSpark does not yet support -bamout or --bamOutput"); + Utils.validate(header.getSortOrder() == SAMFileHeader.SortOrder.coordinate, "The reads must be coordinate sorted."); logger.info("********************************************************************************"); logger.info("The output of this tool DOES NOT match the output of HaplotypeCaller. "); @@ -162,49 +160,18 @@ protected void runTool(final JavaSparkContext ctx) { logger.info("For evaluation only."); logger.info("Use the non-spark HaplotypeCaller if you care about the results. "); logger.info("********************************************************************************"); - addReferenceFilesForSpark(ctx, referenceArguments.getReferenceFileName()); - final List intervals = hasUserSuppliedIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(header.getSequenceDictionary()); - callVariantsWithHaplotypeCallerAndWriteOutput(ctx, getReads(), header, referenceArguments.getReferenceFileName(), intervals, hcArgs, shardingArgs, numReducers, output, makeVariantAnnotations()); - } - @Override - public List getDefaultReadFilters() { - return HaplotypeCallerEngine.makeStandardHCReadFilters(); - } - - /** - * Call Variants using HaplotypeCaller on Spark and write out a VCF file. - * - * This may be called from any spark pipeline in order to call variants from an RDD of GATKRead - * - * @param ctx the spark context - * @param reads the reads variants should be called from - * @param header the header that goes with the reads - * @param reference the full path to the reference file (must have been added via {@code SparkContext#addFile()}) - * @param intervals the intervals to restrict calling to - * @param hcArgs haplotype caller arguments - * @param shardingArgs arguments to control how the assembly regions are sharded - * @param numReducers the number of reducers to use when sorting - * @param output the output path for the VCF - */ - public static void callVariantsWithHaplotypeCallerAndWriteOutput( - final JavaSparkContext ctx, - final JavaRDD reads, - final SAMFileHeader header, - final String reference, - final List intervals, - final HaplotypeCallerArgumentCollection hcArgs, - final ShardingArgumentCollection shardingArgs, - final int numReducers, - final String output, - final Collection annotations) { final VariantAnnotatorEngine variantannotatorEngine = new VariantAnnotatorEngine(annotations, hcArgs.dbsnp.dbsnp, hcArgs.comps, hcArgs.emitReferenceConfidence != ReferenceConfidenceMode.NONE); final Path referencePath = IOUtils.getPath(reference); final ReferenceSequenceFile driverReferenceSequenceFile = new CachingIndexedFastaSequenceFile(referencePath); final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, false, false, header, driverReferenceSequenceFile, variantannotatorEngine); final String referenceFileName = referencePath.getFileName().toString(); - final JavaRDD variants = callVariantsWithHaplotypeCaller(ctx, reads, header, referenceFileName, intervals, hcArgs, shardingArgs, variantannotatorEngine); + final Broadcast hcArgsBroadcast = ctx.broadcast(hcArgs); + final Broadcast annotatorEngineBroadcast = ctx.broadcast(variantannotatorEngine); + + final JavaRDD variants = rdd.mapPartitions(assemblyFunction(header, referenceFileName, hcArgsBroadcast, annotatorEngineBroadcast)); + variants.cache(); // without caching, computations are run twice as a side effect of finding partition boundaries for sorting try { VariantsSparkSink.writeVariants(ctx, output, variants, hcEngine.makeVCFHeader(header.getSequenceDictionary(), new HashSet<>()), @@ -214,199 +181,111 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput( } } - /** - * Call Variants using HaplotypeCaller on Spark and return an RDD of {@link VariantContext} - * - * This may be called from any spark pipeline in order to call variants from an RDD of GATKRead - * - * @param ctx the spark context - * @param reads the reads variants should be called from - * @param header the header that goes with the reads - * @param referenceFileName the name of the reference file added via {@code SparkContext#addFile()} - * @param intervals the intervals to restrict calling to - * @param hcArgs haplotype caller arguments - * @param shardingArgs arguments to control how the assembly regions are sharded - * @param variantannotatorEngine - * @return an RDD of Variants - */ - public static JavaRDD callVariantsWithHaplotypeCaller( - final JavaSparkContext ctx, - final JavaRDD reads, - final SAMFileHeader header, - final String referenceFileName, - final List intervals, - final HaplotypeCallerArgumentCollection hcArgs, - final ShardingArgumentCollection shardingArgs, - final VariantAnnotatorEngine variantannotatorEngine) { - Utils.validateArg(hcArgs.dbsnp.dbsnp == null, "HaplotypeCallerSpark does not yet support -D or --dbsnp arguments" ); - Utils.validateArg(hcArgs.comps.isEmpty(), "HaplotypeCallerSpark does not yet support -comp or --comp arguments" ); - Utils.validateArg(hcArgs.bamOutputPath == null, "HaplotypeCallerSpark does not yet support -bamout or --bamOutput"); - - final Broadcast hcArgsBroadcast = ctx.broadcast(hcArgs); - - final Broadcast annotatorEngineBroadcast = ctx.broadcast(variantannotatorEngine); - - final List shardBoundaries = getShardBoundaries(header, intervals, shardingArgs.readShardSize, shardingArgs.readShardPadding); - - final int maxReadLength = reads.map(r -> r.getEnd() - r.getStart() + 1).reduce(Math::max); - - final JavaRDD> readShards = SparkSharder.shard(ctx, reads, GATKRead.class, header.getSequenceDictionary(), shardBoundaries, maxReadLength); - - final JavaRDD> assemblyRegions = readShards - .mapPartitions(shardsToAssemblyRegions(referenceFileName, - hcArgsBroadcast, shardingArgs, header, annotatorEngineBroadcast)); + private static FlatMapFunction, VariantContext> assemblyFunction(final SAMFileHeader header, + final String referenceFileName, + final Broadcast hcArgsBroadcast, + final Broadcast annotatorEngineBroadcast) { + return (FlatMapFunction, VariantContext>) contexts -> { + // HaplotypeCallerEngine isn't serializable but is expensive to instantiate, so construct and reuse one for every partition + final ReferenceSequenceFile taskReferenceSequenceFile = taskReferenceSequenceFile(referenceFileName); + final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), false, false, header, taskReferenceSequenceFile, annotatorEngineBroadcast.getValue()); + Iterator> iterators = Utils.stream(contexts).map(context -> { + AssemblyRegion region = context.getAssemblyRegion(); + FeatureContext featureContext = context.getFeatureContext(); + return hcEngine.callRegion(region, featureContext).iterator(); + }).iterator(); - return assemblyRegions.mapPartitions(callVariantsFromAssemblyRegions(header, referenceFileName, hcArgsBroadcast, annotatorEngineBroadcast)); + return Iterators.concat(iterators); + }; } - /** - * Call variants from Tuples of AssemblyRegion and Simple Interval - * The interval should be the non-padded shard boundary for the shard that the corresponding AssemblyRegion was - * created in, it's used to eliminate redundant variant calls at the edge of shard boundaries. - */ - private static FlatMapFunction>, VariantContext> callVariantsFromAssemblyRegions( - final SAMFileHeader header, - final String referenceFileName, - final Broadcast hcArgsBroadcast, - final Broadcast annotatorEngineBroadcast) { - return regionAndIntervals -> { - //HaplotypeCallerEngine isn't serializable but is expensive to instantiate, so construct and reuse one for every partition - final String pathOnExecutor = SparkFiles.get(referenceFileName); - final ReferenceSequenceFile taskReferenceSequenceFile = new CachingIndexedFastaSequenceFile(IOUtils.getPath(pathOnExecutor)); - final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), false, false, header, taskReferenceSequenceFile, annotatorEngineBroadcast.getValue()); - return Utils.stream(regionAndIntervals).flatMap(regionToVariants(hcEngine)).iterator(); - }; + @Override + public List getDefaultReadFilters() { + return HaplotypeCallerEngine.makeStandardHCReadFilters(); } - private static Function, Stream> regionToVariants(HaplotypeCallerEngine hcEngine) { - return regionAndInterval -> { - final List variantContexts = hcEngine.callRegion(regionAndInterval._1(), new FeatureContext()); - final SimpleInterval shardBoundary = regionAndInterval._2(); - return variantContexts.stream() - .filter(vc -> shardBoundary.contains(new SimpleInterval(vc.getContig(), vc.getStart(), vc.getStart()))); - }; + @Override + public AssemblyRegionEvaluator assemblyRegionEvaluator() { + return null; // not used (see assemblyRegionEvaluatorSupplierBroadcast) } - /** - * @return a list of {@link ShardBoundary} - * based on the -L intervals - */ - private static List getShardBoundaries(final SAMFileHeader - header, final List intervals, final int readShardSize, final int readShardPadding) { - return intervals.stream() - .flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, header.getSequenceDictionary()).stream()) - .collect(Collectors.toList()); + @Override + protected Broadcast> assemblyRegionEvaluatorSupplierBroadcast(final JavaSparkContext ctx) { + final Path referencePath = IOUtils.getPath(referenceArguments.getReferenceFileName()); + final String referenceFileName = referencePath.getFileName().toString(); + final String pathOnExecutor = SparkFiles.get(referenceFileName); + final ReferenceSequenceFile taskReferenceSequenceFile = new CachingIndexedFastaSequenceFile(IOUtils.getPath(pathOnExecutor)); + final Collection annotations = makeVariantAnnotations(); + final VariantAnnotatorEngine annotatorEngine = new VariantAnnotatorEngine(annotations, hcArgs.dbsnp.dbsnp, hcArgs.comps, hcArgs.emitReferenceConfidence != ReferenceConfidenceMode.NONE); + return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, hcArgs, getHeaderForReads(), taskReferenceSequenceFile, annotatorEngine); } - /** - * @return and RDD of {@link Tuple2} which pairs each AssemblyRegion with the - * interval it was generated in - */ - private static FlatMapFunction>, Tuple2> shardsToAssemblyRegions( - final String referenceFileName, - final Broadcast hcArgsBroadcast, - final ShardingArgumentCollection assemblyArgs, + private static Broadcast> assemblyRegionEvaluatorSupplierBroadcast( + final JavaSparkContext ctx, + final HaplotypeCallerArgumentCollection hcArgs, final SAMFileHeader header, - final Broadcast annotatorEngineBroadcast) { - return shards -> { - final String pathOnExecutor = SparkFiles.get(referenceFileName); - final ReferenceSequenceFile taskReferenceSequenceFile = new CachingIndexedFastaSequenceFile(IOUtils.getPath(pathOnExecutor)); - final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), false, false, header, taskReferenceSequenceFile, annotatorEngineBroadcast.getValue()); - - final ReferenceDataSource taskReferenceDataSource = new ReferenceFileSource(IOUtils.getPath(pathOnExecutor)); // TODO: share with taskReferenceSequenceFile - final ReadsDownsampler readsDownsampler = assemblyArgs.maxReadsPerAlignmentStart > 0 ? - new PositionalDownsampler(assemblyArgs.maxReadsPerAlignmentStart, header) : null; - return Utils.stream(shards) - //TODO we've hacked multi interval shards here with a shim, but we should investigate as smarter approach https://github.com/broadinstitute/gatk/issues/4299 - .map(shard -> new ShardToMultiIntervalShardAdapter<>( - new DownsampleableSparkReadShard(new ShardBoundary(shard.getInterval(), shard.getPaddedInterval()), shard, readsDownsampler))) - .flatMap(shardToRegion(assemblyArgs, header, taskReferenceDataSource, hcEngine)).iterator(); - }; + final String reference, + final Collection annotations) { + final Path referencePath = IOUtils.getPath(reference); + final String referenceFileName = referencePath.getFileName().toString(); + final ReferenceSequenceFile taskReferenceSequenceFile = taskReferenceSequenceFile(referenceFileName); + final VariantAnnotatorEngine annotatorEngine = new VariantAnnotatorEngine(annotations, hcArgs.dbsnp.dbsnp, hcArgs.comps, hcArgs.emitReferenceConfidence != ReferenceConfidenceMode.NONE); + return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, hcArgs, header, taskReferenceSequenceFile, annotatorEngine); } - private static Function, Stream>> shardToRegion( - ShardingArgumentCollection assemblyArgs, - SAMFileHeader header, - ReferenceDataSource referenceDataSource, - HaplotypeCallerEngine evaluator) { - return shard -> { - //TODO load features as a side input - final FeatureManager featureManager = null; - - final Iterator assemblyRegionIter = new AssemblyRegionIterator(shard, header, referenceDataSource, featureManager, evaluator, assemblyArgs.minAssemblyRegionSize, assemblyArgs.maxAssemblyRegionSize, assemblyArgs.assemblyRegionPadding, assemblyArgs.activeProbThreshold, assemblyArgs.maxProbPropagationDistance, - INCLUDE_READS_WITH_DELETIONS_IN_IS_ACTIVE_PILEUPS); + private static ReferenceSequenceFile taskReferenceSequenceFile(final String referenceFileName) { + final String pathOnExecutor = SparkFiles.get(referenceFileName); + return new CachingIndexedFastaSequenceFile(IOUtils.getPath(pathOnExecutor)); + } - return Utils.stream(assemblyRegionIter) - .map(a -> new Tuple2<>(a, shard.getIntervals().get(0))); + private static Broadcast> assemblyRegionEvaluatorSupplierBroadcastFunction( + final JavaSparkContext ctx, + final HaplotypeCallerArgumentCollection hcArgs, + final SAMFileHeader header, + final ReferenceSequenceFile taskReferenceSequenceFile, + final VariantAnnotatorEngine annotatorEngine) { + Supplier supplier = new Supplier() { + @Override + public AssemblyRegionEvaluator get() { + return new HaplotypeCallerEngine(hcArgs, false, false, header, taskReferenceSequenceFile, annotatorEngine); + } }; + return ctx.broadcast(supplier); } /** - * Adapter to allow a 2bit reference to be used in HaplotypeCallerEngine. - * This is not intended as a general purpose adapter, it only enables the operations needed in {@link HaplotypeCallerEngine} - * This should not be used outside of this class except for testing purposes. + * Call Variants using HaplotypeCaller on Spark and write out a VCF file. + * + * This may be called from any spark pipeline in order to call variants from an RDD of GATKRead + * @param ctx the spark context + * @param reads the reads variants should be called from + * @param header the header that goes with the reads + * @param reference the full path to the reference file (must have been added via {@code SparkContext#addFile()}) + * @param intervalShards the interval shards to restrict calling to + * @param hcArgs haplotype caller arguments + * @param shardingArgs arguments to control how the assembly regions are sharded + * @param output the output path for the VCF + * @param logger */ - @VisibleForTesting - public static final class ReferenceMultiSourceAdapter implements ReferenceSequenceFile, ReferenceDataSource, Serializable{ - private static final long serialVersionUID = 1L; - - private final ReferenceMultiSparkSource source; - private final SAMSequenceDictionary sequenceDictionary; - - public ReferenceMultiSourceAdapter(final ReferenceMultiSparkSource source) { - this.source = source; - sequenceDictionary = source.getReferenceSequenceDictionary(null); - } - - @Override - public ReferenceSequence queryAndPrefetch(final String contig, final long start, final long stop) { - return getSubsequenceAt(contig, start, stop); - } - - @Override - public SAMSequenceDictionary getSequenceDictionary() { - return source.getReferenceSequenceDictionary(null); - } - - @Override - public ReferenceSequence nextSequence() { - throw new UnsupportedOperationException("nextSequence is not implemented"); - } - - @Override - public void reset() { - throw new UnsupportedOperationException("reset is not implemented"); - } - - @Override - public boolean isIndexed() { - return true; - } - - @Override - public ReferenceSequence getSequence(final String contig) { - throw new UnsupportedOperationException("getSequence is not supported"); - } - - @Override - public ReferenceSequence getSubsequenceAt(final String contig, final long start, final long stop) { - try { - final ReferenceBases bases = source.getReferenceBases(new SimpleInterval(contig, (int) start, (int) stop)); - return new ReferenceSequence(contig, sequenceDictionary.getSequenceIndex(contig), bases.getBases()); - } catch (final IOException e) { - throw new GATKException(String.format("Failed to load reference bases for %s:%d-%d", contig, start, stop)); - } - } - - @Override - public void close() { - // doesn't do anything because you can't close a two-bit file - } + public static void callVariantsWithHaplotypeCallerAndWriteOutput( + final JavaSparkContext ctx, + final JavaRDD reads, + final SAMFileHeader header, + final SAMSequenceDictionary sequenceDictionary, + final String reference, + final List intervalShards, + final HaplotypeCallerArgumentCollection hcArgs, + final AssemblyRegionReadShardArgumentCollection shardingArgs, + final AssemblyRegionArgumentCollection assemblyRegionArgs, + final boolean includeReadsWithDeletionsInIsActivePileups, + final String output, + final Collection annotations, + final Logger logger) { - @Override - public Iterator iterator() { - throw new UnsupportedOperationException("iterator is not supported"); - } + final Path referencePath = IOUtils.getPath(reference); + final String referenceFileName = referencePath.getFileName().toString(); + Broadcast> assemblyRegionEvaluatorSupplierBroadcast = assemblyRegionEvaluatorSupplierBroadcast(ctx, hcArgs, header, reference, annotations); + JavaRDD assemblyRegions = getAssemblyRegions(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, includeReadsWithDeletionsInIsActivePileups, false); + processAssemblyRegions(assemblyRegions, ctx, header, reference, hcArgs, output, annotations, logger); } - } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/examples/ExampleAssemblyRegionWalkerSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/examples/ExampleAssemblyRegionWalkerSpark.java index 060ae954e08..83b32c3116b 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/examples/ExampleAssemblyRegionWalkerSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/examples/ExampleAssemblyRegionWalkerSpark.java @@ -5,10 +5,12 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; import org.broadinstitute.hellbender.cmdline.programgroups.ExampleProgramGroup; import org.broadinstitute.hellbender.engine.*; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionArgumentCollection; import org.broadinstitute.hellbender.engine.spark.AssemblyRegionWalkerContext; import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.activityprofile.ActivityProfileState; @@ -37,29 +39,32 @@ public final class ExampleAssemblyRegionWalkerSpark extends AssemblyRegionWalker @Argument(fullName="knownVariants", shortName="knownVariants", doc="Known set of variants", optional=true) private FeatureInput knownVariants; - @Override - protected int defaultReadShardSize() { return 5000; } + public static class ExampleAssemblyRegionArgumentCollection extends AssemblyRegionArgumentCollection { + private static final long serialVersionUID = 1L; - @Override - protected int defaultReadShardPadding() { return 100; } + @Override + protected int defaultMinAssemblyRegionSize() { return 50; } - @Override - protected int defaultMinAssemblyRegionSize() { return 50; } + @Override + protected int defaultMaxAssemblyRegionSize() { return 300; } - @Override - protected int defaultMaxAssemblyRegionSize() { return 300; } + @Override + protected int defaultAssemblyRegionPadding() { return 100; } - @Override - protected int defaultAssemblyRegionPadding() { return 100; } + @Override + protected int defaultMaxReadsPerAlignmentStart() { return 50; } - @Override - protected int defaultMaxReadsPerAlignmentStart() { return 50; } + @Override + protected double defaultActiveProbThreshold() { return 0.002; } - @Override - protected double defaultActiveProbThreshold() { return 0.002; } + @Override + protected int defaultMaxProbPropagationDistance() { return 50; } + } @Override - protected int defaultMaxProbPropagationDistance() { return 50; } + protected AssemblyRegionArgumentCollection getAssemblyRegionArgumentCollection() { + return new ExampleAssemblyRegionArgumentCollection(); + } @Override protected boolean includeReadsWithDeletionsInIsActivePileups() { diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java index 88e02747529..9b5e9cb8092 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java @@ -1,6 +1,7 @@ package org.broadinstitute.hellbender.tools.spark.pipelines; import htsjdk.samtools.SAMFileHeader; +import htsjdk.samtools.SAMSequenceDictionary; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -13,7 +14,11 @@ import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; import org.broadinstitute.hellbender.cmdline.argumentcollections.MarkDuplicatesSparkArgumentCollection; import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup; +import org.broadinstitute.hellbender.engine.Shard; +import org.broadinstitute.hellbender.engine.ShardBoundary; import org.broadinstitute.hellbender.engine.filters.ReadFilter; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionArgumentCollection; +import org.broadinstitute.hellbender.engine.spark.AssemblyRegionReadShardArgumentCollection; import org.broadinstitute.hellbender.engine.spark.GATKSparkTool; import org.broadinstitute.hellbender.utils.spark.JoinReadsWithVariants; import org.broadinstitute.hellbender.tools.ApplyBQSRUniqueArgumentCollection; @@ -39,6 +44,7 @@ import java.util.Collection; import java.util.List; +import java.util.stream.Collectors; /** * ReadsPipelineSpark is our standard pipeline that takes unaligned or aligned reads and runs BWA (if specified), MarkDuplicates, @@ -118,7 +124,10 @@ public class ReadsPipelineSpark extends GATKSparkTool { private final RecalibrationArgumentCollection bqsrArgs = new RecalibrationArgumentCollection(); @ArgumentCollection - public final HaplotypeCallerSpark.ShardingArgumentCollection shardingArgs = new HaplotypeCallerSpark.ShardingArgumentCollection(); + public final AssemblyRegionReadShardArgumentCollection shardingArgs = new AssemblyRegionReadShardArgumentCollection(); + + @ArgumentCollection + public final AssemblyRegionArgumentCollection assemblyRegionArgs = new HaplotypeCallerSpark.HaplotypeCallerAssemblyRegionArgumentCollection(); /** * command-line arguments to fine tune the apply BQSR step. @@ -203,7 +212,13 @@ protected void runTool(final JavaSparkContext ctx) { final ReadFilter hcReadFilter = ReadFilter.fromList(HaplotypeCallerEngine.makeStandardHCReadFilters(), header); final JavaRDD filteredReadsForHC = finalReads.filter(hcReadFilter::test); final List intervals = hasUserSuppliedIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(header.getSequenceDictionary()); - HaplotypeCallerSpark.callVariantsWithHaplotypeCallerAndWriteOutput(ctx, filteredReadsForHC, header, referenceArguments.getReferenceFileName(), intervals, hcArgs, shardingArgs, numReducers, output, makeVariantAnnotations()); + + SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); + List intervalShards = intervals.stream() + .flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardingArgs.readShardSize, shardingArgs.readShardPadding, sequenceDictionary).stream()) + .collect(Collectors.toList()); + + HaplotypeCallerSpark.callVariantsWithHaplotypeCallerAndWriteOutput(ctx, filteredReadsForHC, readsHeader, sequenceDictionary, referenceArguments.getReferenceFileName(), intervalShards, hcArgs, shardingArgs, assemblyRegionArgs, true, output, makeVariantAnnotations(), logger); if (bwaEngine != null) { bwaEngine.close(); diff --git a/src/test/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSparkIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSparkIntegrationTest.java index 898815c8372..08c5ba779dd 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSparkIntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSparkIntegrationTest.java @@ -188,15 +188,6 @@ public void testGVCFModeIsConcordantWithGATK3_8AlelleSpecificResults(String exte Assert.assertTrue(concordance >= 0.99, "Concordance with GATK 3.8 in AS GVCF mode is < 99% (" + concordance + ")"); } - @Test - public void testReferenceAdapterIsSerializable() throws IOException { - final ReferenceMultiSparkSource referenceMultiSource = new ReferenceMultiSparkSource(b37_reference_20_21, ReferenceWindowFunctions.IDENTITY_FUNCTION); - SparkTestUtils.roundTripInKryo(referenceMultiSource, ReferenceMultiSparkSource.class, SparkContextFactory.getTestSparkContext().getConf()); - final HaplotypeCallerSpark.ReferenceMultiSourceAdapter adapter = new HaplotypeCallerSpark.ReferenceMultiSourceAdapter(referenceMultiSource); - SparkTestUtils.roundTripInKryo(adapter, HaplotypeCallerSpark.ReferenceMultiSourceAdapter.class, SparkContextFactory.getTestSparkContext().getConf()); - - } - @Test public void testGenotypeCalculationArgumentCollectionIsSerializable() { final GenotypeCalculationArgumentCollection args = new GenotypeCalculationArgumentCollection();