Skip to content

Commit

Permalink
Reworked the keys for MarkDuplciatesSpark to be sufficient for groupi…
Browse files Browse the repository at this point in the history
…ng on their own.
  • Loading branch information
jamesemery authored Jun 14, 2018
1 parent 51874bc commit ba62c2a
Show file tree
Hide file tree
Showing 15 changed files with 339 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.bdgenomics.adam.serialization.ADAMKryoRegistrator;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtils;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.*;

import java.util.Collections;
Expand Down Expand Up @@ -89,5 +90,8 @@ private void registerGATKClasses(Kryo kryo) {
kryo.register(Pair.class, new Pair.Serializer());
kryo.register(Passthrough.class, new FieldSerializer(kryo, Passthrough.class));
kryo.register(MarkDuplicatesSparkUtils.IndexPair.class, new FieldSerializer(kryo, MarkDuplicatesSparkUtils.IndexPair.class));
kryo.register(ReadsKey.class, new FieldSerializer(kryo, ReadsKey.class));
kryo.register(ReadsKey.KeyForFragment.class, new FieldSerializer(kryo, ReadsKey.KeyForFragment.class));
kryo.register(ReadsKey.KeyForPair.class, new FieldSerializer(kryo, ReadsKey.KeyForPair.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ static JavaPairRDD<IndexPair<String>, Integer> transformToDuplicateNames(final S
final JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> keyedReads = getReadsGroupedByName(header, mappedReads, numReducers);

final Broadcast<Map<String, Short>> headerReadGroupIndexMap = JavaSparkContext.fromSparkContext(reads.context()).broadcast( getHeaderReadGroupIndexMap(header));
final Broadcast<Map<String, Byte>> libraryIndex = JavaSparkContext.fromSparkContext(reads.context()).broadcast( constructLibraryIndex(header));

// Place all the reads into a single RDD of MarkDuplicatesSparkRecord objects
final JavaPairRDD<Integer, MarkDuplicatesSparkRecord> pairedEnds = keyedReads.flatMapToPair(keyedRead -> {
final List<Tuple2<Integer, MarkDuplicatesSparkRecord>> out = Lists.newArrayList();
final JavaPairRDD<ReadsKey, MarkDuplicatesSparkRecord> pairedEnds = keyedReads.flatMapToPair(keyedRead -> {
final List<Tuple2<ReadsKey, MarkDuplicatesSparkRecord>> out = Lists.newArrayList();
final IndexPair<?>[] hadNonPrimaryRead = {null};

final List<IndexPair<GATKRead>> primaryReads = Utils.stream(keyedRead._2())
Expand All @@ -110,8 +111,8 @@ static JavaPairRDD<IndexPair<String>, Integer> transformToDuplicateNames(final S
final GATKRead read = readWithIndex.getValue();
if (!(read.isSecondaryAlignment()||read.isSupplementaryAlignment())) {
PairedEnds fragment = (ReadUtils.readHasMappedMate(read)) ?
MarkDuplicatesSparkRecord.newEmptyFragment(read, header) :
MarkDuplicatesSparkRecord.newFragment(read, header, readWithIndex.getIndex(), scoringStrategy);
MarkDuplicatesSparkRecord.newEmptyFragment(read, header, libraryIndex.getValue()) :
MarkDuplicatesSparkRecord.newFragment(read, header, readWithIndex.getIndex(), scoringStrategy, libraryIndex.getValue());

out.add(new Tuple2<>(fragment.key(), fragment));
} else {
Expand Down Expand Up @@ -143,7 +144,7 @@ static JavaPairRDD<IndexPair<String>, Integer> transformToDuplicateNames(final S
if (mappedPair.size()==2) {
final GATKRead firstRead = mappedPair.get(0).getValue();
final IndexPair<GATKRead> secondRead = mappedPair.get(1);
final Pair pair = MarkDuplicatesSparkRecord.newPair(firstRead, secondRead.getValue(), header, secondRead.getIndex(), scoringStrategy);
final Pair pair = MarkDuplicatesSparkRecord.newPair(firstRead, secondRead.getValue(), header, secondRead.getIndex(), scoringStrategy, libraryIndex.getValue());
// Validate and add the read group to the pair
final Short readGroup = headerReadGroupIndexMap.getValue().get(firstRead.getReadGroup());
if (readGroup != null) {
Expand All @@ -167,16 +168,35 @@ static JavaPairRDD<IndexPair<String>, Integer> transformToDuplicateNames(final S
return out.iterator();
});

final JavaPairRDD<Integer, Iterable<MarkDuplicatesSparkRecord>> keyedPairs = pairedEnds.groupByKey(); //TODO evaluate replacing this with a smart aggregate by key.
final JavaPairRDD<ReadsKey, Iterable<MarkDuplicatesSparkRecord>> keyedPairs = pairedEnds.groupByKey(); //TODO evaluate replacing this with a smart aggregate by key.

return markDuplicateRecords(keyedPairs, finder);
}

/**
* Method which generates a map of the libraries found tagged in readgroups from the header so they can be serialized as indexes to save space
*/
public static Map<String, Byte> constructLibraryIndex(final SAMFileHeader header) {
final List<String> discoveredLibraries = header.getReadGroups().stream()
.map(r -> { String library = r.getLibrary();
return library==null? LibraryIdGenerator.UNKNOWN_LIBRARY : library;} )
.distinct()
.collect(Collectors.toList());
if (discoveredLibraries.size() > 255) {
throw new GATKException("Detected too many read libraries among read groups header, currently MarkDuplicatesSpark only supports up to 256 unique readgroup libraries but " + discoveredLibraries.size() + " were found");
}
final Iterator<Byte> iterator = IntStream.range(0, discoveredLibraries.size()).boxed().map(Integer::byteValue).iterator();
return Maps.uniqueIndex(iterator, idx -> discoveredLibraries.get(idx));
}

/**
* Method which generates a map of the readgroups from the header so they can be serialized as indexes
*/
private static Map<String, Short> getHeaderReadGroupIndexMap(final SAMFileHeader header) {
final List<SAMReadGroupRecord> readGroups = header.getReadGroups();
if (readGroups.size() > 65535) {
throw new GATKException("Detected too many read groups in the header, currently MarkDuplicatesSpark only supports up to 65535 unique readgroup IDs but " + readGroups.size() + " were found");
}
final Iterator<Short> iterator = IntStream.range(0, readGroups.size()).boxed().map(Integer::shortValue).iterator();
return Maps.uniqueIndex(iterator, idx -> readGroups.get(idx).getId() );
}
Expand Down Expand Up @@ -284,64 +304,43 @@ protected Tuple2<K, Iterable<V>> computeNext() {
* - Collects the results and returns an iterator
*/
@SuppressWarnings("unchecked")
private static JavaPairRDD<IndexPair<String>, Integer> markDuplicateRecords(final JavaPairRDD<Integer, Iterable<MarkDuplicatesSparkRecord>> keyedPairs,
private static JavaPairRDD<IndexPair<String>, Integer> markDuplicateRecords(final JavaPairRDD<ReadsKey, Iterable<MarkDuplicatesSparkRecord>> keyedPairs,
final OpticalDuplicateFinder finder) {
return keyedPairs.flatMapToPair(keyedPair -> {
Iterable<MarkDuplicatesSparkRecord> pairGroups = keyedPair._2();

final List<Tuple2<IndexPair<String>, Integer>> nonDuplicates = Lists.newArrayList();
final Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> stratifiedByType = splitByType(pairGroups);

// Each key corresponds to either fragments or paired ends, not a mixture of both.
final List<MarkDuplicatesSparkRecord> emptyFragments = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.EMPTY_FRAGMENT);
final List<MarkDuplicatesSparkRecord> fragments = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.FRAGMENT);
final List<Pair> pairs = (List<Pair>)(List)stratifiedByType.get(MarkDuplicatesSparkRecord.Type.PAIR);
final List<MarkDuplicatesSparkRecord> passthroughs = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.PASSTHROUGH);

//empty MarkDuplicatesSparkRecord signify that a pair has a mate somewhere else
// If there are any non-fragment placeholders at this site, mark everything as duplicates, otherwise compute the best score
if (Utils.isNonEmpty(fragments) && !Utils.isNonEmpty(emptyFragments)) {
final Tuple2<IndexPair<String>, Integer> bestFragment = handleFragments(fragments);
nonDuplicates.add(bestFragment);
}

//since we grouped by a non-unique hash code for efficiency we need to regroup by the actual criteria
//todo this should use library and contig as well probably
//todo these should all be one traversal over the records)
final Collection<List<MarkDuplicatesSparkRecord>> groups = Utils.stream(pairGroups)
.collect(Collectors.groupingBy(MarkDuplicatesSparkUtils::getGroupKey)).values();

for (List<MarkDuplicatesSparkRecord> duplicateGroup : groups) {
final Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> stratifiedByType = splitByType(duplicateGroup);

// Each key corresponds to either fragments or paired ends, not a mixture of both.
final List<MarkDuplicatesSparkRecord> emptyFragments = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.EMPTY_FRAGMENT);
final List<MarkDuplicatesSparkRecord> fragments = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.FRAGMENT);
final List<Pair> pairs = (List<Pair>) (List)(stratifiedByType.get(MarkDuplicatesSparkRecord.Type.PAIR));
final List<MarkDuplicatesSparkRecord> passthroughs = stratifiedByType.get(MarkDuplicatesSparkRecord.Type.PASSTHROUGH);

//empty MarkDuplicatesSparkRecord signify that a pair has a mate somewhere else
// If there are any non-fragment placeholders at this site, mark everything as duplicates, otherwise compute the best score
if (Utils.isNonEmpty(fragments) && !Utils.isNonEmpty(emptyFragments)) {
final Tuple2<IndexPair<String>, Integer> bestFragment = handleFragments(fragments);
nonDuplicates.add(bestFragment);
}

if (Utils.isNonEmpty(pairs)) {
nonDuplicates.add(handlePairs(pairs, finder));
}
if (Utils.isNonEmpty(pairs)) {
nonDuplicates.add(handlePairs(pairs, finder));
}

if (Utils.isNonEmpty(passthroughs)) {
nonDuplicates.addAll(handlePassthroughs(passthroughs));
}
if (Utils.isNonEmpty(passthroughs)) {
nonDuplicates.addAll(handlePassthroughs(passthroughs));
}

return nonDuplicates.iterator();
});
}

// Note, this uses bitshift operators in order to perform only a single groupBy operation for all the merged data
private static long getGroupKey(MarkDuplicatesSparkRecord record) {
if ( record.getClass()==Passthrough.class) {
return -1;
} else {
final PairedEnds pairedEnds = (PairedEnds) record;
return ((((long) pairedEnds.getUnclippedStartPosition()) << 32) |
(pairedEnds.getFirstRefIndex() << 16) |
pairedEnds.getOrientationForPCRDuplicates() );
}
}

/**
* split MarkDuplicatesSparkRecord into groups by their type
*/
private static Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> splitByType(List<MarkDuplicatesSparkRecord> duplicateGroup) {
private static Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> splitByType(Iterable<MarkDuplicatesSparkRecord> duplicateGroup) {
final EnumMap<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> byType = new EnumMap<>(MarkDuplicatesSparkRecord.Type.class);
for(MarkDuplicatesSparkRecord pair: duplicateGroup) {
byType.compute(pair.getType(), (key, value) -> {
Expand Down Expand Up @@ -497,7 +496,7 @@ private PairedEndsCoordinateComparator() { }

@Override
public int compare( PairedEnds first, PairedEnds second ) {
int result = compareCoordinates(first, second);
int result = Integer.compare(first.getFirstStartPosition(), second.getFirstStartPosition());
if ( result != 0 ) {
return result;
}
Expand All @@ -512,24 +511,5 @@ public int compare( PairedEnds first, PairedEnds second ) {
}
return result;
}

public static int compareCoordinates(final PairedEnds first, final PairedEnds second ) {
final int firstRefIndex = first.getFirstRefIndex();
final int secondRefIndex = second.getFirstRefIndex();

if ( firstRefIndex == -1 ) {
return (secondRefIndex == -1 ? 0 : 1);
}
else if ( secondRefIndex == -1 ) {
return -1;
}

final int refIndexDifference = firstRefIndex - secondRefIndex;
if ( refIndexDifference != 0 ) {
return refIndexDifference;
}

return Integer.compare(first.getFirstStartPosition(), second.getFirstStartPosition());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
Expand All @@ -13,6 +14,8 @@
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtils;
import org.broadinstitute.hellbender.utils.read.markduplicates.LibraryIdGenerator;
import picard.cmdline.programgroups.DiagnosticsAndQCProgramGroup;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
Expand Down Expand Up @@ -134,19 +137,21 @@ protected void runTool(final JavaSparkContext ctx) {
}
System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);

Broadcast<Map<String, Byte>> libraryIndex = ctx.broadcast(MarkDuplicatesSparkUtils.constructLibraryIndex(getHeaderForReads()));

Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
// Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
JavaPairRDD<Integer, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.hashKeyForFragment(
JavaPairRDD<ReadsKey, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.getKeyForFragment(
ReadUtils.getStrandedUnclippedStart(read),
read.isReverseStrand(),
ReadUtils.getReferenceIndex(read,bHeader.getValue()),
ReadUtils.getLibrary(read, bHeader.getValue())), read));
JavaPairRDD<Integer, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.hashKeyForFragment(
libraryIndex.getValue().get(ReadUtils.getLibrary(read, bHeader.getValue(), LibraryIdGenerator.UNKNOWN_LIBRARY))), read));
JavaPairRDD<ReadsKey, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.getKeyForFragment(
ReadUtils.getStrandedUnclippedStart(read),
read.isReverseStrand(),
ReadUtils.getReferenceIndex(read,bHeader.getValue()),
ReadUtils.getLibrary(read, bHeader.getValue())), read));
JavaPairRDD<Integer, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
libraryIndex.getValue().get(ReadUtils.getLibrary(read, bHeader.getValue(), LibraryIdGenerator.UNKNOWN_LIBRARY))), read));
JavaPairRDD<ReadsKey, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());


// Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
Expand All @@ -159,8 +164,6 @@ protected void runTool(final JavaSparkContext ctx) {

return getDupes(iFirstReads, iSecondReads, header);
});

// TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 ->
new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();

Expand Down Expand Up @@ -232,7 +235,9 @@ static MatchType getDupes(Iterable<GATKRead> f, Iterable<GATKRead> s, SAMFileHea
static JavaRDD<GATKRead> filteredReads(JavaRDD<GATKRead> initialReads, String fileName) {
// We only need to compare duplicates that are "primary" (i.g., primary mapped read).
return initialReads.map((Function<GATKRead, GATKRead>) v1 -> {
String rg = v1.getReadGroup();
v1.clearAttributes();
v1.setReadGroup(rg);
return v1;
}).filter(v1 -> {
if (ReadUtils.isNonPrimary(v1) && v1.isDuplicate()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.markduplicates.LibraryIdGenerator;
import org.broadinstitute.hellbender.utils.recalibration.EventType;

/**
Expand Down Expand Up @@ -274,6 +275,21 @@ public static String getLibrary( final GATKRead read, final SAMFileHeader header
return readGroup != null ? readGroup.getLibrary() : null;
}

/**
* Returns the library associated with the provided read's read group.
* Or the specified default if no library is found
*
* @param read read whose library to retrieve
* @param header SAM header containing read groups
* @return the library for the provided read's read group as a String,
* or the default value if the read has no read group.
*/
public static String getLibrary( final GATKRead read, final SAMFileHeader header, String defaultLibrary) {
final SAMReadGroupRecord readGroup = getSAMReadGroupRecord(read, header);
String library = readGroup != null ? readGroup.getLibrary() : null;
return library==null? defaultLibrary : library;
}

/**
* Returns the sample name associated with the provided read's read group.
*
Expand Down
Loading

0 comments on commit ba62c2a

Please sign in to comment.