Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort output from SVClusterEngine and fix no-call genotype ploidy bug in JointGermlineCNVSegmentation #7779

Merged
merged 7 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ public static <T extends SVCallRecord> Comparator<T> getCallComparator(final SAM
return (o1, o2) -> compareCalls(o1, o2, dictionary);
}

/**
* Gets a comparator for objects implementing {@link SVLocatable}.
* @param dictionary sequence dictionary pertaining to both records (not validated)
* @param <T> record class
* @return comparator
*/
public static <T extends SVLocatable> Comparator<T> getSVLocatableComparator(final SAMSequenceDictionary dictionary) {
return (o1, o2) -> compareSVLocatables(o1, o2, dictionary);
}

/**
* Compares two objects based on start and end positions.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package org.broadinstitute.hellbender.tools.sv.cluster;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.BoundType;
import com.google.common.collect.SortedMultiset;
import com.google.common.collect.TreeMultiset;
import htsjdk.samtools.SAMSequenceDictionary;
import org.broadinstitute.hellbender.tools.sv.SVCallRecordUtils;
import org.broadinstitute.hellbender.tools.sv.SVLocatable;
import org.broadinstitute.hellbender.utils.Utils;

Expand Down Expand Up @@ -36,13 +41,14 @@ public enum CLUSTERING_TYPE {
private final SVClusterLinkage<T> linkage;
private Map<Integer, Cluster> idToClusterMap; // Active clusters
private final Map<Integer, T> idToItemMap; // Active items
private final List<T> outputBuffer;
protected final CLUSTERING_TYPE clusteringType;
private final ItemSortingBuffer buffer;
private final Comparator<T> itemComparator;

private String currentContig;
private int nextItemId;
private int nextClusterId;
private int lastStart;
private Integer minActiveStartingPosition;
private Integer minActiveStartingPositionItemId;

/**
Expand All @@ -51,39 +57,37 @@ public enum CLUSTERING_TYPE {
*/
public SVClusterEngine(final CLUSTERING_TYPE clusteringType,
final SVCollapser<T> collapser,
final SVClusterLinkage<T> linkage) {
final SVClusterLinkage<T> linkage,
final SAMSequenceDictionary dictionary) {
this.clusteringType = clusteringType;
this.collapser = Utils.nonNull(collapser);
this.linkage = Utils.nonNull(linkage);
idToClusterMap = new HashMap<>();
outputBuffer = new ArrayList<>();
currentContig = null;
idToItemMap = new HashMap<>();
itemComparator = SVCallRecordUtils.getSVLocatableComparator(dictionary);
buffer = new ItemSortingBuffer();
nextItemId = 0;
nextClusterId = 0;
lastStart = 0;
minActiveStartingPosition = null;
minActiveStartingPositionItemId = null;

}


/**
* Flushes all active clusters, adding them to the output buffer. Results from the output buffer are then copied out
* and the buffer is cleared. This should be called between contigs to save memory.
*/
public final List<T> forceFlushAndGetOutput() {
public final List<T> forceFlush() {
flushClusters();
return getOutput();
return buffer.forceFlush();
}

/**
* Gets any available finalized clusters.
*/
public final List<T> getOutput() {
final List<T> output = new ArrayList<>(outputBuffer);
outputBuffer.clear();
return output;
public final List<T> flush() {
return buffer.flush();
}

@VisibleForTesting
Expand All @@ -96,15 +100,17 @@ public SVClusterLinkage<T> getLinkage() {
return linkage;
}

public Integer getMinActiveStartingPosition() {
return minActiveStartingPosition;
public T getMinActiveStartingPositionItem() {
Utils.validate(minActiveStartingPositionItemId == null || idToItemMap.containsKey(minActiveStartingPositionItemId),
"Unregistered item id " + minActiveStartingPositionItemId);
return idToItemMap.get(minActiveStartingPositionItemId);
}

/**
* Returns true if there are any active or finalized clusters.
*/
public final boolean isEmpty() {
return idToClusterMap.isEmpty() && outputBuffer.isEmpty();
return idToClusterMap.isEmpty() && buffer.isEmpty();
}

/**
Expand All @@ -130,8 +136,7 @@ private final int registerItem(final T item) {
lastStart = item.getPositionA();
final int itemId = nextItemId++;
idToItemMap.put(itemId, item);
if (minActiveStartingPosition == null || item.getPositionA() < minActiveStartingPosition) {
minActiveStartingPosition = item.getPositionA();
if (minActiveStartingPositionItemId == null || item.getPositionA() < getMinActiveStartingPositionItem().getPositionA()) {
minActiveStartingPositionItemId = itemId;
}
return itemId;
Expand Down Expand Up @@ -263,7 +268,7 @@ private final void processCluster(final int clusterIndex) {
final Cluster cluster = getCluster(clusterIndex);
idToClusterMap.remove(clusterIndex);
final List<Integer> clusterItemIds = cluster.getItemIds();
outputBuffer.add(collapser.collapse(clusterItemIds.stream().map(idToItemMap::get).collect(Collectors.toList())));
buffer.add(collapser.collapse(clusterItemIds.stream().map(idToItemMap::get).collect(Collectors.toList())));
// Clean up item id map
if (clusterItemIds.size() == 1) {
// Singletons won't be present in any other clusters
Expand All @@ -290,13 +295,13 @@ private final void processCluster(final int clusterIndex) {
* Scans active items for the current min active starting position.
*/
private final void findAndSetMinActiveStart() {
minActiveStartingPosition = null;
minActiveStartingPositionItemId = null;
T minActiveStartingPositionItem = null;
for (final Integer itemId : idToItemMap.keySet()) {
final T item = idToItemMap.get(itemId);
if (minActiveStartingPosition == null || item.getPositionA() < minActiveStartingPosition) {
minActiveStartingPosition = item.getPositionA();
if (minActiveStartingPositionItemId == null || itemComparator.compare(item, minActiveStartingPositionItem) < 0) {
minActiveStartingPositionItemId = itemId;
minActiveStartingPositionItem = idToItemMap.get(itemId);
}
}
}
Expand All @@ -320,7 +325,6 @@ private final void flushClusters() {
processCluster(clusterId);
}
idToItemMap.clear();
minActiveStartingPosition = null;
minActiveStartingPositionItemId = null;
nextItemId = 0;
nextClusterId = 0;
Expand Down Expand Up @@ -416,4 +420,50 @@ public int hashCode() {
return Objects.hash(itemIds);
}
}

private final class ItemSortingBuffer {
private SortedMultiset<T> buffer;

public ItemSortingBuffer() {
Utils.nonNull(itemComparator);
this.buffer = TreeMultiset.create(itemComparator);
}

public void add(final T record) {
buffer.add(record);
}

/**
* Returns any records that can be safely flushed based on the current minimum starting position
* of items still being actively clustered.
*/
public List<T> flush() {
if (buffer.isEmpty()) {
return Collections.emptyList();
}
final T minActiveStartItem = getMinActiveStartingPositionItem();
if (minActiveStartItem == null) {
forceFlush();
}
final SortedMultiset<T> finalizedRecordView = buffer.headMultiset(minActiveStartItem, BoundType.CLOSED);
final ArrayList<T> finalizedRecords = new ArrayList<>(finalizedRecordView);
// Clearing a view of the buffer also clears the items from the buffer itself
finalizedRecordView.clear();
return finalizedRecords;
}

/**
* Returns all buffered records, regardless of any active clusters. To be used only when certain that no
* active clusters can be clustered with any future inputs.
*/
public List<T> forceFlush() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is for special cases, like the tests below, and the task end. Is that true? Regardless, please add some javadoc, esp. for differentiating use cases with the above method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok added a comment. This is to be used only by forceFlush() in the engine itself, which is only called when we're certain that none of the currently active clusters can change. This is yes usually when reaching the end of a contig (or file).

final List<T> result = new ArrayList<>(buffer);
buffer.clear();
return result;
}

public boolean isEmpty() {
return buffer.isEmpty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static SVClusterEngine<SVCallRecord> createCanonical(final SVClusterEngin
linkage.setDepthOnlyParams(depthParameters);
linkage.setMixedParams(mixedParameters);
linkage.setEvidenceParams(pesrParameters);
return new SVClusterEngine<>(type, new CanonicalSVCollapser(reference, altAlleleSummaryStrategy, breakpointSummaryStrategy, insertionLengthSummaryStrategy), linkage);
return new SVClusterEngine<>(type, new CanonicalSVCollapser(reference, altAlleleSummaryStrategy, breakpointSummaryStrategy, insertionLengthSummaryStrategy), linkage, dictionary);
}

public static SVClusterEngine<SVCallRecord> createCNVDefragmenter(final SAMSequenceDictionary dictionary,
Expand All @@ -36,7 +36,7 @@ public static SVClusterEngine<SVCallRecord> createCNVDefragmenter(final SAMSeque
final double minSampleOverlap) {
final SVClusterLinkage<SVCallRecord> linkage = new CNVLinkage(dictionary, paddingFraction, minSampleOverlap);
final SVCollapser<SVCallRecord> collapser = new CanonicalSVCollapser(reference, altAlleleSummaryStrategy, CanonicalSVCollapser.BreakpointSummaryStrategy.MIN_START_MAX_END, CanonicalSVCollapser.InsertionLengthSummaryStrategy.MEDIAN);
return new SVClusterEngine<>(SVClusterEngine.CLUSTERING_TYPE.SINGLE_LINKAGE, collapser, linkage);
return new SVClusterEngine<>(SVClusterEngine.CLUSTERING_TYPE.SINGLE_LINKAGE, collapser, linkage, dictionary);
}

public static SVClusterEngine<SVCallRecord> createBinnedCNVDefragmenter(final SAMSequenceDictionary dictionary,
Expand All @@ -47,6 +47,6 @@ public static SVClusterEngine<SVCallRecord> createBinnedCNVDefragmenter(final SA
final List<GenomeLoc> coverageIntervals) {
final SVClusterLinkage<SVCallRecord> linkage = new BinnedCNVLinkage(dictionary, paddingFraction, minSampleOverlap, coverageIntervals);
final SVCollapser<SVCallRecord> collapser = new CanonicalSVCollapser(reference, altAlleleSummaryStrategy, CanonicalSVCollapser.BreakpointSummaryStrategy.MIN_START_MAX_END, CanonicalSVCollapser.InsertionLengthSummaryStrategy.MEDIAN);
return new SVClusterEngine<>(SVClusterEngine.CLUSTERING_TYPE.SINGLE_LINKAGE, collapser, linkage);
return new SVClusterEngine<>(SVClusterEngine.CLUSTERING_TYPE.SINGLE_LINKAGE, collapser, linkage, dictionary);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public int getEndPosition() {
doc="Cluster events whose endpoints are within this distance of each other", optional=true)
public int clusterWindow = CanonicalSVLinkage.DEFAULT_DEPTH_ONLY_PARAMS.getWindow();

@Argument(fullName = MODEL_CALL_INTERVALS_LONG_NAME, doc = "gCNV model intervals created with the FilterIntervals tool.")
@Argument(fullName = MODEL_CALL_INTERVALS_LONG_NAME, doc = "gCNV model intervals created with the FilterIntervals tool.", optional=true)
private GATKPath modelCallIntervalList = null;

@Argument(fullName = BREAKPOINT_SUMMARY_STRATEGY_LONG_NAME, doc = "Strategy to use for choosing a representative value for a breakpoint cluster.", optional = true)
Expand Down Expand Up @@ -204,6 +204,7 @@ public void onTraversalStart() {
//dictionary will not be null because this tool requiresReference()

final GenomeLocParser parser = new GenomeLocParser(this.dictionary);

setIntervals(parser);

final ClusteringParameters clusterArgs = ClusteringParameters.createDepthParameters(clusterIntervalOverlap, clusterWindow, CLUSTER_SAMPLE_OVERLAP_FRACTION);
Expand Down Expand Up @@ -267,12 +268,12 @@ private VariantContextWriter getVCFWriter() {
*/
@Override
public void apply(final List<VariantContext> variantContexts, final ReferenceContext referenceContext, final List<ReadsContext> readsContexts) {
if (currentContig == null) {
currentContig = variantContexts.get(0).getContig(); //variantContexts should have identical start, so choose 0th arbitrarily
} else if (!variantContexts.get(0).getContig().equals(currentContig)) {
//variantContexts should have identical start, so choose 0th arbitrarily
final String variantContig = variantContexts.get(0).getContig();
if (currentContig != null && !variantContig.equals(currentContig)) {
processClusters();
currentContig = variantContexts.get(0).getContig();
}
currentContig = variantContig;
for (final VariantContext vc : variantContexts) {
final SVCallRecord record = createDepthOnlyFromGCNVWithOriginalGenotypes(vc, minQS, allosomalContigs, refAutosomalCopyNumber, sampleDB);
if (record != null) {
Expand All @@ -291,13 +292,16 @@ public Object onTraversalSuccess() {
return null;
}

/**
* Force-flushes the defragmenter, adds the resulting calls to the clustering engine, and flushes the clustering
* engine. Since we need to check for variant overlap and reset genotypes, only flush clustering when we hit a
* new contig.
*/
private void processClusters() {
if (!defragmenter.isEmpty()) {
final List<SVCallRecord> defragmentedCalls = defragmenter.forceFlushAndGetOutput();
defragmentedCalls.stream().forEachOrdered(clusterEngine::add);
}
final List<SVCallRecord> defragmentedCalls = defragmenter.forceFlush();
defragmentedCalls.stream().forEachOrdered(clusterEngine::add);
//Jack and Isaac cluster first and then defragment
final List<SVCallRecord> clusteredCalls = clusterEngine.forceFlushAndGetOutput();
final List<SVCallRecord> clusteredCalls = clusterEngine.forceFlush();
write(clusteredCalls);
}

Expand All @@ -315,12 +319,10 @@ private VariantContext buildAndSanitizeRecord(final SVCallRecord record) {
}

private void write(final List<SVCallRecord> calls) {
final List<VariantContext> sortedCalls = calls.stream()
.sorted(Comparator.comparing(c -> new SimpleInterval(c.getContigA(), c.getPositionA(), c.getPositionB()), //VCs have to be sorted by end as well
IntervalUtils.getDictionaryOrderComparator(dictionary)))
final List<VariantContext> sanitizedRecords = calls.stream()
.map(this::buildAndSanitizeRecord)
.collect(Collectors.toList());
final Iterator<VariantContext> it = sortedCalls.iterator();
final Iterator<VariantContext> it = sanitizedRecords.iterator();
ArrayList<VariantContext> overlappingVCs = new ArrayList<>(calls.size());
if (!it.hasNext()) {
return;
Expand Down Expand Up @@ -680,16 +682,28 @@ private static Genotype prepareGenotype(final Genotype g, final Allele refAllele
return builder.make();
}

/**
* "Fills" genotype alleles so that it has the correct ploidy
* @param builder new alleles will be set for this builder
* @param g non-ref alleles will be carried over from this genotype
* @param ploidy desired ploidy for the new genotype
* @param refAllele desired ref allele for new genotype
*/
private static void correctGenotypePloidy(final GenotypeBuilder builder, final Genotype g, final int ploidy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I dropped the ball, but can you add some comments here? This is for overlapping events, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This just modifies the genotypes for input variants so the ploidies are consistent with the ped file and also irons out any no-call/null GTs.

final Allele refAllele) {
final ArrayList<Allele> alleles = new ArrayList<>(g.getAlleles());
Utils.validate(alleles.size() <= ploidy, "Encountered genotype with ploidy " + ploidy + " but " +
alleles.size() + " alleles.");
while (alleles.size() < ploidy) {
alleles.add(refAllele);
if (g.getAlleles().size() == 1 && g.getAllele(0).isNoCall()) {
// Special case to force interpretation of a single no-call allele as a possible null GT
builder.alleles(Collections.nCopies(ploidy, Allele.NO_CALL));
} else {
final ArrayList<Allele> alleles = new ArrayList<>(g.getAlleles());
Utils.validate(alleles.size() <= ploidy, "Encountered genotype with ploidy " + ploidy +
" but " + alleles.size() + " alleles.");
while (alleles.size() < ploidy) {
alleles.add(refAllele);
}
alleles.trimToSize();
builder.alleles(alleles);
}
alleles.trimToSize();
builder.alleles(alleles);
}

private static void addExpectedCopyNumber(final GenotypeBuilder g, final int ploidy) {
Expand Down
Loading