From a43ff7337797d4f16189ba6bb1adcf0d2a113715 Mon Sep 17 00:00:00 2001 From: Samuel Lee Date: Fri, 22 Feb 2019 10:49:21 -0500 Subject: [PATCH] Improved memory requirements of CollectReadCounts. --- .../tools/copynumber/CollectReadCounts.java | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java index 42209b78bb3..23d500771f3 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java @@ -1,14 +1,15 @@ package org.broadinstitute.hellbender.tools.copynumber; import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Multiset; +import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMSequenceDictionary; import htsjdk.samtools.util.Locatable; import htsjdk.samtools.util.OverlapDetector; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.broadinstitute.barclay.argparser.Argument; -import org.broadinstitute.barclay.argparser.BetaFeature; import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; import org.broadinstitute.barclay.help.DocumentedFeature; import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; @@ -34,9 +35,7 @@ import org.broadinstitute.hellbender.utils.read.GATKRead; import java.io.File; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; /** @@ -124,16 +123,21 @@ enum Format { */ private SampleLocatableMetadata metadata; + private List intervals; + + private String currentContig = null; + private boolean isCoordinateSorted; + /** * Overlap detector used to determine when read starts overlap with input intervals. */ private CachedOverlapDetector intervalCachedOverlapDetector; - private Multiset intervalMultiset = HashMultiset.create(); + private Multiset intervalMultiset; @Override public List getDefaultReadFilters() { - final List filters = new ArrayList<>(super.getDefaultReadFilters()); + final List filters = new ArrayList<>(); filters.add(ReadFilterLibrary.MAPPED); filters.add(ReadFilterLibrary.NON_ZERO_REFERENCE_LENGTH_ALIGNMENT); filters.add(ReadFilterLibrary.NOT_DUPLICATE); @@ -158,19 +162,31 @@ public void onTraversalStart() { CopyNumberArgumentValidationUtils.validateIntervalArgumentCollection(intervalArgumentCollection); - logger.info("Initializing and validating intervals..."); - final List intervals = intervalArgumentCollection.getIntervals(sequenceDictionary); - intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervals); + intervals = intervalArgumentCollection.getIntervals(sequenceDictionary); + intervalMultiset = HashMultiset.create(intervals.size()); - //verify again that intervals do not overlap - Utils.validateArg(intervals.stream().noneMatch(i -> intervalCachedOverlapDetector.overlapDetector.getOverlaps(i).size() > 1), - "Input intervals may not be overlapping."); + final SAMFileHeader.SortOrder sortOrder = getHeaderForReads().getSortOrder(); + isCoordinateSorted = sortOrder == SAMFileHeader.SortOrder.coordinate; + if (!isCoordinateSorted) { + //if reads are not sorted, create an OverlapDetector covering all intervals; + //we currently require that reads are indexed (and hence sorted), so this is only to protect against future code changes + logger.warn("Reads are not coordinate sorted; sorting reads before running this tool may reduce memory requirements."); + intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervals); + } logger.info("Collecting read counts..."); } @Override public void apply(GATKRead read, ReferenceContext referenceContext, FeatureContext featureContext) { + if (isCoordinateSorted && (currentContig == null || !read.getContig().equals(currentContig))) { + //if reads are sorted and we are on a new contig, create an OverlapDetector covering the contig + currentContig = read.getContig(); + final List intervalsOnCurrentContig = intervals.stream() + .filter(i -> i.getContig().equals(currentContig)) + .collect(Collectors.toList()); + intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervalsOnCurrentContig); + } final SimpleInterval overlappingInterval = intervalCachedOverlapDetector.getOverlap( new SimpleInterval(read.getContig(), read.getStart(), read.getStart())); @@ -186,9 +202,9 @@ public Object onTraversalSuccess() { logger.info("Writing read counts to " + outputCountsFile); final SimpleCountCollection readCounts = new SimpleCountCollection( metadata, - intervalCachedOverlapDetector.overlapDetector.getAll().stream() + ImmutableList.copyOf(intervals.stream() //making this an ImmutableList avoids a defensive copy in SimpleCountCollection .map(i -> new SimpleCount(i, intervalMultiset.count(i))) - .collect(Collectors.toList())); + .iterator())); if (format == Format.HDF5) { readCounts.writeHDF5(outputCountsFile); @@ -210,6 +226,9 @@ private final class CachedOverlapDetector { CachedOverlapDetector(final List intervals) { Utils.nonEmpty(intervals); this.overlapDetector = OverlapDetector.create(intervals); + //double check that intervals do not overlap + Utils.validateArg(intervals.stream().noneMatch(i -> overlapDetector.getOverlaps(i).size() > 1), + "Input intervals may not be overlapping."); cachedResult = intervals.get(0); }