diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java index eb1c5275c80..5dd0a617cb7 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java @@ -9,8 +9,11 @@ import htsjdk.variant.vcf.VCFHeader; import htsjdk.variant.vcf.VCFHeaderLine; import htsjdk.variant.vcf.VCFStandardHeaderLines; +import it.unimi.dsi.fastutil.bytes.ByteArrayList; +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import org.apache.commons.lang3.mutable.MutableLong; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.math3.special.Gamma; import org.apache.commons.math3.util.FastMath; import org.apache.logging.log4j.LogManager; @@ -51,6 +54,7 @@ import java.nio.file.Path; import java.util.*; import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * Created by davidben on 9/15/16. @@ -105,6 +109,9 @@ public final class Mutect2Engine implements AssemblyRegionEvaluator { private final Optional f1R2CountsCollector; + private PileupQualBuffer tumorPileupQualBuffer = new PileupQualBuffer(); + private PileupQualBuffer normalPileupQualBuffer = new PileupQualBuffer(); + /** * Create and initialize a new HaplotypeCallerEngine given a collection of HaplotypeCaller arguments, a reads header, * and a reference file @@ -358,6 +365,9 @@ public ActivityProfileState isActive(final AlignmentContext context, final Refer if ( forceCallingAllelesPresent && features.getValues(MTAC.alleles, ref).stream().anyMatch(vc -> MTAC.forceCallFiltered || vc.isNotFiltered())) { return new ActivityProfileState(ref.getInterval(), 1.0); } + if (ref.getStart() == 2916255) { + int f = 90; + } final byte refBase = ref.getBase(); final SimpleInterval refInterval = ref.getInterval(); @@ -372,18 +382,22 @@ public ActivityProfileState isActive(final AlignmentContext context, final Refer } final ReadPileup tumorPileup = pileup.makeFilteredPileup(pe -> isTumorSample(ReadUtils.getSampleName(pe.getRead(), header))); f1R2CountsCollector.ifPresent(collector -> collector.process(tumorPileup, ref)); - final List tumorAltQuals = altQuals(tumorPileup, refBase, MTAC.pcrSnvQual); - final double tumorLogOdds = logLikelihoodRatio(tumorPileup.size() - tumorAltQuals.size(), tumorAltQuals); + tumorPileupQualBuffer.accumulateQuals(tumorPileup, refBase, MTAC.pcrSnvQual); + final Pair bestTumorAltAllele = tumorPileupQualBuffer.likeliestIndexAndQuals(); + final double tumorLogOdds = logLikelihoodRatio(tumorPileup.size() - bestTumorAltAllele.getRight().size(), bestTumorAltAllele.getRight()); if (tumorLogOdds < MTAC.getInitialLogOdds()) { return new ActivityProfileState(refInterval, 0.0); } else if (hasNormal() && !MTAC.genotypeGermlineSites) { final ReadPileup normalPileup = pileup.makeFilteredPileup(pe -> isNormalSample(ReadUtils.getSampleName(pe.getRead(), header))); - final List normalAltQuals = altQuals(normalPileup, refBase, MTAC.pcrSnvQual); - final int normalAltCount = normalAltQuals.size(); - final double normalQualSum = normalAltQuals.stream().mapToDouble(Byte::doubleValue).sum(); - if (normalAltCount > normalPileup.size() * MAX_ALT_FRACTION_IN_NORMAL && normalQualSum > MAX_NORMAL_QUAL_SUM) { - return new ActivityProfileState(refInterval, 0.0); + normalPileupQualBuffer.accumulateQuals(normalPileup, refBase, MTAC.pcrSnvQual); + final Pair bestNormalAltAllele = normalPileupQualBuffer.likeliestIndexAndQuals(); + if (bestNormalAltAllele.getLeft() == bestNormalAltAllele.getLeft()) { + final int normalAltCount = bestNormalAltAllele.getRight().size(); + final double normalQualSum = normalPileupQualBuffer.qualSum(bestNormalAltAllele.getLeft()); + if (normalAltCount > normalPileup.size() * MAX_ALT_FRACTION_IN_NORMAL && normalQualSum > MAX_NORMAL_QUAL_SUM) { + return new ActivityProfileState(refInterval, 0.0); + } } } else if (!MTAC.genotypeGermlineSites) { final List germline = features.getValues(MTAC.germlineResource, refInterval); @@ -458,33 +472,10 @@ private static byte indelQual(final int indelLength) { return (byte) Math.min(INDEL_START_QUAL + (indelLength - 1) * INDEL_CONTINUATION_QUAL, Byte.MAX_VALUE); } - private static List altQuals(final ReadPileup pileup, final byte refBase, final int pcrErrorQual) { - final List result = new ArrayList<>(); - final int position = pileup.getLocation().getStart(); - - for (final PileupElement pe : pileup) { - final int indelLength = getCurrentOrFollowingIndelLength(pe); - if (indelLength > 0) { - result.add(indelQual(indelLength)); - } else if (isNextToUsefulSoftClip(pe)) { - result.add(indelQual(1)); - } else if (pe.getBase() != refBase && pe.getQual() > MINIMUM_BASE_QUALITY) { - final GATKRead read = pe.getRead(); - final int mateStart = (!read.isProperlyPaired() || read.mateIsUnmapped()) ? Integer.MAX_VALUE : read.getMateStart(); - final boolean overlapsMate = mateStart <= position && position < mateStart + read.getLength(); - result.add(overlapsMate ? (byte) FastMath.min(pe.getQual(), pcrErrorQual/2) : pe.getQual()); - } - } - - return result; - } - public static double logLikelihoodRatio(final int refCount, final List altQuals) { return logLikelihoodRatio(refCount, altQuals, 1); } - - /** * this implements the isActive() algorithm described in docs/mutect/mutect.pdf * the multiplicative factor is for the special case where we pass a singleton list @@ -552,4 +543,81 @@ private void checkSampleInBamHeader(final String sample) { private String decodeSampleNameIfNecessary(final String name) { return samplesList.asListOfSamples().contains(name) ? name : IOUtils.urlDecode(name); } + + /** + * A resuable container class to accumulate qualities for each type of SNV and indels (all indels combined) + */ + private static class PileupQualBuffer { + private static final int OTHER_SUBSTITUTION = 4; + private static final int INDEL = 5; + + // our pileup likelihoods models assume that the qual corresponds to the probability that a ref base is misread + // as the *particular* alt base, whereas the qual actually means the probability of *any* substitution error. + // since there are three possible substitutions for each ref base we must divide the error probability by three + // which corresponds to adding 10*log10(3) = 4.77 ~ 5 to the qual. + private static final int ONE_THIRD_QUAL_CORRECTION = 5; + + // indices 0-3 are A,C,G,T; 4 is other substitution (just in case it's some exotic protocol); 5 is indel + private List buffers = IntStream.range(0,6).mapToObj(n -> new ByteArrayList()).collect(Collectors.toList()); + + public PileupQualBuffer() { } + + public void accumulateQuals(final ReadPileup pileup, final byte refBase, final int pcrErrorQual) { + clear(); + final int position = pileup.getLocation().getStart(); + + for (final PileupElement pe : pileup) { + final int indelLength = getCurrentOrFollowingIndelLength(pe); + if (indelLength > 0) { + accumulateIndel(indelQual(indelLength)); + } else if (isNextToUsefulSoftClip(pe)) { + accumulateIndel(indelQual(1)); + } else if (pe.getBase() != refBase && pe.getQual() > MINIMUM_BASE_QUALITY) { + final GATKRead read = pe.getRead(); + final int mateStart = (!read.isProperlyPaired() || read.mateIsUnmapped()) ? Integer.MAX_VALUE : read.getMateStart(); + final boolean overlapsMate = mateStart <= position && position < mateStart + read.getLength(); + accumulateSubstitution(pe.getBase(), overlapsMate ? (byte) FastMath.min(pe.getQual(), pcrErrorQual/2) : pe.getQual()); + } + } + } + + public Pair likeliestIndexAndQuals() { + int bestIndex = 0; + long bestSum = 0; + for (int n = 0; n < buffers.size(); n++) { + final long sum = qualSum(n); + if (sum > bestSum) { + bestSum = sum; + bestIndex = n; + } + } + return ImmutablePair.of(bestIndex, buffers.get(bestIndex)); + } + + private void accumulateSubstitution(final byte base, final byte qual) { + final int index = BaseUtils.simpleBaseToBaseIndex(base); + if (index == -1) { // -1 is the hard-coded value for non-simple bases in BaseUtils + buffers.get(OTHER_SUBSTITUTION).add(qual); + } else { + buffers.get(index).add((byte) FastMath.min(qual + ONE_THIRD_QUAL_CORRECTION, QualityUtils.MAX_QUAL)); + } + } + + private void accumulateIndel(final byte qual) { + buffers.get(INDEL).add(qual); + } + + private void clear() { + buffers.forEach(ByteArrayList::clear); + } + + public long qualSum(final int index) { + final ByteArrayList list = buffers.get(index); + long result = 0; + for (int n = 0; n < list.size(); n++) { + result += list.getByte(n); + } + return result; + } + } } diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2EngineUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2EngineUnitTest.java index 93934b39faf..d8a910ba3d6 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2EngineUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2EngineUnitTest.java @@ -2,6 +2,7 @@ import org.apache.commons.math.special.Beta; import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.utils.QualityUtils; import org.testng.Assert; import org.testng.annotations.DataProvider; @@ -17,6 +18,10 @@ public class Mutect2EngineUnitTest extends GATKBaseTest { * Test the active region log likelihood ratio to lowest order. From the Mutect docs the no-variation likelihood is * prod_alt(eps_j), where eps_j is the error probability of read j and alt means "over all alt reads". * + * Note that, as in the docs, this assumes the approximation that ref reads have infinite quality -- that is, we + * don't try to squeeze out the last bit of extra variant likelihood by accounting for the chance that some alt reads + * were actually misread as ref reads. + * * The likelihood for a variant with allele fraction f is (1-f)^N_ref * prod_alt [f(1-eps_j) + (1-f)eps_j]. * * To leading order in epsilon the log likelihood ratio is @@ -62,4 +67,82 @@ public Object[][] getLeadingOrderData() { }; } + /** + * We can also (again using the perfect ref reads approximation) exactly calculate the integral when the alt count is small. + * Assuming a constant error rate eps for the alt reads we have a variant likelihood + * int_(0 to 1) (1-f)^N_ref * [f(1-eps_j) + (1-f)eps_j]^N_alt df + * + * We can expand the binomial raised to the N_alt power explicitly for small N_alt to obtain the sum of N_alt terms + * in the integrand, each of which is the normalization constant of a Beta distribution with integer shape parameters. + * @param numRef + * @param errorRate + */ + @Test(dataProvider = "fewAltData") + public void testSmallNumAltExact(final int numRef, final double errorRate) { + + for (final int numAlt : new int[] {0, 1, 2}) { + final double calculated = Mutect2Engine.logLikelihoodRatio(numRef, numAlt, errorRate); + final double expected; + switch (numAlt) { + case 0: + expected = Math.log(1.0 / (numRef + 1));; + break; + case 1: + expected = Math.log((errorRate / (numRef + 2)) + (1 - errorRate) /((numRef + 1) * (numRef + 2))) + - Math.log(errorRate); + break; + case 2: + expected = Math.log( (errorRate*errorRate/(numRef + 3)) + (2*errorRate*(1-errorRate)/((numRef+2)*(numRef+3))) + + (2*(1-errorRate)*(1-errorRate)/((numRef+1)*(numRef+2)*(numRef+3)))) - 2*Math.log(errorRate); + break; + default: + throw new GATKException.ShouldNeverReachHereException("Didn't write this test case"); + + } + + // we don't really care about high accuracy if things are obvious: + if (expected < - 3.0) { + Assert.assertTrue(calculated < -1.0); + continue; + } + + final double precision; + if (expected < -2) { + precision = 2.0; + } else if (expected < 0) { + precision = 1.0; + } else if (expected < 1) { + precision = 0.5; + } else { + precision = 0.25; + } + + Assert.assertEquals(calculated, expected, precision); + } + + + + } + + @DataProvider(name = "fewAltData") + public Object[][] getFewAltData() { + return new Object[][] { + { 1, 0.0001}, + { 5, 0.0001}, + { 10, 0.0001}, + { 100, 0.0001}, + { 1000, 0.0001}, + { 1, 0.001}, + { 5, 0.001}, + { 10, 0.001}, + { 100, 0.001}, + { 1000, 0.001}, + { 1, 0.01}, + { 5, 0.01}, + { 10, 0.01}, + { 100, 0.01}, + { 1000, 0.01}, + }; + } + } \ No newline at end of file diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java index 283130d55eb..02c9c7286f6 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2IntegrationTest.java @@ -162,7 +162,8 @@ public void testDreamTumorNormal(final File tumor, final Optional normal, args -> runOrientationFilter ? args.add(M2FiltersArgumentCollection.ARTIFACT_PRIOR_TABLE_NAME, orientationModel) : args); final File concordanceSummary = createTempFile("concordance", ".txt"); - runConcordance(truth, filteredVcf,concordanceSummary, CHROMOSOME_20, mask); + final File truePositivesFalseNegatives = createTempFile("tpfn", ".vcf"); + runConcordance(truth, filteredVcf,concordanceSummary, CHROMOSOME_20, mask, Optional.of (truePositivesFalseNegatives)); final List summaryRecords = new ConcordanceSummaryRecord.Reader(concordanceSummary.toPath()).toList(); summaryRecords.forEach(rec -> { @@ -216,7 +217,7 @@ public void testTwoDreamTumorSamples(final List tumors, final List n runFilterMutectCalls(unfilteredVcf, filteredVcf, b37Reference); final File concordanceSummary = createTempFile("concordance", ".txt"); - runConcordance(truth, filteredVcf, concordanceSummary, CHROMOSOME_20, mask); + runConcordance(truth, filteredVcf, concordanceSummary, CHROMOSOME_20, mask, Optional.empty()); final List summaryRecords = new ConcordanceSummaryRecord.Reader(concordanceSummary.toPath()).toList(); summaryRecords.forEach(rec -> { @@ -882,13 +883,15 @@ final private void runFilterMutectCalls(final File unfilteredVcf, final File fil runCommandLine(argsWithAdditions, FilterMutectCalls.class.getSimpleName()); } - private void runConcordance(final File truth, final File eval, final File summary, final String interval, final File mask) { + private void runConcordance(final File truth, final File eval, final File summary, final String interval, final File mask, final Optional truePositivesFalseNegatives) { final ArgumentsBuilder concordanceArgs = new ArgumentsBuilder() .add(Concordance.TRUTH_VARIANTS_LONG_NAME, truth) .add(Concordance.EVAL_VARIANTS_LONG_NAME, eval) .addInterval(new SimpleInterval(interval)) .add(IntervalArgumentCollection.EXCLUDE_INTERVALS_LONG_NAME, mask) .add(Concordance.SUMMARY_LONG_NAME, summary); + + truePositivesFalseNegatives.ifPresent(file -> concordanceArgs.add(Concordance.TRUE_POSITIVES_AND_FALSE_NEGATIVES_SHORT_NAME, file)); runCommandLine(concordanceArgs, Concordance.class.getSimpleName()); }