Skip to content

Commit

Permalink
fixed any potential worry about pileup memory overhead by disabling t…
Browse files Browse the repository at this point in the history
…racking by default
  • Loading branch information
jamesemery committed Feb 4, 2022
1 parent 578d271 commit a319cd9
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public AssemblyRegionIterator(final MultiIntervalShard<GATKRead> readShard,
final ReferenceDataSource reference,
final FeatureManager features,
final AssemblyRegionEvaluator evaluator,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean trackPileups ) {

Utils.nonNull(readShard);
Utils.nonNull(readHeader);
Expand All @@ -87,7 +88,7 @@ public AssemblyRegionIterator(final MultiIntervalShard<GATKRead> readShard,
this.readCachingIterator = new ReadCachingIterator(readShard.iterator());
this.readCache = new ArrayDeque<>();
this.activityProfile = new BandPassActivityProfile(assemblyRegionArgs.maxProbPropagationDistance, assemblyRegionArgs.activeProbThreshold, BandPassActivityProfile.MAX_FILTER_SIZE, BandPassActivityProfile.DEFAULT_SIGMA, readHeader);
this.pendingAlignmentData = new ArrayDeque<>();
this.pendingAlignmentData = trackPileups? new ArrayDeque<>(): null;

// We wrap our LocusIteratorByState inside an IntervalAlignmentContextIterator so that we get empty loci
// for uncovered locations. This is critical for reproducing GATK 3.x behavior!
Expand Down Expand Up @@ -134,7 +135,9 @@ private AssemblyRegion loadNextAssemblyRegion() {
final SimpleInterval pileupInterval = new SimpleInterval(pileup);
final ReferenceContext pileupRefContext = new ReferenceContext(reference, pileupInterval);
final FeatureContext pileupFeatureContext = new FeatureContext(features, pileupInterval);
pendingAlignmentData.add(new AlignmentAndReferenceContext(pileup, pileupRefContext));
if (pendingAlignmentData!=null) {
pendingAlignmentData.add(new AlignmentAndReferenceContext(pileup, pileupRefContext));
}

final ActivityProfileState profile = evaluator.isActive(pileup, pileupRefContext, pileupFeatureContext);
activityProfile.add(profile);
Expand Down Expand Up @@ -215,6 +218,9 @@ private void fillNextAssemblyRegionWithReads( final AssemblyRegion region ) {
}

private void fillNextAssemblyRegionWithPileupData(final AssemblyRegion region){
if (pendingAlignmentData==null){
return;
}
final List<AlignmentAndReferenceContext> overlappingAlignmentData = new ArrayList<>();
final Queue<AlignmentAndReferenceContext> previousAlignmentData = new ArrayDeque<>();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package org.broadinstitute.hellbender.engine;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.engine.filters.CountingReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
Expand Down Expand Up @@ -185,7 +183,7 @@ public void traverse() {
* @param features FeatureManager
*/
private void processReadShard(MultiIntervalLocalReadShard shard, ReferenceDataSource reference, FeatureManager features ) {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(shard, getHeaderForReads(), reference, features, assemblyRegionEvaluator(), assemblyRegionArgs);
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(shard, getHeaderForReads(), reference, features, assemblyRegionEvaluator(), assemblyRegionArgs, shouldTrackPileupsForAssemblyRegions());

// Call into the tool implementation to process each assembly region from this shard.
while ( assemblyRegionIter.hasNext() ) {
Expand Down Expand Up @@ -236,6 +234,12 @@ protected final void onShutdown() {
*/
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();

/**
* Allows implementing tools to decide whether pileups must be tracked and attached to assembly regions for later processing.
* This is configurable for now in order to save on potential increases in memory consumption variant calling machinery.
*/
public abstract boolean shouldTrackPileupsForAssemblyRegions();

/**
* Process an individual AssemblyRegion. Must be implemented by tool authors.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, ass
} else {
return FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, getReads(), getHeaderForReads(), sequenceDictionary, referenceFileName, features,
intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, assemblyRegionArgs,
shuffle);
shuffle, false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,21 @@ public static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegionsFast(
final Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean shuffle) {
final boolean shuffle,
final boolean trackPileups) {
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.mapPartitions(getAssemblyRegionsFunctionFast(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs));
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs, trackPileups));
}

private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext> getAssemblyRegionsFunctionFast(
final String referenceFileName,
final Broadcast<FeatureManager> bFeatureManager,
final SAMFileHeader header,
final Broadcast<Supplier<AssemblyRegionEvaluator>> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean trackPileups) {
return (FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext>) shardedReadIterator -> {
final ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
Expand All @@ -90,7 +92,7 @@ private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerCo
.map(downsampledShardedRead -> {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(downsampledShardedRead),
header, reference, features, assemblyRegionEvaluator, assemblyRegionArgs);
header, reference, features, assemblyRegionEvaluator, assemblyRegionArgs, trackPileups);
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getPaddedSpan()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput(
Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast = assemblyRegionEvaluatorSupplierBroadcast(ctx, hcArgs, assemblyRegionArgs, header, reference, annotations);
JavaRDD<AssemblyRegionWalkerContext> assemblyRegions = strict ?
FindAssemblyRegionsSpark.getAssemblyRegionsStrict(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false) :
FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false);
FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false, hcArgs.pileupDetectionArgs.usePileupDetection);
processAssemblyRegions(assemblyRegions, ctx, header, reference, hcArgs, assemblyRegionArgs, output, annotations, logger, createOutputVariantIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public AssemblyRegionEvaluator assemblyRegionEvaluator() {
return (locusPileup, referenceContext, featureContext) -> new ActivityProfileState(new SimpleInterval(locusPileup), 1.0);
}

@Override
public boolean shouldTrackPileupsForAssemblyRegions() {
return false;
}

@Override
public void onTraversalStart() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ public AssemblyRegionEvaluator assemblyRegionEvaluator() {
return hcEngine;
}

@Override
public boolean shouldTrackPileupsForAssemblyRegions() {
return hcArgs.pileupDetectionArgs.usePileupDetection;
}

@Override
public void onTraversalStart() {
if (hcArgs.emitReferenceConfidence == ReferenceConfidenceMode.GVCF && hcArgs.maxMnpDistance > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ protected ReadsDownsampler createDownsampler() {
@Override
public AssemblyRegionEvaluator assemblyRegionEvaluator() { return m2Engine; }

@Override
public boolean shouldTrackPileupsForAssemblyRegions() {
return MTAC.pileupDetectionArgs.usePileupDetection;
}

@Override
public void onTraversalStart() {
VariantAnnotatorEngine annotatorEngine = new VariantAnnotatorEngine(makeVariantAnnotations(), null, Collections.emptyList(), false, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void testRegionsHaveCorrectReadsAndSize( final String reads, final String
final CountingReadFilter combinedReadFilter = CountingReadFilter.fromList(readFilters, readsSource.getHeader());
readShard.setReadFilter(combinedReadFilter);

final AssemblyRegionIterator iter = new AssemblyRegionIterator(readShard, readsSource.getHeader(), refSource, null, evaluator, assemblyRegionArgs);
final AssemblyRegionIterator iter = new AssemblyRegionIterator(readShard, readsSource.getHeader(), refSource, null, evaluator, assemblyRegionArgs, false);

AssemblyRegion previousRegion = null;
while ( iter.hasNext() ) {
Expand Down

0 comments on commit a319cd9

Please sign in to comment.