Skip to content

Commit

Permalink
Improved memory requirements of CollectReadCounts.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelklee committed Feb 25, 2019
1 parent 226f6d7 commit a43ff73
Showing 1 changed file with 33 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package org.broadinstitute.hellbender.tools.copynumber;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multiset;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
Expand All @@ -34,9 +35,7 @@
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -124,16 +123,21 @@ enum Format {
*/
private SampleLocatableMetadata metadata;

private List<SimpleInterval> intervals;

private String currentContig = null;
private boolean isCoordinateSorted;

/**
* Overlap detector used to determine when read starts overlap with input intervals.
*/
private CachedOverlapDetector<SimpleInterval> intervalCachedOverlapDetector;

private Multiset<SimpleInterval> intervalMultiset = HashMultiset.create();
private Multiset<SimpleInterval> intervalMultiset;

@Override
public List<ReadFilter> getDefaultReadFilters() {
final List<ReadFilter> filters = new ArrayList<>(super.getDefaultReadFilters());
final List<ReadFilter> filters = new ArrayList<>();
filters.add(ReadFilterLibrary.MAPPED);
filters.add(ReadFilterLibrary.NON_ZERO_REFERENCE_LENGTH_ALIGNMENT);
filters.add(ReadFilterLibrary.NOT_DUPLICATE);
Expand All @@ -158,19 +162,31 @@ public void onTraversalStart() {

CopyNumberArgumentValidationUtils.validateIntervalArgumentCollection(intervalArgumentCollection);

logger.info("Initializing and validating intervals...");
final List<SimpleInterval> intervals = intervalArgumentCollection.getIntervals(sequenceDictionary);
intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervals);
intervals = intervalArgumentCollection.getIntervals(sequenceDictionary);
intervalMultiset = HashMultiset.create(intervals.size());

//verify again that intervals do not overlap
Utils.validateArg(intervals.stream().noneMatch(i -> intervalCachedOverlapDetector.overlapDetector.getOverlaps(i).size() > 1),
"Input intervals may not be overlapping.");
final SAMFileHeader.SortOrder sortOrder = getHeaderForReads().getSortOrder();
isCoordinateSorted = sortOrder == SAMFileHeader.SortOrder.coordinate;
if (!isCoordinateSorted) {
//if reads are not sorted, create an OverlapDetector covering all intervals;
//we currently require that reads are indexed (and hence sorted), so this is only to protect against future code changes
logger.warn("Reads are not coordinate sorted; sorting reads before running this tool may reduce memory requirements.");
intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervals);
}

logger.info("Collecting read counts...");
}

@Override
public void apply(GATKRead read, ReferenceContext referenceContext, FeatureContext featureContext) {
if (isCoordinateSorted && (currentContig == null || !read.getContig().equals(currentContig))) {
//if reads are sorted and we are on a new contig, create an OverlapDetector covering the contig
currentContig = read.getContig();
final List<SimpleInterval> intervalsOnCurrentContig = intervals.stream()
.filter(i -> i.getContig().equals(currentContig))
.collect(Collectors.toList());
intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervalsOnCurrentContig);
}
final SimpleInterval overlappingInterval = intervalCachedOverlapDetector.getOverlap(
new SimpleInterval(read.getContig(), read.getStart(), read.getStart()));

Expand All @@ -186,9 +202,9 @@ public Object onTraversalSuccess() {
logger.info("Writing read counts to " + outputCountsFile);
final SimpleCountCollection readCounts = new SimpleCountCollection(
metadata,
intervalCachedOverlapDetector.overlapDetector.getAll().stream()
ImmutableList.copyOf(intervals.stream() //making this an ImmutableList avoids a defensive copy in SimpleCountCollection
.map(i -> new SimpleCount(i, intervalMultiset.count(i)))
.collect(Collectors.toList()));
.iterator()));

if (format == Format.HDF5) {
readCounts.writeHDF5(outputCountsFile);
Expand All @@ -210,6 +226,9 @@ private final class CachedOverlapDetector<T extends Locatable> {
CachedOverlapDetector(final List<T> intervals) {
Utils.nonEmpty(intervals);
this.overlapDetector = OverlapDetector.create(intervals);
//double check that intervals do not overlap
Utils.validateArg(intervals.stream().noneMatch(i -> overlapDetector.getOverlaps(i).size() > 1),
"Input intervals may not be overlapping.");
cachedResult = intervals.get(0);
}

Expand Down

0 comments on commit a43ff73

Please sign in to comment.