Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved memory requirements of CollectReadCounts. #5715

Merged
merged 2 commits into from
Feb 27, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package org.broadinstitute.hellbender.tools.copynumber;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multiset;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
Expand All @@ -34,9 +34,7 @@
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -124,12 +122,16 @@ enum Format {
*/
private SampleLocatableMetadata metadata;

private List<SimpleInterval> intervals;

private String currentContig = null;

/**
* Overlap detector used to determine when read starts overlap with input intervals.
*/
private CachedOverlapDetector<SimpleInterval> intervalCachedOverlapDetector;

private Multiset<SimpleInterval> intervalMultiset = HashMultiset.create();
private Multiset<SimpleInterval> intervalMultiset;

@Override
public List<ReadFilter> getDefaultReadFilters() {
Expand Down Expand Up @@ -158,19 +160,22 @@ public void onTraversalStart() {

CopyNumberArgumentValidationUtils.validateIntervalArgumentCollection(intervalArgumentCollection);

logger.info("Initializing and validating intervals...");
final List<SimpleInterval> intervals = intervalArgumentCollection.getIntervals(sequenceDictionary);
intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervals);

//verify again that intervals do not overlap
Utils.validateArg(intervals.stream().noneMatch(i -> intervalCachedOverlapDetector.overlapDetector.getOverlaps(i).size() > 1),
"Input intervals may not be overlapping.");
intervals = intervalArgumentCollection.getIntervals(sequenceDictionary);
intervalMultiset = HashMultiset.create(intervals.size());

logger.info("Collecting read counts...");
}

@Override
public void apply(GATKRead read, ReferenceContext referenceContext, FeatureContext featureContext) {
if (currentContig == null || !read.getContig().equals(currentContig)) {
//if we are on a new contig, create an OverlapDetector covering the contig
currentContig = read.getContig();
final List<SimpleInterval> intervalsOnCurrentContig = intervals.stream()
.filter(i -> i.getContig().equals(currentContig))
.collect(Collectors.toList());
intervalCachedOverlapDetector = new CachedOverlapDetector<>(intervalsOnCurrentContig);
}
final SimpleInterval overlappingInterval = intervalCachedOverlapDetector.getOverlap(
new SimpleInterval(read.getContig(), read.getStart(), read.getStart()));

Expand All @@ -186,9 +191,9 @@ public Object onTraversalSuccess() {
logger.info("Writing read counts to " + outputCountsFile);
final SimpleCountCollection readCounts = new SimpleCountCollection(
metadata,
intervalCachedOverlapDetector.overlapDetector.getAll().stream()
ImmutableList.copyOf(intervals.stream() //making this an ImmutableList avoids a defensive copy in SimpleCountCollection
.map(i -> new SimpleCount(i, intervalMultiset.count(i)))
.collect(Collectors.toList()));
.iterator()));

if (format == Format.HDF5) {
readCounts.writeHDF5(outputCountsFile);
Expand All @@ -210,6 +215,9 @@ private final class CachedOverlapDetector<T extends Locatable> {
CachedOverlapDetector(final List<T> intervals) {
Utils.nonEmpty(intervals);
this.overlapDetector = OverlapDetector.create(intervals);
//double check that intervals do not overlap
Utils.validateArg(intervals.stream().noneMatch(i -> overlapDetector.getOverlaps(i).size() > 1),
"Input intervals may not be overlapping.");
cachedResult = intervals.get(0);
}

Expand Down