Skip to content

Commit

Permalink
cnn variant update models validate scores cleanup training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Oct 11, 2018
1 parent ce669d1 commit 65d0edd
Show file tree
Hide file tree
Showing 26 changed files with 896 additions and 327 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final barclayVersion = System.getProperty('barclay.version','2.1.0')
final sparkVersion = System.getProperty('spark.version', '2.2.0')
final hadoopVersion = System.getProperty('hadoop.version', '2.8.2')
final hadoopBamVersion = System.getProperty('hadoopBam.version','7.10.0')
final tensorflowVersion = System.getProperty('tensorflow.version','1.4.0')
final tensorflowVersion = System.getProperty('tensorflow.version','1.9.0')
final genomicsdbVersion = System.getProperty('genomicsdb.version','0.10.0-proto-3.0.0-beta-1+bdce8be25b873')
final testNGVersion = '6.11'
// Using the shaded version to avoid conflicts between its protobuf dependency
Expand Down
5 changes: 2 additions & 3 deletions scripts/gatkcondaenv.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies:
- h5py==2.7.1
- html5lib==0.9999999
- joblib==0.11
- keras==2.1.4
- keras==2.2.0
- markdown==2.6.9
- matplotlib==2.1.0
- numpy==1.13.3
Expand All @@ -44,8 +44,7 @@ dependencies:
- scipy==1.0.0
- six==1.11.0
- $tensorFlowDependency
- tensorflow-tensorboard==0.4.0rc3
- theano==0.9.0
- tqdm==4.19.4
- werkzeug==0.12.2
- gatkPythonPackageArchive.zip
- gatkPythonPackageArchive.zip
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package org.broadinstitute.hellbender.tools.walkers.validation;

import java.nio.file.Paths;
import java.io.IOException;

import org.apache.commons.collections4.Predicate;

import htsjdk.variant.variantcontext.VariantContext;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;

import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.engine.AbstractConcordanceWalker;

import picard.cmdline.programgroups.VariantEvaluationProgramGroup;

/**
* Compare INFO field values between two VCFs or compare two different INFO fields from one VCF.
* We only evaluate sites that are in both VCFs.
* Although we use the arguments eval and truth, we only compare the scores, we do not determine the correct score.
* Either VCF can be used as eval or truth, or the same VCF can be used for both.
* Differences greater than the epsilon argument will trigger a warning.
*
* <h3>Compare the CNN_2D info fields for the same sites from two different VCFs:</h3>
*
* <pre>
* gatk EvaluateInfoFieldConcordance \
* -eval a.vcf \
* -truth another.vcf \
* -S summary.txt \
* -eval-info-key CNN_2D \
* -truth-info-key CNN_2D \
* -epsilon 0.01
* </pre>
*
* <h3>Compare the CNN_2D info field with the CNN_1D field from the same sites in one VCF:</h3>
*
* <pre>
* gatk EvaluateInfoFieldConcordance \
* -eval my.vcf \
* -truth my.vcf \
* -S summary.txt \
* -eval-info-key CNN_2D \
* -truth-info-key CNN_1D \
* -epsilon 0.01
* </pre>
*/
@CommandLineProgramProperties(
summary=EvaluateInfoFieldConcordance.USAGE_SUMMARY,
oneLineSummary=EvaluateInfoFieldConcordance.USAGE_ONE_LINE_SUMMARY,
programGroup=VariantEvaluationProgramGroup.class)
@DocumentedFeature
@BetaFeature
public class EvaluateInfoFieldConcordance extends AbstractConcordanceWalker {
static final String USAGE_ONE_LINE_SUMMARY = "Evaluate concordance of info fields in an input VCF against a validated truth VCF";
static final String USAGE_SUMMARY = "This tool evaluates info fields from an input VCF against a VCF that has been validated and is considered to represent ground truth.\n";
public static final String SUMMARY_LONG_NAME = "summary";
public static final String SUMMARY_SHORT_NAME = "S";

@Argument(doc="A table of summary statistics (true positives, sensitivity, etc.)", fullName=SUMMARY_LONG_NAME, shortName=SUMMARY_SHORT_NAME)
protected String summary;

@Argument(fullName="eval-info-key", shortName="eval-info-key", doc="Info key from eval vcf")
protected String evalInfoKey;

@Argument(fullName="truth-info-key", shortName="truth-info-key", doc="Info key from truth vcf")
protected String truthInfoKey;

@Advanced
@Argument(fullName="warn-big-differences",
shortName="warn-big-differences",
doc="If set differences in the info key values greater than epsilon will trigger warnings.",
optional=true)
protected boolean warnBigDifferences = false;

@Advanced
@Argument(fullName="epsilon", shortName="epsilon", doc="Difference tolerance", optional=true)
protected double epsilon = 0.1;

private int snpCount = 0;
private int indelCount = 0;

private double snpSumDelta = 0.0;
private double snpSumDeltaSquared = 0.0;
private double indelSumDelta = 0.0;
private double indelSumDeltaSquared = 0.0;

@Override
public void onTraversalStart() {
if(getEvalHeader().getInfoHeaderLine(evalInfoKey) == null){
throw new UserException("Missing key:"+evalInfoKey+" in Eval VCF:"+evalVariantsFile);
}

if(getTruthHeader().getInfoHeaderLine(truthInfoKey) == null){
throw new UserException("Missing key:"+truthInfoKey+" in Truth VCF:"+truthVariantsFile);
}
}

@Override
protected void apply(AbstractConcordanceWalker.TruthVersusEval truthVersusEval, ReadsContext readsContext, ReferenceContext refContext) {
ConcordanceState concordanceState = truthVersusEval.getConcordance();
switch (concordanceState) {
case TRUE_POSITIVE: {
if(truthVersusEval.getEval().isSNP()){
snpCount++;
} else if (truthVersusEval.getEval().isIndel()) {
indelCount++;
}
this.infoDifference(truthVersusEval.getEval(), truthVersusEval.getTruth());
break;
}
case FALSE_POSITIVE:
case FALSE_NEGATIVE:
case FILTERED_TRUE_NEGATIVE:
case FILTERED_FALSE_NEGATIVE: {
break;
}
default: {
throw new IllegalStateException("Unexpected ConcordanceState: " + concordanceState.toString());
}
}
}

private void infoDifference(final VariantContext eval, final VariantContext truth) {
if(eval.hasAttribute(this.evalInfoKey) && truth.hasAttribute(truthInfoKey)) {
final double evalVal = Double.valueOf((String) eval.getAttribute(this.evalInfoKey));
final double truthVal = Double.valueOf((String) truth.getAttribute(this.truthInfoKey));
final double delta = evalVal - truthVal;
final double deltaSquared = delta * delta;
if (eval.isSNP()) {
this.snpSumDelta += Math.sqrt(deltaSquared);
this.snpSumDeltaSquared += deltaSquared;
} else if (eval.isIndel()) {
this.indelSumDelta += Math.sqrt(deltaSquared);
this.indelSumDeltaSquared += deltaSquared;
}
if (warnBigDifferences && Math.abs(delta) > this.epsilon) {
this.logger.warn(String.format("Difference (%f) greater than epsilon (%f) at %s:%d %s:", delta, this.epsilon, eval.getContig(), eval.getStart(), eval.getAlleles().toString()));
this.logger.warn(String.format("\t\tTruth info: " + truth.getAttributes().toString()));
this.logger.warn(String.format("\t\tEval info: " + eval.getAttributes().toString()));
}
}
}

@Override
public Object onTraversalSuccess() {
final double snpMean = this.snpSumDelta / snpCount;
final double snpVariance = (this.snpSumDeltaSquared - this.snpSumDelta * this.snpSumDelta / snpCount) / snpCount;
final double snpStd = Math.sqrt(snpVariance);
final double indelMean = this.indelSumDelta / indelCount;
final double indelVariance = (this.indelSumDeltaSquared - this.indelSumDelta * this.indelSumDelta / indelCount) / indelCount;
final double indelStd = Math.sqrt(indelVariance);

this.logger.info(String.format("SNP average delta %f and standard deviation: %f", snpMean, snpStd));
this.logger.info(String.format("INDEL average delta %f and standard deviation: %f", indelMean, indelStd));

try (final InfoConcordanceRecord.InfoConcordanceWriter
concordanceWriter = InfoConcordanceRecord.getWriter(Paths.get(this.summary))){
concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.SNP, this.evalInfoKey, this.truthInfoKey, snpMean, snpStd));
concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.INDEL, this.evalInfoKey, this.truthInfoKey, indelMean, indelStd));
} catch (IOException e) {
throw new UserException("Encountered an IO exception writing the concordance summary table", e);
}

return "SUCCESS";
}

@Override
protected boolean areVariantsAtSameLocusConcordant(VariantContext truth, VariantContext eval) {
final boolean sameRefAllele = truth.getReference().equals(eval.getReference());
final boolean containsAltAllele = eval.getAlternateAlleles().contains(truth.getAlternateAllele(0));
return sameRefAllele && containsAltAllele;
}

@Override
protected Predicate<VariantContext> makeTruthVariantFilter() {
return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV();
}

@Override
protected Predicate<VariantContext> makeEvalVariantFilter() {
return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package org.broadinstitute.hellbender.tools.walkers.validation;

import htsjdk.variant.variantcontext.VariantContext;

import java.io.IOException;
import java.nio.file.Path;

import org.broadinstitute.hellbender.exceptions.UserException;

import org.broadinstitute.hellbender.utils.tsv.DataLine;
import org.broadinstitute.hellbender.utils.tsv.TableColumnCollection;
import org.broadinstitute.hellbender.utils.tsv.TableWriter;
import org.broadinstitute.hellbender.utils.tsv.TableReader;

/**
* Keeps track of concordance between two info fields.
*/
public class InfoConcordanceRecord {
private static final String VARIANT_TYPE_COLUMN_NAME = "type";
private static final String EVAL_INFO_KEY = "eval_info_key";
private static final String TRUE_INFO_KEY = "true_info_key";
private static final String MEAN_DIFFERENCE = "mean_difference";
private static final String STD_DIFFERENCE = "std_difference";
private static final String[] INFO_CONCORDANCE_COLUMN_HEADER =
{VARIANT_TYPE_COLUMN_NAME, EVAL_INFO_KEY, TRUE_INFO_KEY, MEAN_DIFFERENCE, STD_DIFFERENCE};
final VariantContext.Type type;
private final String evalKey;
private final String trueKey;
private final double mean;
private final double std;

/**
* Record keeps track of concordance between values from INFO-field keys of a VCF.
*
* @param type SNP or INDEL
* @param evalKey The INFO field key from the eval VCF
* @param trueKey The INFO field key from the truth VCF
* @param mean The mean of the differences in values for these INFO fields.
* @param std The standard deviation of the differences in values for these INFO fields.
*/
public InfoConcordanceRecord(VariantContext.Type type, String evalKey, String trueKey, double mean, double std) {
this.type = type;
this.evalKey = evalKey;
this.trueKey = trueKey;
this.mean = mean;
this.std = std;
}

/**
*
* @return Variant type (e.g. SNP or INDEL)
*/
public VariantContext.Type getVariantType() {
return this.type;
}

/**
*
* @return The mean of the differences between two INFO fields
*/
public double getMean() {
return this.mean;
}

/**
*
* @return The Standard Deviation of the differences between two INFO fields
*/
public double getStd() {
return this.std;
}

/**
*
* @return The INFO field for the eval VCF
*/
public String getEvalKey() {
return this.evalKey;
}

/**
*
* @return The INFO field for the truth VCF
*/
public String getTrueKey() {
return this.trueKey;
}

/**
* Get a table writer
* @param outputTable A Path where the output table will be written
* @return A Table writer for INFO field concordances
*/
public static InfoConcordanceWriter getWriter(Path outputTable) {
try {
InfoConcordanceWriter writer = new InfoConcordanceWriter(outputTable);
return writer;
}
catch (IOException e) {
throw new UserException(String.format("Encountered an IO exception while writing from %s.", outputTable), e);
}
}

/**
* Table writing class for InfoConcordanceRecords
*/
public static class InfoConcordanceWriter extends TableWriter<InfoConcordanceRecord> {
private InfoConcordanceWriter(Path output) throws IOException {
super(output.toFile(), new TableColumnCollection(INFO_CONCORDANCE_COLUMN_HEADER));
}

@Override
protected void composeLine(InfoConcordanceRecord record, DataLine dataLine) {
dataLine.set(VARIANT_TYPE_COLUMN_NAME, record.getVariantType().toString())
.set(EVAL_INFO_KEY, record.getEvalKey())
.set(TRUE_INFO_KEY, record.getTrueKey())
.set(MEAN_DIFFERENCE, record.getMean())
.set(STD_DIFFERENCE, record.getStd());
}
}

/**
* Table reading class for InfoConcordanceRecords
*/
public static class InfoConcordanceReader extends TableReader<InfoConcordanceRecord> {
public InfoConcordanceReader(Path summary) throws IOException {
super(summary.toFile());
}

@Override
protected InfoConcordanceRecord createRecord(DataLine dataLine) {
VariantContext.Type type = VariantContext.Type.valueOf(dataLine.get(VARIANT_TYPE_COLUMN_NAME));
String evalKey = dataLine.get(EVAL_INFO_KEY);
String trueKey = dataLine.get(TRUE_INFO_KEY);
double mean = Double.parseDouble(dataLine.get(MEAN_DIFFERENCE));
double std = Double.parseDouble(dataLine.get(STD_DIFFERENCE));
return new InfoConcordanceRecord(type, evalKey, trueKey, mean, std);
}
}
}
Loading

0 comments on commit 65d0edd

Please sign in to comment.