Skip to content

Commit

Permalink
Adaptive pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin committed Dec 5, 2018
1 parent d1a6ebc commit 25da765
Show file tree
Hide file tree
Showing 15 changed files with 509 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ public static ReadLikelihoodCalculationEngine createLikelihoodCalculationEngine(
public static ReadThreadingAssembler createReadThreadingAssembler(final AssemblyBasedCallerArgumentCollection args) {
final ReadThreadingAssemblerArgumentCollection rtaac = args.assemblerArgs;
final ReadThreadingAssembler assemblyEngine = new ReadThreadingAssembler(rtaac.maxNumHaplotypesInPopulation, rtaac.kmerSizes,
rtaac.dontIncreaseKmerSizesForCycles, rtaac.allowNonUniqueKmersInRef, rtaac.numPruningSamples, rtaac.minPruneFactor);
rtaac.dontIncreaseKmerSizesForCycles, rtaac.allowNonUniqueKmersInRef, rtaac.numPruningSamples, rtaac.minPruneFactor,
rtaac.useAdaptivePruning, rtaac.initialErrorRateForPruning, rtaac.pruningLog10OddsThreshold, rtaac.maxUnprunedVariants);
assemblyEngine.setErrorCorrectKmers(rtaac.errorCorrectKmers);
assemblyEngine.setDebug(args.debug);
assemblyEngine.setDebugGraphTransformations(rtaac.debugGraphTransformations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ public final class ReadThreadingAssemblerArgumentCollection implements Serializa
@Argument(fullName="min-pruning", doc = "Minimum support to not prune paths in the graph", optional = true)
public int minPruneFactor = 2;

/**
* A single edge multiplicity cutoff for pruning doesn't work in samples with variable depths, for example exomes
* and RNA. This parameter activates a probabilistic algorithm for pruning the assembly graph that considers the
* likelihood that each chain in the graph comes from real variation.
*/
@Advanced
@Argument(fullName="adaptive-pruning", doc = "Use an adaptive algorithm for pruning paths in the graph", optional = true)
public boolean useAdaptivePruning = false;

/**
* Initial base error rate guess for the probabilistic adaptive pruning model. Results are not very sensitive to this
* parameter because it is only a starting point from which the algorithm discovers the true error rate.
*/
@Advanced
@Argument(fullName="adaptive-pruning-initial-error-rate", doc = "Initial base error rate estimate for adaptive pruning", optional = true)
public double initialErrorRateForPruning = 0.001;

/**
* Log-10 likelihood ratio threshold for adaptive pruning algorithm.
*/
@Advanced
@Argument(fullName="pruning-lod-threshold", doc = "Log-10 likelihood ratio threshold for adaptive pruning algorithm", optional = true)
public double pruningLog10OddsThreshold = 1.0;

/**
* The maximum number of variants in graph the adaptive pruner will allow
*/
@Advanced
@Argument(fullName="max-unpruned-variants", doc = "Maximum number of variants in graph the adaptive pruner will allow", optional = true)
public int maxUnprunedVariants = 100;

@Hidden
@Argument(fullName="debug-graph-transformations", doc="Write DOT formatted graph files out of the assembler for only this graph size", optional = true)
public boolean debugGraphTransformations = false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class AdaptiveChainPruner<V extends BaseVertex, E extends BaseEdge> extends ChainPruner<V,E> {
private final double initialErrorProbability;
private final double logOddsThreshold;
private final int maxUnprunedVariants;

public AdaptiveChainPruner(final double initialErrorProbability, final double logOddsThreshold, final int maxUnprunedVariants) {
ParamUtils.isPositive(initialErrorProbability, "Must have positive error probability");
this.initialErrorProbability = initialErrorProbability;
this.logOddsThreshold = logOddsThreshold;
this.maxUnprunedVariants = maxUnprunedVariants;
}

@Override
protected Collection<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains) {
if (chains.isEmpty()) {
return Collections.emptyList();
}

final BaseGraph<V,E> graph = chains.get(0).getGraph();

Collection<Path<V,E>> probableErrorChains = likelyErrorChains(chains, graph, initialErrorProbability);
final int errorCount = probableErrorChains.stream().mapToInt(c -> c.getLastEdge().getMultiplicity()).sum();
final int totalBases = chains.stream().mapToInt(c -> c.getEdges().stream().mapToInt(E::getMultiplicity).sum()).sum();
final double errorRate = (double) errorCount / totalBases;

return likelyErrorChains(chains, graph, errorRate);
}

private Collection<Path<V,E>> likelyErrorChains(final List<Path<V, E>> chains, final BaseGraph<V,E> graph, final double errorRate) {
final Map<Path<V,E>, Double> chainLogOdds = chains.stream()
.collect(Collectors.toMap(c -> c, c-> chainLogOdds(c, graph, errorRate)));

final Set<Path<V,E>> result = new HashSet<>(chains.size());

chainLogOdds.forEach((chain, lod) -> {
if (lod < logOddsThreshold) {
result.add(chain);
}
});

chains.stream().filter(c -> isChainPossibleVariant(c, graph))
.sorted(Comparator.comparingDouble(chainLogOdds::get).reversed())
.skip(maxUnprunedVariants)
.forEach(result::add);

return result;

}

private double chainLogOdds(final Path<V,E> chain, final BaseGraph<V,E> graph, final double errorRate) {
if (chain.getEdges().stream().anyMatch(E::isRef)) {
return Double.POSITIVE_INFINITY;
}

final int leftTotalMultiplicity = MathUtils.sumIntFunction(graph.outgoingEdgesOf(chain.getFirstVertex()), E::getMultiplicity);
final int rightTotalMultiplicity = MathUtils.sumIntFunction(graph.incomingEdgesOf(chain.getLastVertex()), E::getMultiplicity);

final int leftMultiplicity = chain.getEdges().get(0).getMultiplicity();
final int rightMultiplicity = chain.getLastEdge().getMultiplicity();

final double leftLogOdds = graph.isSource(chain.getFirstVertex()) ? 0.0 :
Mutect2Engine.lnLikelihoodRatio(leftTotalMultiplicity - leftMultiplicity, leftMultiplicity, errorRate);
final double rightLogOdds = graph.isSink(chain.getLastVertex()) ? 0.0 :
Mutect2Engine.lnLikelihoodRatio(rightTotalMultiplicity - rightMultiplicity, rightMultiplicity, errorRate);

return FastMath.max(leftLogOdds, rightLogOdds);
}

// is the chain
private boolean isChainPossibleVariant(final Path<V,E> chain, final BaseGraph<V,E> graph) {
final int leftTotalMultiplicity = MathUtils.sumIntFunction(graph.outgoingEdgesOf(chain.getFirstVertex()), E::getMultiplicity);
final int rightTotalMultiplicity = MathUtils.sumIntFunction(graph.incomingEdgesOf(chain.getLastVertex()), E::getMultiplicity);

final int leftMultiplicity = chain.getEdges().get(0).getMultiplicity();
final int rightMultiplicity = chain.getLastEdge().getMultiplicity();

return leftMultiplicity <= leftTotalMultiplicity / 2 || rightMultiplicity <= rightTotalMultiplicity / 2;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import com.google.common.annotations.VisibleForTesting;

import java.util.*;


Expand All @@ -8,12 +10,13 @@ public ChainPruner() { }

public void pruneLowWeightChains(final BaseGraph<V,E> graph) {
final List<Path<V, E>> chains = findAllChains(graph);
final List<Path<V, E>> chainsToRemove = chainsToRemove(chains);
final Collection<Path<V, E>> chainsToRemove = chainsToRemove(chains);
chainsToRemove.forEach(c -> graph.removeAllEdges(c.getEdges()));
graph.removeSingletonOrphanVertices();
}

private List<Path<V, E>> findAllChains(BaseGraph<V, E> graph) {
@VisibleForTesting
List<Path<V, E>> findAllChains(BaseGraph<V, E> graph) {
final Deque<V> chainStarts = new LinkedList<>(graph.getSources());
final List<Path<V,E>> chains = new LinkedList<>();
final Set<V> alreadySeen = new HashSet<>(chainStarts);
Expand Down Expand Up @@ -56,5 +59,5 @@ private Path<V,E> findChain(final E startEdge, final BaseGraph<V, E> graph) {
return new Path<V, E>(edges, lastVertex, graph);
}

protected abstract List<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains);
protected abstract Collection<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public LowWeightChainPruner(final int pruneFactor) {
}

@Override
protected List<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains) {
protected Collection<Path<V,E>> chainsToRemove(final List<Path<V, E>> chains) {
return chains.stream().filter(this::needsPruning).collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReadErrorCorrector;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.*;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
Expand Down Expand Up @@ -59,7 +60,7 @@ public final class ReadThreadingAssembler {

protected byte minBaseQualityToUseInAssembly = DEFAULT_MIN_BASE_QUALITY_TO_USE;
private int pruneFactor;
private final LowWeightChainPruner<MultiDeBruijnVertex, MultiSampleEdge> lowWeightChainPruner;
private final ChainPruner<MultiDeBruijnVertex, MultiSampleEdge> chainPruner;

protected boolean errorCorrectKmers = false;

Expand All @@ -68,20 +69,23 @@ public final class ReadThreadingAssembler {

public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes,
final boolean dontIncreaseKmerSizesForCycles, final boolean allowNonUniqueKmersInRef,
final int numPruningSamples, final int pruneFactor) {
final int numPruningSamples, final int pruneFactor, final boolean useAdaptivePruning,
final double initialErrorRateForPruning, final double pruningLog10OddsThreshold,
final int maxUnprunedVariants) {
Utils.validateArg( maxAllowedPathsForReadThreadingAssembler >= 1, "numBestHaplotypesPerGraph should be >= 1 but got " + maxAllowedPathsForReadThreadingAssembler);
this.kmerSizes = kmerSizes;
this.dontIncreaseKmerSizesForCycles = dontIncreaseKmerSizesForCycles;
this.allowNonUniqueKmersInRef = allowNonUniqueKmersInRef;
this.numPruningSamples = numPruningSamples;
this.pruneFactor = pruneFactor;
lowWeightChainPruner = new LowWeightChainPruner<>(pruneFactor);
chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, MathUtils.log10ToLog(pruningLog10OddsThreshold), maxUnprunedVariants) :
new LowWeightChainPruner<>(pruneFactor);
numBestHaplotypesPerGraph = maxAllowedPathsForReadThreadingAssembler;
}

@VisibleForTesting
ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes, final int pruneFactor) {
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor);
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, true, 1, pruneFactor, false, 0.001, 2, Integer.MAX_VALUE);
}

@VisibleForTesting
Expand Down Expand Up @@ -492,7 +496,7 @@ private AssemblyResult getAssemblyResult(final Haplotype refHaplotype, final int
// prune all of the chains where all edges have multiplicity < pruneFactor. This must occur
// before recoverDanglingTails in the graph, so that we don't spend a ton of time recovering
// tails that we'll ultimately just trim away anyway, as the dangling tail edges have weight of 1
lowWeightChainPruner.pruneLowWeightChains(rtgraph);
chainPruner.pruneLowWeightChains(rtgraph);

// look at all chains in the graph that terminate in a non-ref node (dangling sources and sinks) and see if
// we can recover them by merging some N bases from the chain back into the reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ private void resetToInitialState() {
* @param sequence a non-null sequence
* @param isRef is this the reference sequence?
*/
final void addSequence(final byte[] sequence, final boolean isRef) {
@VisibleForTesting
public final void addSequence(final byte[] sequence, final boolean isRef) {
addSequence("anonymous", sequence, isRef);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.VariantContext;
Expand Down Expand Up @@ -274,7 +275,7 @@ public ActivityProfileState isActive(final AlignmentContext context, final Refer
final ReadPileup pileup = context.getBasePileup();
final ReadPileup tumorPileup = pileup.getPileupForSample(tumorSample, header);
final List<Byte> tumorAltQuals = altQuals(tumorPileup, refBase, MTAC.initialPCRErrorQual);
final double tumorLog10Odds = MathUtils.logToLog10(lnLikelihoodRatio(tumorPileup.size()-tumorAltQuals.size(), tumorAltQuals));
final double tumorLog10Odds = MathUtils.logToLog10(lnLikelihoodRatio(tumorPileup.size() - tumorAltQuals.size(), tumorAltQuals));

if (tumorLog10Odds < MTAC.getInitialLod()) {
return new ActivityProfileState(refInterval, 0.0);
Expand Down Expand Up @@ -349,10 +350,17 @@ private static List<Byte> altQuals(final ReadPileup pileup, final byte refBase,
return result;
}

// this implements the isActive() algorithm described in docs/mutect/mutect.pdf
private static double lnLikelihoodRatio(final int refCount, final List<Byte> altQuals) {
return lnLikelihoodRatio(refCount, altQuals, 1);
}

// this implements the isActive() algorithm described in docs/mutect/mutect.pdf
// the multiplicative factor is for the special case where we pass a singleton list
// of alt quals and want to duplicate that alt qual over multiple reads
@VisibleForTesting
static double lnLikelihoodRatio(final int refCount, final List<Byte> altQuals, final double repeatFactor) {
final double beta = refCount + 1;
final double alpha = altQuals.size() + 1;
final double alpha = repeatFactor * altQuals.size() + 1;
final double digammaAlpha = Gamma.digamma(alpha);
final double digammaBeta = Gamma.digamma(beta);
final double digammaAlphaPlusBeta = Gamma.digamma(alpha + beta);
Expand All @@ -361,21 +369,26 @@ private static double lnLikelihoodRatio(final int refCount, final List<Byte> alt
final double lnTau = digammaAlpha - digammaAlphaPlusBeta;
final double tau = FastMath.exp(lnTau);



final double betaEntropy = Beta.logBeta(alpha, beta) - (alpha - 1)*digammaAlpha - (beta-1)*digammaBeta + (alpha + beta - 2)*digammaAlphaPlusBeta;

final double result = betaEntropy + refCount * lnRho + altQuals.stream().mapToDouble(qual -> {
double readSum = 0;
for (final byte qual : altQuals) {
final double epsilon = QualityUtils.qualToErrorProb(qual);
final double gamma = rho * epsilon / (rho * epsilon + tau * (1-epsilon));
final double bernoulliEntropy = -gamma * FastMath.log(gamma) - (1-gamma)*FastMath.log1p(-gamma);
final double lnEpsilon = MathUtils.log10ToLog(QualityUtils.qualToErrorProbLog10(qual));
final double lnOneMinusEpsilon = MathUtils.log10ToLog(QualityUtils.qualToProbLog10(qual));
return gamma * (lnRho + lnEpsilon) + (1-gamma)*(lnTau + lnOneMinusEpsilon) - lnEpsilon + bernoulliEntropy;
}).sum();
readSum += gamma * (lnRho + lnEpsilon) + (1-gamma)*(lnTau + lnOneMinusEpsilon) - lnEpsilon + bernoulliEntropy;
}

return result;
return betaEntropy + refCount * lnRho + readSum * repeatFactor;

}

// same as above but with a constant error probability for several alts
public static double lnLikelihoodRatio(final int refCount, final int altCount, final double errorProbability) {
final byte qual = QualityUtils.errorProbToQual(errorProbability);
return lnLikelihoodRatio(refCount, Collections.singletonList(qual), altCount);
}

// check that we're next to a soft clip that is not due to a read that got out of sync and ended in a bunch of BQ2's
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/broadinstitute/hellbender/utils/MathUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,22 @@ public static double sum(final double[] arr, final int start, final int stop) {
return result;
}

public static <E> double sumDoubleFunction(final Collection<E> collection, final ToDoubleFunction<E> function) {
double result = 0;
for (final E e: collection) {
result += function.applyAsDouble(e);
}
return result;
}

public static <E> int sumIntFunction(final Collection<E> collection, final ToIntFunction<E> function) {
int result = 0;
for (final E e: collection) {
result += function.applyAsInt(e);
}
return result;
}

/**
* Compares double values for equality (within 1e-6), or inequality.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import static org.testng.Assert.*;

public class AdaptiveChainPrunerUnitTest {

}
Loading

0 comments on commit 25da765

Please sign in to comment.