From 4904fc85db967b889cd7301de2d3e61635f8412b Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Sun, 22 Mar 2020 22:27:31 -0400 Subject: [PATCH] Improve adaptive pruner to be smarter about adjacent chains with errors --- ...dThreadingAssemblerArgumentCollection.java | 2 +- ...dThreadingAssemblerArgumentCollection.java | 2 +- ...dThreadingAssemblerArgumentCollection.java | 8 ++ .../graphs/AdaptiveChainPruner.java | 108 ++++++++++++++---- .../readthreading/ReadThreadingAssembler.java | 7 +- .../graphs/ChainPrunerUnitTest.java | 43 ++++++- .../mutect/Mutect2IntegrationTest.java | 4 +- 7 files changed, 145 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/HaplotypeCallerReadThreadingAssemblerArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/HaplotypeCallerReadThreadingAssemblerArgumentCollection.java index 39d236f6930..fcd892aa392 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/HaplotypeCallerReadThreadingAssemblerArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/HaplotypeCallerReadThreadingAssemblerArgumentCollection.java @@ -36,7 +36,7 @@ public class HaplotypeCallerReadThreadingAssemblerArgumentCollection extends Rea public ReadThreadingAssembler makeReadThreadingAssembler() { final ReadThreadingAssembler assemblyEngine = new ReadThreadingAssembler(maxNumHaplotypesInPopulation, kmerSizes, dontIncreaseKmerSizesForCycles, allowNonUniqueKmersInRef, numPruningSamples, useAdaptivePruning ? 0 : minPruneFactor, - useAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph); + useAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph); assemblyEngine.setDebugGraphTransformations(debugGraphTransformations); assemblyEngine.setRecoverDanglingBranches(!doNotRecoverDanglingBranches); assemblyEngine.setRecoverAllDanglingBranches(recoverAllDanglingBranches); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/MutectReadThreadingAssemblerArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/MutectReadThreadingAssemblerArgumentCollection.java index 0a2af77c104..cb1d361c227 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/MutectReadThreadingAssemblerArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/MutectReadThreadingAssemblerArgumentCollection.java @@ -22,7 +22,7 @@ public class MutectReadThreadingAssemblerArgumentCollection extends ReadThreadin public ReadThreadingAssembler makeReadThreadingAssembler() { final ReadThreadingAssembler assemblyEngine = new ReadThreadingAssembler(maxNumHaplotypesInPopulation, kmerSizes, dontIncreaseKmerSizesForCycles, allowNonUniqueKmersInRef, numPruningSamples, disableAdaptivePruning ? minPruneFactor : 0, - !disableAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph); + !disableAdaptivePruning, initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants, useLinkedDeBruijnGraph); assemblyEngine.setDebugGraphTransformations(debugGraphTransformations); assemblyEngine.setRecoverDanglingBranches(true); assemblyEngine.setRecoverAllDanglingBranches(recoverAllDanglingBranches); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/ReadThreadingAssemblerArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/ReadThreadingAssemblerArgumentCollection.java index e2a4c7bea6a..0216072a649 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/ReadThreadingAssemblerArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/ReadThreadingAssemblerArgumentCollection.java @@ -18,6 +18,7 @@ public abstract class ReadThreadingAssemblerArgumentCollection implements Serial private static final long serialVersionUID = 1L; public static final double DEFAULT_PRUNING_LOG_ODDS_THRESHOLD = MathUtils.log10ToLog(1.0); + public static final double DEFAULT_PRUNING_SEEDING_LOG_ODDS_THRESHOLD = MathUtils.log10ToLog(2.0); public static final String ERROR_CORRECT_READS_LONG_NAME = "error-correct-reads"; public static final String PILEUP_ERROR_CORRECTION_LOG_ODDS_NAME = "error-correction-log-odds"; @@ -118,6 +119,13 @@ public abstract class ReadThreadingAssemblerArgumentCollection implements Serial @Argument(fullName="pruning-lod-threshold", doc = "Ln likelihood ratio threshold for adaptive pruning algorithm", optional = true) public double pruningLogOddsThreshold = DEFAULT_PRUNING_LOG_ODDS_THRESHOLD; + /** + * Log-10 likelihood ratio threshold for adaptive pruning algorithm. + */ + @Advanced + @Argument(fullName="pruning-seeding-lod-threshold", doc = "Ln likelihood ratio threshold for seeding subgraph of good variation in adaptive pruning algorithm", optional = true) + public double pruningSeedingLogOddsThreshold = DEFAULT_PRUNING_SEEDING_LOG_ODDS_THRESHOLD; + /** * The maximum number of variants in graph the adaptive pruner will allow */ diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/AdaptiveChainPruner.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/AdaptiveChainPruner.java index 090f9774aab..cfa8606e714 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/AdaptiveChainPruner.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/AdaptiveChainPruner.java @@ -1,23 +1,27 @@ package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs; -import org.apache.commons.math3.util.FastMath; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Multimap; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine; import org.broadinstitute.hellbender.utils.MathUtils; import org.broadinstitute.hellbender.utils.param.ParamUtils; import java.util.*; -import java.util.function.ToDoubleFunction; import java.util.stream.Collectors; public class AdaptiveChainPruner extends ChainPruner { private final double initialErrorProbability; private final double logOddsThreshold; + private final double seedingLogOddsThreshold; // threshold for seeding subgraph of good vertices private final int maxUnprunedVariants; - public AdaptiveChainPruner(final double initialErrorProbability, final double logOddsThreshold, final int maxUnprunedVariants) { + public AdaptiveChainPruner(final double initialErrorProbability, final double logOddsThreshold, final double seedingLogOddsThreshold, final int maxUnprunedVariants) { ParamUtils.isPositive(initialErrorProbability, "Must have positive error probability"); this.initialErrorProbability = initialErrorProbability; this.logOddsThreshold = logOddsThreshold; + this.seedingLogOddsThreshold = seedingLogOddsThreshold; this.maxUnprunedVariants = maxUnprunedVariants; } @@ -29,41 +33,105 @@ protected Collection> chainsToRemove(final List> chains) { final BaseGraph graph = chains.get(0).getGraph(); + + Collection> probableErrorChains = likelyErrorChains(chains, graph, initialErrorProbability); final int errorCount = probableErrorChains.stream().mapToInt(c -> c.getLastEdge().getMultiplicity()).sum(); final int totalBases = chains.stream().mapToInt(c -> c.getEdges().stream().mapToInt(E::getMultiplicity).sum()).sum(); final double errorRate = (double) errorCount / totalBases; - return likelyErrorChains(chains, graph, errorRate); + return likelyErrorChains(chains, graph, errorRate).stream().filter(c -> !c.getEdges().stream().anyMatch(BaseEdge::isRef)).collect(Collectors.toList()); } private Collection> likelyErrorChains(final List> chains, final BaseGraph graph, final double errorRate) { - final Map, Double> chainLogOdds = chains.stream() + /////// NEW + + // pre-compute the left and right log odds of each chain + final Map, Pair> chainLogOdds = chains.stream() .collect(Collectors.toMap(c -> c, c-> chainLogOdds(c, graph, errorRate))); - final Set> result = new HashSet<>(chains.size()); + // compute correspondence of vertices to incident chains with log odds above the seeding and extending thresholds + final Multimap> vertexToSeedableChains = ArrayListMultimap.create(); + final Multimap> vertexToGoodIncomingChains = ArrayListMultimap.create(); + final Multimap> vertexToGoodOutgoingChains = ArrayListMultimap.create(); + + for (final Path chain : chains) { + if (chainLogOdds.get(chain).getRight() >= logOddsThreshold) { + vertexToGoodIncomingChains.put(chain.getLastVertex(), chain); + } + + if (chainLogOdds.get(chain).getLeft() >= logOddsThreshold) { + vertexToGoodOutgoingChains.put(chain.getFirstVertex(), chain); + } + + // seed-worthy chains must pass the more stringent seeding log odds threshold on both sides + // in addition to that, we only seed from vertices with multiple such chains incoming or outgoing (see below) + if (chainLogOdds.get(chain).getRight() >= seedingLogOddsThreshold && chainLogOdds.get(chain).getLeft() >= seedingLogOddsThreshold) { + vertexToSeedableChains.put(chain.getFirstVertex(), chain); + vertexToSeedableChains.put(chain.getLastVertex(), chain); + } + } + + // find a subset of good vertices from which to grow the subgraph of good chains. + final Queue verticesToProcess = new ArrayDeque<>(); + final Path maxWeightChain = getMaxWeightChain(chains); + verticesToProcess.add(maxWeightChain.getFirstVertex()); + verticesToProcess.add(maxWeightChain.getLastVertex()); + + // look for vertices with two incoming or two outgoing chains (plus one outgoing or incoming for a total of 3 or more) with good log odds to seed the subgraph of good vertices + // the logic here is that a high-multiplicity error chain A that branches into a second error chain B and a continuation-of-the-original-error chain A' + // may have a high log odds for A'. However, only in the case of true variation will multiple branches leaving the same vertex have good log odds. + vertexToSeedableChains.keySet().stream() + .filter(v -> vertexToSeedableChains.get(v).size() > 2) + .forEach(verticesToProcess::add); + + final Set processedVertices = new HashSet<>(); + final Set> goodChains = new HashSet<>(); + + // starting from the high-confidence seed vertices, grow the "good" subgraph along chains with above-threshold log odds, + // discovering good chains as we go. + while (!verticesToProcess.isEmpty()) { + final V vertex = verticesToProcess.poll(); + processedVertices.add(vertex); + for (final Path outgoingChain : vertexToGoodOutgoingChains.get(vertex)) { + goodChains.add(outgoingChain); + if(!processedVertices.contains(outgoingChain.getLastVertex())) { + verticesToProcess.add(outgoingChain.getLastVertex()); + } + } - chainLogOdds.forEach((chain, lod) -> { - if (lod < logOddsThreshold) { - result.add(chain); + for (final Path incomingChain : vertexToGoodIncomingChains.get(vertex)) { + goodChains.add(incomingChain); + if(!processedVertices.contains(incomingChain.getFirstVertex())) { + verticesToProcess.add(incomingChain.getFirstVertex()); + } } - }); + } + + final Set> errorChains = chains.stream().filter(c -> !goodChains.contains(c)).collect(Collectors.toSet()); - chains.stream().filter(c -> isChainPossibleVariant(c, graph)) - .sorted(Comparator.comparingDouble((ToDoubleFunction>) chainLogOdds::get) - .reversed().thenComparingInt(Path::length)) + // add non-error chains to error chains if maximum number of variants has been exceeded + chains.stream() + .filter(c -> !errorChains.contains(c)) + .filter(c -> isChainPossibleVariant(c, graph)) + .sorted(Comparator.comparingDouble(c -> Math.min(chainLogOdds.get(c).getLeft(), chainLogOdds.get(c).getLeft())).reversed()) .skip(maxUnprunedVariants) - .forEach(result::add); + .forEach(errorChains::add); - return result; + return errorChains; } - private double chainLogOdds(final Path chain, final BaseGraph graph, final double errorRate) { - if (chain.getEdges().stream().anyMatch(E::isRef)) { - return Double.POSITIVE_INFINITY; - } + // find the chain containing the edge of greatest weight, taking care to break ties deterministically + private Path getMaxWeightChain(final Collection> chains) { + return chains.stream() + .max(Comparator.comparingInt((Path chain) -> chain.getEdges().stream().mapToInt(BaseEdge::getMultiplicity).max().orElse(0)) + .thenComparingInt(Path::length) + .thenComparingInt(Path::hashCode)).get(); + } + // left and right chain log odds + private Pair chainLogOdds(final Path chain, final BaseGraph graph, final double errorRate) { final int leftTotalMultiplicity = MathUtils.sumIntFunction(graph.outgoingEdgesOf(chain.getFirstVertex()), E::getMultiplicity); final int rightTotalMultiplicity = MathUtils.sumIntFunction(graph.incomingEdgesOf(chain.getLastVertex()), E::getMultiplicity); @@ -75,7 +143,7 @@ private double chainLogOdds(final Path chain, final BaseGraph graph, f final double rightLogOdds = graph.isSink(chain.getLastVertex()) ? 0.0 : Mutect2Engine.logLikelihoodRatio(rightTotalMultiplicity - rightMultiplicity, rightMultiplicity, errorRate); - return FastMath.max(leftLogOdds, rightLogOdds); + return ImmutablePair.of(leftLogOdds, rightLogOdds); } private boolean isChainPossibleVariant(final Path chain, final BaseGraph graph) { diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java index 362eb33872f..1f402f2d11e 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java @@ -12,7 +12,6 @@ import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResult; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet; -import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.NearbyKmerErrorCorrector; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReadErrorCorrector; import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.*; import org.broadinstitute.hellbender.utils.Histogram; @@ -75,7 +74,7 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler final boolean dontIncreaseKmerSizesForCycles, final boolean allowNonUniqueKmersInRef, final int numPruningSamples, final int pruneFactor, final boolean useAdaptivePruning, final double initialErrorRateForPruning, final double pruningLogOddsThreshold, - final int maxUnprunedVariants, final boolean useLinkedDebruijnGraphs) { + final double pruningSeedingLogOddsThreshold, final int maxUnprunedVariants, final boolean useLinkedDebruijnGraphs) { Utils.validateArg( maxAllowedPathsForReadThreadingAssembler >= 1, "numBestHaplotypesPerGraph should be >= 1 but got " + maxAllowedPathsForReadThreadingAssembler); this.kmerSizes = new ArrayList<>(kmerSizes); kmerSizes.sort(Integer::compareTo); @@ -88,14 +87,14 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler logger.error("JunctionTreeLinkedDeBruijnGraph is enabled.\n This is an experimental assembly graph mode that has not been fully validated\n\n"); } - chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, pruningLogOddsThreshold, maxUnprunedVariants) : + chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants) : new LowWeightChainPruner<>(pruneFactor); numBestHaplotypesPerGraph = maxAllowedPathsForReadThreadingAssembler; } @VisibleForTesting ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List kmerSizes, final int pruneFactor) { - this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor, false, 0.001, 2, Integer.MAX_VALUE, false); + this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor, false, 0.001, 2, 2, Integer.MAX_VALUE, false); } @VisibleForTesting diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/ChainPrunerUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/ChainPrunerUnitTest.java index 22a87e96e35..4b367a28bb0 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/ChainPrunerUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/ChainPrunerUnitTest.java @@ -148,7 +148,7 @@ public void testAdaptivePruning(final int kmerSize, final byte[] ref, final byte // note: these are the steps in ReadThreadingAssembler::createGraph graph.buildGraphIfNecessary(); - final ChainPruner pruner = new AdaptiveChainPruner<>(0.001, logOddsThreshold, 50); + final ChainPruner pruner = new AdaptiveChainPruner<>(0.001, logOddsThreshold, 2.0, 50); pruner.pruneLowWeightChains(graph); final SmithWatermanAligner aligner = SmithWatermanJavaAligner.getInstance(); @@ -175,6 +175,47 @@ public void testAdaptivePruning(final int kmerSize, final byte[] ref, final byte Assert.assertTrue(bestPaths.size() < 15); } + // test that in graph with good path A -> B -> C and bad edges A -> D -> C and D -> B that the adjacency of bad edges -- + // such that when bad edges meet the multiplicities do not indicate an error - does not harm pruning. + // we test with and without a true variant path A -> E -> C + @Test + public void testAdaptivePruningWithAdjacentBadEdges() { + final int goodMultiplicity = 1000; + final int variantMultiplicity = 50; + final int badMultiplicity = 5; + + final SeqVertex source = new SeqVertex("source"); + final SeqVertex sink = new SeqVertex("sink"); + final SeqVertex A = new SeqVertex("A"); + final SeqVertex B = new SeqVertex("B"); + final SeqVertex C = new SeqVertex("C"); + final SeqVertex D = new SeqVertex("D"); + final SeqVertex E = new SeqVertex("E"); + + + for (boolean variantPresent : new boolean[] {false, true}) { + final SeqGraph graph = new SeqGraph(20); + + graph.addVertices(source, A, B, C, D, sink); + graph.addEdges(() -> new BaseEdge(true, goodMultiplicity), source, A, B, C, sink); + graph.addEdges(() -> new BaseEdge(false, badMultiplicity), A, D, C); + graph.addEdges(() -> new BaseEdge(false, badMultiplicity), D, B); + + if (variantPresent) { + graph.addVertices(E); + graph.addEdges(() -> new BaseEdge(false, variantMultiplicity), A, E, C); + } + + final ChainPruner pruner = new AdaptiveChainPruner<>(0.01, 2.0, 2.0, 50); + pruner.pruneLowWeightChains(graph); + + Assert.assertFalse(graph.containsVertex(D)); + if (variantPresent) { + Assert.assertTrue(graph.containsVertex(E)); + } + } + } + @DataProvider(name = "chainPrunerData") public Object[][] getChainPrunerData() { final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(9)); diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java index 283130d55eb..745e862314e 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java @@ -507,7 +507,7 @@ public void testMitochondria() { "chrM:750-750 A*, [G]"); Assert.assertTrue(variantKeys.containsAll(expectedKeys)); - Assert.assertEquals(variants.get(0).getAttributeAsInt(GATKVCFConstants.ORIGINAL_CONTIG_MISMATCH_KEY, 0), 1672); + Assert.assertEquals(variants.get(0).getAttributeAsInt(GATKVCFConstants.ORIGINAL_CONTIG_MISMATCH_KEY, 0), 1664); } @DataProvider(name = "vcfsForFiltering") @@ -703,7 +703,7 @@ public void testMitochondrialRefConf() { //ref blocks will be dependent on TLOD band values "chrM:218-218 A*, []", "chrM:264-266 C*, []", - "chrM:479-483 A*, []", + "chrM:475-483 A*, []", "chrM:488-492 T*, []"); //ref block boundaries aren't particularly stable, so try a few and make sure we check at least one