Skip to content

Commit

Permalink
Improve adaptive pruner to be smarter about adjacent chains with errors
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin committed Apr 30, 2020
1 parent adc85cb commit 5806cef
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class HaplotypeCallerReadThreadingAssemblerArgumentCollection extends Rea
public ReadThreadingAssembler makeReadThreadingAssembler() {
final ReadThreadingAssembler assemblyEngine = new ReadThreadingAssembler(maxNumHaplotypesInPopulation, kmerSizes,
dontIncreaseKmerSizesForCycles, allowNonUniqueKmersInRef, numPruningSamples, useAdaptivePruning ? 0 : minPruneFactor,
useAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph);
useAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph);
assemblyEngine.setDebugGraphTransformations(debugGraphTransformations);
assemblyEngine.setRecoverDanglingBranches(!doNotRecoverDanglingBranches);
assemblyEngine.setRecoverAllDanglingBranches(recoverAllDanglingBranches);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class MutectReadThreadingAssemblerArgumentCollection extends ReadThreadin
public ReadThreadingAssembler makeReadThreadingAssembler() {
final ReadThreadingAssembler assemblyEngine = new ReadThreadingAssembler(maxNumHaplotypesInPopulation, kmerSizes,
dontIncreaseKmerSizesForCycles, allowNonUniqueKmersInRef, numPruningSamples, disableAdaptivePruning ? minPruneFactor : 0,
!disableAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph);
!disableAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph);
assemblyEngine.setDebugGraphTransformations(debugGraphTransformations);
assemblyEngine.setRecoverDanglingBranches(true);
assemblyEngine.setRecoverAllDanglingBranches(recoverAllDanglingBranches);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public abstract class ReadThreadingAssemblerArgumentCollection implements Serial
private static final long serialVersionUID = 1L;

public static final double DEFAULT_PRUNING_LOG_ODDS_THRESHOLD = MathUtils.log10ToLog(1.0);
public static final double DEFAULT_PRUNING_SEEDING_LOG_ODDS_THRESHOLD = MathUtils.log10ToLog(2.0);

public static final String ERROR_CORRECT_READS_LONG_NAME = "error-correct-reads";
public static final String PILEUP_ERROR_CORRECTION_LOG_ODDS_NAME = "error-correction-log-odds";
Expand Down Expand Up @@ -118,6 +119,13 @@ public abstract class ReadThreadingAssemblerArgumentCollection implements Serial
@Argument(fullName="pruning-lod-threshold", doc = "Ln likelihood ratio threshold for adaptive pruning algorithm", optional = true)
public double pruningLogOddsThreshold = DEFAULT_PRUNING_LOG_ODDS_THRESHOLD;

/**
* Log-10 likelihood ratio threshold for adaptive pruning algorithm.
*/
@Advanced
@Argument(fullName="pruning-seeding-lod-threshold", doc = "Ln likelihood ratio threshold for seeding subgraph of good variation in adaptive pruning algorithm", optional = true)
public double pruningSeedingLogOddsThreshold = DEFAULT_PRUNING_SEEDING_LOG_ODDS_THRESHOLD;

/**
* The maximum number of variants in graph the adaptive pruner will allow
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import org.apache.commons.math3.util.FastMath;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

import java.util.*;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;

public class AdaptiveChainPruner<V extends BaseVertex, E extends BaseEdge> extends ChainPruner<V,E> {
private final double initialErrorProbability;
private final double logOddsThreshold;
private final double seedingLogOddsThreshold; // threshold for seeding subgraph of good vertices
private final int maxUnprunedVariants;

public AdaptiveChainPruner(final double initialErrorProbability, final double logOddsThreshold, final int maxUnprunedVariants) {
public AdaptiveChainPruner(final double initialErrorProbability, final double logOddsThreshold, final double seedingLogOddsThreshold, final int maxUnprunedVariants) {
ParamUtils.isPositive(initialErrorProbability, "Must have positive error probability");
this.initialErrorProbability = initialErrorProbability;
this.logOddsThreshold = logOddsThreshold;
this.seedingLogOddsThreshold = seedingLogOddsThreshold;
this.maxUnprunedVariants = maxUnprunedVariants;
}

Expand All @@ -29,41 +33,107 @@ protected Collection<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains) {

final BaseGraph<V,E> graph = chains.get(0).getGraph();



Collection<Path<V,E>> probableErrorChains = likelyErrorChains(chains, graph, initialErrorProbability);
final int errorCount = probableErrorChains.stream().mapToInt(c -> c.getLastEdge().getMultiplicity()).sum();
final int totalBases = chains.stream().mapToInt(c -> c.getEdges().stream().mapToInt(E::getMultiplicity).sum()).sum();
final double errorRate = (double) errorCount / totalBases;

return likelyErrorChains(chains, graph, errorRate);
return likelyErrorChains(chains, graph, errorRate).stream().filter(c -> !c.getEdges().stream().anyMatch(BaseEdge::isRef)).collect(Collectors.toList());
}

private Collection<Path<V,E>> likelyErrorChains(final List<Path<V, E>> chains, final BaseGraph<V,E> graph, final double errorRate) {
final Map<Path<V,E>, Double> chainLogOdds = chains.stream()
/////// NEW

// pre-compute the left and right log odds of each chain
final Map<Path<V,E>, Pair<Double, Double>> chainLogOdds = chains.stream()
.collect(Collectors.toMap(c -> c, c-> chainLogOdds(c, graph, errorRate)));

final Set<Path<V,E>> result = new HashSet<>(chains.size());

chainLogOdds.forEach((chain, lod) -> {
if (lod < logOddsThreshold) {
result.add(chain);

// compute correspondence of vertices to incident chains with log odds above the seeding and extending thresholds
final Multimap<V, Path<V,E>> vertexToSeedableChains = ArrayListMultimap.create();
final Multimap<V, Path<V,E>> vertexToGoodIncomingChains = ArrayListMultimap.create();
final Multimap<V, Path<V,E>> vertexToGoodOutgoingChains = ArrayListMultimap.create();

for (final Path<V,E> chain : chains) {
if (chainLogOdds.get(chain).getRight() >= logOddsThreshold) {
vertexToGoodIncomingChains.put(chain.getLastVertex(), chain);
}

if (chainLogOdds.get(chain).getLeft() >= logOddsThreshold) {
vertexToGoodOutgoingChains.put(chain.getFirstVertex(), chain);
}

// seed-worthy chains must pass the more stringent seeding log odds threshold on both sides
// in addition to that, we only seed from vertices with multiple such chains incoming or outgoing (see below)
if (chainLogOdds.get(chain).getRight() >= seedingLogOddsThreshold && chainLogOdds.get(chain).getLeft() >= seedingLogOddsThreshold) {
vertexToSeedableChains.put(chain.getFirstVertex(), chain);
vertexToSeedableChains.put(chain.getLastVertex(), chain);
}
}

// find a subset of good vertices from which to grow the subgraph of good chains.
final Queue<V> verticesToProcess = new ArrayDeque<>();
final Path<V,E> maxWeightChain = getMaxWeightChain(chains);
verticesToProcess.add(maxWeightChain.getFirstVertex());
verticesToProcess.add(maxWeightChain.getLastVertex());

// look for vertices with two incoming or two outgoing chains (plus one outgoing or incoming for a total of 3 or more) with good log odds to seed the subgraph of good vertices
// the logic here is that a high-multiplicity error chain A that branches into a second error chain B and a continuation-of-the-original-error chain A'
// may have a high log odds for A'. However, only in the case of true variation will multiple branches leaving the same vertex have good log odds.
vertexToSeedableChains.keySet().stream()
.filter(v -> vertexToSeedableChains.get(v).size() > 2)
.forEach(verticesToProcess::add);

final Set<V> processedVertices = new HashSet<>();
final Set<Path<V,E>> goodChains = new HashSet<>();

// starting from the high-confidence seed vertices, grow the "good" subgraph along chains with above-threshold log odds,
// discovering good chains as we go.
while (!verticesToProcess.isEmpty()) {
final V vertex = verticesToProcess.poll();
processedVertices.add(vertex);
for (final Path<V,E> outgoingChain : vertexToGoodOutgoingChains.get(vertex)) {
goodChains.add(outgoingChain);
if(!processedVertices.contains(outgoingChain.getLastVertex())) {
verticesToProcess.add(outgoingChain.getLastVertex());
}
}
});

chains.stream().filter(c -> isChainPossibleVariant(c, graph))
.sorted(Comparator.comparingDouble((ToDoubleFunction<Path<V, E>>) chainLogOdds::get)
.reversed().thenComparingInt(Path::length))
for (final Path<V,E> incomingChain : vertexToGoodIncomingChains.get(vertex)) {
goodChains.add(incomingChain);
if(!processedVertices.contains(incomingChain.getFirstVertex())) {
verticesToProcess.add(incomingChain.getFirstVertex());
}
}
}

final Set<Path<V,E>> errorChains = chains.stream().filter(c -> !goodChains.contains(c)).collect(Collectors.toSet());

// add non-error chains to error chains if maximum number of variants has been exceeded
chains.stream()
.filter(c -> !errorChains.contains(c))
.filter(c -> isChainPossibleVariant(c, graph))
.sorted(Comparator.comparingDouble(c -> Math.min(chainLogOdds.get(c).getLeft(), chainLogOdds.get(c).getLeft())).reversed())
.skip(maxUnprunedVariants)
.forEach(result::add);
.forEach(errorChains::add);

return result;
return errorChains;

}

private double chainLogOdds(final Path<V,E> chain, final BaseGraph<V,E> graph, final double errorRate) {
if (chain.getEdges().stream().anyMatch(E::isRef)) {
return Double.POSITIVE_INFINITY;
}
// find the chain containing the edge of greatest weight, taking care to break ties deterministically
private Path<V, E> getMaxWeightChain(final Collection<Path<V, E>> chains) {
return chains.stream()
.max(Comparator.comparingInt((Path<V, E> chain) -> chain.getEdges().stream().mapToInt(BaseEdge::getMultiplicity).max().orElse(0))
.thenComparingInt(Path::length)
.thenComparingInt(Path::hashCode)).get();
}

// left and right chain log odds
private Pair<Double, Double> chainLogOdds(final Path<V,E> chain, final BaseGraph<V,E> graph, final double errorRate) {
final int leftTotalMultiplicity = MathUtils.sumIntFunction(graph.outgoingEdgesOf(chain.getFirstVertex()), E::getMultiplicity);
final int rightTotalMultiplicity = MathUtils.sumIntFunction(graph.incomingEdgesOf(chain.getLastVertex()), E::getMultiplicity);

Expand All @@ -75,7 +145,7 @@ private double chainLogOdds(final Path<V,E> chain, final BaseGraph<V,E> graph, f
final double rightLogOdds = graph.isSink(chain.getLastVertex()) ? 0.0 :
Mutect2Engine.logLikelihoodRatio(rightTotalMultiplicity - rightMultiplicity, rightMultiplicity, errorRate);

return FastMath.max(leftLogOdds, rightLogOdds);
return ImmutablePair.of(leftLogOdds, rightLogOdds);
}

private boolean isChainPossibleVariant(final Path<V,E> chain, final BaseGraph<V,E> graph) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResult;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.NearbyKmerErrorCorrector;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReadErrorCorrector;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.*;
import org.broadinstitute.hellbender.utils.Histogram;
Expand Down Expand Up @@ -75,7 +74,7 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler
final boolean dontIncreaseKmerSizesForCycles, final boolean allowNonUniqueKmersInRef,
final int numPruningSamples, final int pruneFactor, final boolean useAdaptivePruning,
final double initialErrorRateForPruning, final double pruningLogOddsThreshold,
final int maxUnprunedVariants, final boolean useLinkedDebruijnGraphs) {
final double pruningSeedingLogOddsThreshold, final int maxUnprunedVariants, final boolean useLinkedDebruijnGraphs) {
Utils.validateArg( maxAllowedPathsForReadThreadingAssembler >= 1, "numBestHaplotypesPerGraph should be >= 1 but got " + maxAllowedPathsForReadThreadingAssembler);
this.kmerSizes = new ArrayList<>(kmerSizes);
kmerSizes.sort(Integer::compareTo);
Expand All @@ -88,14 +87,14 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler
logger.error("JunctionTreeLinkedDeBruijnGraph is enabled.\n This is an experimental assembly graph mode that has not been fully validated\n\n");
}

chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants) :
chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants) :
new LowWeightChainPruner<>(pruneFactor);
numBestHaplotypesPerGraph = maxAllowedPathsForReadThreadingAssembler;
}

@VisibleForTesting
ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes, final int pruneFactor) {
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor, false, 0.001, 2, Integer.MAX_VALUE, false);
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor, false, 0.001, 2, 2, Integer.MAX_VALUE, false);
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void testAdaptivePruning(final int kmerSize, final byte[] ref, final byte
// note: these are the steps in ReadThreadingAssembler::createGraph
graph.buildGraphIfNecessary();

final ChainPruner<MultiDeBruijnVertex, MultiSampleEdge> pruner = new AdaptiveChainPruner<>(0.001, logOddsThreshold, 50);
final ChainPruner<MultiDeBruijnVertex, MultiSampleEdge> pruner = new AdaptiveChainPruner<>(0.001, logOddsThreshold, 2.0, 50);
pruner.pruneLowWeightChains(graph);

final SmithWatermanAligner aligner = SmithWatermanJavaAligner.getInstance();
Expand All @@ -175,6 +175,47 @@ public void testAdaptivePruning(final int kmerSize, final byte[] ref, final byte
Assert.assertTrue(bestPaths.size() < 15);
}

// test that in graph with good path A -> B -> C and bad edges A -> D -> C and D -> B that the adjacency of bad edges --
// such that when bad edges meet the multiplicities do not indicate an error - does not harm pruning.
// we test with and without a true variant path A -> E -> C
@Test
public void testAdaptivePruningWithAdjacentBadEdges() {
final int goodMultiplicity = 1000;
final int variantMultiplicity = 50;
final int badMultiplicity = 5;

final SeqVertex source = new SeqVertex("source");
final SeqVertex sink = new SeqVertex("sink");
final SeqVertex A = new SeqVertex("A");
final SeqVertex B = new SeqVertex("B");
final SeqVertex C = new SeqVertex("C");
final SeqVertex D = new SeqVertex("D");
final SeqVertex E = new SeqVertex("E");


for (boolean variantPresent : new boolean[] {false, true}) {
final SeqGraph graph = new SeqGraph(20);

graph.addVertices(source, A, B, C, D, sink);
graph.addEdges(() -> new BaseEdge(true, goodMultiplicity), source, A, B, C, sink);
graph.addEdges(() -> new BaseEdge(false, badMultiplicity), A, D, C);
graph.addEdges(() -> new BaseEdge(false, badMultiplicity), D, B);

if (variantPresent) {
graph.addVertices(E);
graph.addEdges(() -> new BaseEdge(false, variantMultiplicity), A, E, C);
}

final ChainPruner<SeqVertex, BaseEdge> pruner = new AdaptiveChainPruner<>(0.01, 2.0, 2.0, 50);
pruner.pruneLowWeightChains(graph);

Assert.assertFalse(graph.containsVertex(D));
if (variantPresent) {
Assert.assertTrue(graph.containsVertex(E));
}
}
}

@DataProvider(name = "chainPrunerData")
public Object[][] getChainPrunerData() {
final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(9));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ public void testMitochondria() {
"chrM:750-750 A*, [G]");
Assert.assertTrue(variantKeys.containsAll(expectedKeys));

Assert.assertEquals(variants.get(0).getAttributeAsInt(GATKVCFConstants.ORIGINAL_CONTIG_MISMATCH_KEY, 0), 1672);
Assert.assertEquals(variants.get(0).getAttributeAsInt(GATKVCFConstants.ORIGINAL_CONTIG_MISMATCH_KEY, 0), 1664);
}

@DataProvider(name = "vcfsForFiltering")
Expand Down

0 comments on commit 5806cef

Please sign in to comment.