diff --git a/build.gradle b/build.gradle
index 4a5e333db15..f6b5084549b 100644
--- a/build.gradle
+++ b/build.gradle
@@ -63,7 +63,7 @@ final barclayVersion = System.getProperty('barclay.version','2.1.0')
final sparkVersion = System.getProperty('spark.version', '2.2.0')
final hadoopVersion = System.getProperty('hadoop.version', '2.8.2')
final hadoopBamVersion = System.getProperty('hadoopBam.version','7.10.0')
-final tensorflowVersion = System.getProperty('tensorflow.version','1.4.0')
+final tensorflowVersion = System.getProperty('tensorflow.version','1.9.0')
final genomicsdbVersion = System.getProperty('genomicsdb.version','0.10.0-proto-3.0.0-beta-1+bdce8be25b873')
final testNGVersion = '6.11'
// Using the shaded version to avoid conflicts between its protobuf dependency
diff --git a/scripts/gatkcondaenv.yml.template b/scripts/gatkcondaenv.yml.template
index 287756a0e9e..54454a722d3 100644
--- a/scripts/gatkcondaenv.yml.template
+++ b/scripts/gatkcondaenv.yml.template
@@ -26,7 +26,7 @@ dependencies:
- h5py==2.7.1
- html5lib==0.9999999
- joblib==0.11
- - keras==2.1.4
+ - keras==2.2.0
- markdown==2.6.9
- matplotlib==2.1.0
- numpy==1.13.3
@@ -44,8 +44,7 @@ dependencies:
- scipy==1.0.0
- six==1.11.0
- $tensorFlowDependency
- - tensorflow-tensorboard==0.4.0rc3
- theano==0.9.0
- tqdm==4.19.4
- werkzeug==0.12.2
- - gatkPythonPackageArchive.zip
+ - gatkPythonPackageArchive.zip
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java
new file mode 100644
index 00000000000..b255f160877
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java
@@ -0,0 +1,191 @@
+package org.broadinstitute.hellbender.tools.walkers.validation;
+
+import java.nio.file.Paths;
+import java.io.IOException;
+
+import org.apache.commons.collections4.Predicate;
+
+import htsjdk.variant.variantcontext.VariantContext;
+
+import org.broadinstitute.barclay.argparser.Advanced;
+import org.broadinstitute.barclay.argparser.Argument;
+import org.broadinstitute.barclay.argparser.BetaFeature;
+import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
+
+import org.broadinstitute.hellbender.engine.ReadsContext;
+import org.broadinstitute.hellbender.engine.ReferenceContext;
+import org.broadinstitute.hellbender.exceptions.UserException;
+import org.broadinstitute.hellbender.engine.AbstractConcordanceWalker;
+
+import picard.cmdline.programgroups.VariantEvaluationProgramGroup;
+
+/**
+ * Compare INFO field values between two VCFs or compare two different INFO fields from one VCF.
+ * We only evaluate sites that are in both VCFs.
+ * Although we use the arguments eval and truth, we only compare the scores, we do not determine the correct score.
+ * Either VCF can be used as eval or truth, or the same VCF can be used for both.
+ * Differences greater than the epsilon argument will trigger a warning.
+ *
+ *
Compare the CNN_2D info fields for the same sites from two different VCFs:
+ *
+ *
+ * gatk EvaluateInfoFieldConcordance \
+ * -eval a.vcf \
+ * -truth another.vcf \
+ * -S summary.txt \
+ * -eval-info-key CNN_2D \
+ * -truth-info-key CNN_2D \
+ * -epsilon 0.01
+ *
+ *
+ * Compare the CNN_2D info field with the CNN_1D field from the same sites in one VCF:
+ *
+ *
+ * gatk EvaluateInfoFieldConcordance \
+ * -eval my.vcf \
+ * -truth my.vcf \
+ * -S summary.txt \
+ * -eval-info-key CNN_2D \
+ * -truth-info-key CNN_1D \
+ * -epsilon 0.01
+ *
+ */
+@CommandLineProgramProperties(
+ summary=EvaluateInfoFieldConcordance.USAGE_SUMMARY,
+ oneLineSummary=EvaluateInfoFieldConcordance.USAGE_ONE_LINE_SUMMARY,
+ programGroup=VariantEvaluationProgramGroup.class)
+@DocumentedFeature
+@BetaFeature
+public class EvaluateInfoFieldConcordance extends AbstractConcordanceWalker {
+ static final String USAGE_ONE_LINE_SUMMARY = "Evaluate concordance of info fields in an input VCF against a validated truth VCF";
+ static final String USAGE_SUMMARY = "This tool evaluates info fields from an input VCF against a VCF that has been validated and is considered to represent ground truth.\n";
+ public static final String SUMMARY_LONG_NAME = "summary";
+ public static final String SUMMARY_SHORT_NAME = "S";
+
+ @Argument(doc="A table of summary statistics (true positives, sensitivity, etc.)", fullName=SUMMARY_LONG_NAME, shortName=SUMMARY_SHORT_NAME)
+ protected String summary;
+
+ @Argument(fullName="eval-info-key", shortName="eval-info-key", doc="Info key from eval vcf")
+ protected String evalInfoKey;
+
+ @Argument(fullName="truth-info-key", shortName="truth-info-key", doc="Info key from truth vcf")
+ protected String truthInfoKey;
+
+ @Advanced
+ @Argument(fullName="warn-big-differences",
+ shortName="warn-big-differences",
+ doc="If set differences in the info key values greater than epsilon will trigger warnings.",
+ optional=true)
+ protected boolean warnBigDifferences = false;
+
+ @Advanced
+ @Argument(fullName="epsilon", shortName="epsilon", doc="Difference tolerance", optional=true)
+ protected double epsilon = 0.1;
+
+ private int snpCount = 0;
+ private int indelCount = 0;
+
+ private double snpSumDelta = 0.0;
+ private double snpSumDeltaSquared = 0.0;
+ private double indelSumDelta = 0.0;
+ private double indelSumDeltaSquared = 0.0;
+
+ @Override
+ public void onTraversalStart() {
+ if(getEvalHeader().getInfoHeaderLine(evalInfoKey) == null){
+ throw new UserException("Missing key:"+evalInfoKey+" in Eval VCF:"+evalVariantsFile);
+ }
+
+ if(getTruthHeader().getInfoHeaderLine(truthInfoKey) == null){
+ throw new UserException("Missing key:"+truthInfoKey+" in Truth VCF:"+truthVariantsFile);
+ }
+ }
+
+ @Override
+ protected void apply(AbstractConcordanceWalker.TruthVersusEval truthVersusEval, ReadsContext readsContext, ReferenceContext refContext) {
+ ConcordanceState concordanceState = truthVersusEval.getConcordance();
+ switch (concordanceState) {
+ case TRUE_POSITIVE: {
+ if(truthVersusEval.getEval().isSNP()){
+ snpCount++;
+ } else if (truthVersusEval.getEval().isIndel()) {
+ indelCount++;
+ }
+ this.infoDifference(truthVersusEval.getEval(), truthVersusEval.getTruth());
+ break;
+ }
+ case FALSE_POSITIVE:
+ case FALSE_NEGATIVE:
+ case FILTERED_TRUE_NEGATIVE:
+ case FILTERED_FALSE_NEGATIVE: {
+ break;
+ }
+ default: {
+ throw new IllegalStateException("Unexpected ConcordanceState: " + concordanceState.toString());
+ }
+ }
+ }
+
+ private void infoDifference(final VariantContext eval, final VariantContext truth) {
+ if(eval.hasAttribute(this.evalInfoKey) && truth.hasAttribute(truthInfoKey)) {
+ final double evalVal = Double.valueOf((String) eval.getAttribute(this.evalInfoKey));
+ final double truthVal = Double.valueOf((String) truth.getAttribute(this.truthInfoKey));
+ final double delta = evalVal - truthVal;
+ final double deltaSquared = delta * delta;
+ if (eval.isSNP()) {
+ this.snpSumDelta += Math.sqrt(deltaSquared);
+ this.snpSumDeltaSquared += deltaSquared;
+ } else if (eval.isIndel()) {
+ this.indelSumDelta += Math.sqrt(deltaSquared);
+ this.indelSumDeltaSquared += deltaSquared;
+ }
+ if (warnBigDifferences && Math.abs(delta) > this.epsilon) {
+ this.logger.warn(String.format("Difference (%f) greater than epsilon (%f) at %s:%d %s:", delta, this.epsilon, eval.getContig(), eval.getStart(), eval.getAlleles().toString()));
+ this.logger.warn(String.format("\t\tTruth info: " + truth.getAttributes().toString()));
+ this.logger.warn(String.format("\t\tEval info: " + eval.getAttributes().toString()));
+ }
+ }
+ }
+
+ @Override
+ public Object onTraversalSuccess() {
+ final double snpMean = this.snpSumDelta / snpCount;
+ final double snpVariance = (this.snpSumDeltaSquared - this.snpSumDelta * this.snpSumDelta / snpCount) / snpCount;
+ final double snpStd = Math.sqrt(snpVariance);
+ final double indelMean = this.indelSumDelta / indelCount;
+ final double indelVariance = (this.indelSumDeltaSquared - this.indelSumDelta * this.indelSumDelta / indelCount) / indelCount;
+ final double indelStd = Math.sqrt(indelVariance);
+
+ this.logger.info(String.format("SNP average delta %f and standard deviation: %f", snpMean, snpStd));
+ this.logger.info(String.format("INDEL average delta %f and standard deviation: %f", indelMean, indelStd));
+
+ try (final InfoConcordanceRecord.InfoConcordanceWriter
+ concordanceWriter = InfoConcordanceRecord.getWriter(Paths.get(this.summary))){
+ concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.SNP, this.evalInfoKey, this.truthInfoKey, snpMean, snpStd));
+ concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.INDEL, this.evalInfoKey, this.truthInfoKey, indelMean, indelStd));
+ } catch (IOException e) {
+ throw new UserException("Encountered an IO exception writing the concordance summary table", e);
+ }
+
+ return "SUCCESS";
+ }
+
+ @Override
+ protected boolean areVariantsAtSameLocusConcordant(VariantContext truth, VariantContext eval) {
+ final boolean sameRefAllele = truth.getReference().equals(eval.getReference());
+ final boolean containsAltAllele = eval.getAlternateAlleles().contains(truth.getAlternateAllele(0));
+ return sameRefAllele && containsAltAllele;
+ }
+
+ @Override
+ protected Predicate makeTruthVariantFilter() {
+ return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV();
+ }
+
+ @Override
+ protected Predicate makeEvalVariantFilter() {
+ return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV();
+ }
+
+}
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java
new file mode 100644
index 00000000000..3b704cd4d06
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java
@@ -0,0 +1,140 @@
+package org.broadinstitute.hellbender.tools.walkers.validation;
+
+import htsjdk.variant.variantcontext.VariantContext;
+
+import java.io.IOException;
+import java.nio.file.Path;
+
+import org.broadinstitute.hellbender.exceptions.UserException;
+
+import org.broadinstitute.hellbender.utils.tsv.DataLine;
+import org.broadinstitute.hellbender.utils.tsv.TableColumnCollection;
+import org.broadinstitute.hellbender.utils.tsv.TableWriter;
+import org.broadinstitute.hellbender.utils.tsv.TableReader;
+
+/**
+ * Keeps track of concordance between two info fields.
+ */
+public class InfoConcordanceRecord {
+ private static final String VARIANT_TYPE_COLUMN_NAME = "type";
+ private static final String EVAL_INFO_KEY = "eval_info_key";
+ private static final String TRUE_INFO_KEY = "true_info_key";
+ private static final String MEAN_DIFFERENCE = "mean_difference";
+ private static final String STD_DIFFERENCE = "std_difference";
+ private static final String[] INFO_CONCORDANCE_COLUMN_HEADER =
+ {VARIANT_TYPE_COLUMN_NAME, EVAL_INFO_KEY, TRUE_INFO_KEY, MEAN_DIFFERENCE, STD_DIFFERENCE};
+ final VariantContext.Type type;
+ private final String evalKey;
+ private final String trueKey;
+ private final double mean;
+ private final double std;
+
+ /**
+ * Record keeps track of concordance between values from INFO-field keys of a VCF.
+ *
+ * @param type SNP or INDEL
+ * @param evalKey The INFO field key from the eval VCF
+ * @param trueKey The INFO field key from the truth VCF
+ * @param mean The mean of the differences in values for these INFO fields.
+ * @param std The standard deviation of the differences in values for these INFO fields.
+ */
+ public InfoConcordanceRecord(VariantContext.Type type, String evalKey, String trueKey, double mean, double std) {
+ this.type = type;
+ this.evalKey = evalKey;
+ this.trueKey = trueKey;
+ this.mean = mean;
+ this.std = std;
+ }
+
+ /**
+ *
+ * @return Variant type (e.g. SNP or INDEL)
+ */
+ public VariantContext.Type getVariantType() {
+ return this.type;
+ }
+
+ /**
+ *
+ * @return The mean of the differences between two INFO fields
+ */
+ public double getMean() {
+ return this.mean;
+ }
+
+ /**
+ *
+ * @return The Standard Deviation of the differences between two INFO fields
+ */
+ public double getStd() {
+ return this.std;
+ }
+
+ /**
+ *
+ * @return The INFO field for the eval VCF
+ */
+ public String getEvalKey() {
+ return this.evalKey;
+ }
+
+ /**
+ *
+ * @return The INFO field for the truth VCF
+ */
+ public String getTrueKey() {
+ return this.trueKey;
+ }
+
+ /**
+ * Get a table writer
+ * @param outputTable A Path where the output table will be written
+ * @return A Table writer for INFO field concordances
+ */
+ public static InfoConcordanceWriter getWriter(Path outputTable) {
+ try {
+ InfoConcordanceWriter writer = new InfoConcordanceWriter(outputTable);
+ return writer;
+ }
+ catch (IOException e) {
+ throw new UserException(String.format("Encountered an IO exception while writing from %s.", outputTable), e);
+ }
+ }
+
+ /**
+ * Table writing class for InfoConcordanceRecords
+ */
+ public static class InfoConcordanceWriter extends TableWriter {
+ private InfoConcordanceWriter(Path output) throws IOException {
+ super(output.toFile(), new TableColumnCollection(INFO_CONCORDANCE_COLUMN_HEADER));
+ }
+
+ @Override
+ protected void composeLine(InfoConcordanceRecord record, DataLine dataLine) {
+ dataLine.set(VARIANT_TYPE_COLUMN_NAME, record.getVariantType().toString())
+ .set(EVAL_INFO_KEY, record.getEvalKey())
+ .set(TRUE_INFO_KEY, record.getTrueKey())
+ .set(MEAN_DIFFERENCE, record.getMean())
+ .set(STD_DIFFERENCE, record.getStd());
+ }
+ }
+
+ /**
+ * Table reading class for InfoConcordanceRecords
+ */
+ public static class InfoConcordanceReader extends TableReader {
+ public InfoConcordanceReader(Path summary) throws IOException {
+ super(summary.toFile());
+ }
+
+ @Override
+ protected InfoConcordanceRecord createRecord(DataLine dataLine) {
+ VariantContext.Type type = VariantContext.Type.valueOf(dataLine.get(VARIANT_TYPE_COLUMN_NAME));
+ String evalKey = dataLine.get(EVAL_INFO_KEY);
+ String trueKey = dataLine.get(TRUE_INFO_KEY);
+ double mean = Double.parseDouble(dataLine.get(MEAN_DIFFERENCE));
+ double std = Double.parseDouble(dataLine.get(STD_DIFFERENCE));
+ return new InfoConcordanceRecord(type, evalKey, trueKey, mean, std);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java
index be4e0c34ebb..1a114aea907 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java
@@ -24,7 +24,6 @@
import java.io.*;
import java.util.*;
-import java.util.stream.StreamSupport;
/**
@@ -92,15 +91,15 @@
* -weights path/to/my_weights.hd5
*
*/
-@DocumentedFeature
@ExperimentalFeature
+@DocumentedFeature
@CommandLineProgramProperties(
summary = CNNScoreVariants.USAGE_SUMMARY,
oneLineSummary = CNNScoreVariants.USAGE_ONE_LINE_SUMMARY,
programGroup = VariantFilteringProgramGroup.class
)
-public class CNNScoreVariants extends VariantWalker {
+public class CNNScoreVariants extends TwoPassVariantWalker {
private final static String NL = String.format("%n");
static final String USAGE_ONE_LINE_SUMMARY = "Apply a Convolutional Neural Net to filter annotated variants";
static final String USAGE_SUMMARY = "Annotate a VCF with scores from a Convolutional Neural Network (CNN)." +
@@ -116,7 +115,6 @@ public class CNNScoreVariants extends VariantWalker {
private static final int ALT_INDEX = 3;
private static final int KEY_INDEX = 4;
private static final int FIFO_STRING_INITIAL_CAPACITY = 1024;
- private static final int MAX_READ_BATCH = 4098;
@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME,
shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
@@ -138,6 +136,10 @@ public class CNNScoreVariants extends VariantWalker {
@Argument(fullName = "filter-symbolic-and-sv", shortName = "filter-symbolic-and-sv", doc = "If set will filter symbolic and and structural variants from the input VCF", optional = true)
private boolean filterSymbolicAndSV = false;
+ @Advanced
+ @Argument(fullName="info-annotation-keys", shortName="info-annotation-keys", doc="The VCF info fields to send to python.", optional=true)
+ private List annotationKeys = new ArrayList<>(Arrays.asList("MQ", "DP", "SOR", "FS", "QD", "MQRankSum", "ReadPosRankSum"));
+
@Advanced
@Argument(fullName = "inference-batch-size", shortName = "inference-batch-size", doc = "Size of batches for python to do inference on.", minValue = 1, maxValue = 4096, optional = true)
private int inferenceBatchSize = 256;
@@ -180,9 +182,11 @@ public class CNNScoreVariants extends VariantWalker {
private int windowEnd = windowSize / 2;
private int windowStart = windowSize / 2;
private boolean waitforBatchCompletion = false;
- private File scoreFile;
+ private File scoreFile;
private String scoreKey;
+ private Scanner scoreScan;
+ private VariantContextWriter vcfWriter;
private static String resourcePathReadTensor = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/cnn_score_variants/small_2d.json";
private static String resourcePathReferenceTensor = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/cnn_score_variants/1d_cnn_mix_train_full_bn.json";
@@ -228,15 +232,14 @@ public List getDefaultReadFilters() {
@Override
public void onTraversalStart() {
- scoreKey = getScoreKeyAndCheckModelAndReadsHarmony();
- if (architecture == null && weights == null) {
- setArchitectureAndWeightsFromResources();
+ if (getHeaderForVariants().getGenotypeSamples().size() > 1) {
+ logger.warn("CNNScoreVariants is a single sample tool, but the input VCF has more than 1 sample.");
}
// Start the Python process and initialize a stream writer for streaming data to the Python code
pythonExecutor.start(Collections.emptyList(), enableJournal, pythonProfileResults);
-
pythonExecutor.initStreamWriter(AsynchronousStreamWriter.stringSerializer);
+
batchList = new ArrayList<>(transferBatchSize);
// Execute Python code to open our output file, where it will write the contents of everything it reads
@@ -248,26 +251,11 @@ public void onTraversalStart() {
} else {
logger.info("Saving temp file from python:" + scoreFile.getAbsolutePath());
}
-
- pythonExecutor.sendSynchronousCommand("from keras import backend" + NL);
- pythonExecutor.sendSynchronousCommand(String.format("backend.set_session(backend.tf.Session(config=backend.tf.ConfigProto(intra_op_parallelism_threads=%d, inter_op_parallelism_threads=%d)))" + NL, intraOpThreads, interOpThreads));
-
pythonExecutor.sendSynchronousCommand(String.format("tempFile = open('%s', 'w+')" + NL, scoreFile.getAbsolutePath()));
pythonExecutor.sendSynchronousCommand("import vqsr_cnn" + NL);
- String getArgsAndModel;
- if (weights != null && architecture != null) {
- getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics('%s', weights_hd5='%s')", architecture, weights) + NL;
- logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture + " and weights:" + weights);
- } else if (architecture == null) {
- getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics(None, weights_hd5='%s', tensor_type='%s')", weights, tensorType.name()) + NL;
- logger.info("Using key:" + scoreKey + " for CNN weights:" + weights);
- } else {
- getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics('%s')", architecture) + NL;
- logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture);
- }
- pythonExecutor.sendSynchronousCommand(getArgsAndModel);
-
+ scoreKey = getScoreKeyAndCheckModelAndReadsHarmony();
+ initializePythonArgsAndModel();
} catch (IOException e) {
throw new GATKException("Error when creating temp file and initializing python executor.", e);
}
@@ -275,7 +263,7 @@ public void onTraversalStart() {
}
@Override
- public void apply(final VariantContext variant, final ReadsContext readsContext, final ReferenceContext referenceContext, final FeatureContext featureContext) {
+ public void firstPassApply(final VariantContext variant, final ReadsContext readsContext, final ReferenceContext referenceContext, final FeatureContext featureContext) {
referenceContext.setWindow(windowStart, windowEnd);
if (tensorType.isReadsRequired()) {
transferReadsToPythonViaFifo(variant, readsContext, referenceContext);
@@ -285,6 +273,64 @@ public void apply(final VariantContext variant, final ReadsContext readsContext,
sendBatchIfReady();
}
+ @Override
+ public void afterFirstPass() {
+ if (waitforBatchCompletion) {
+ pythonExecutor.waitForPreviousBatchCompletion();
+ }
+ if (curBatchSize > 0) {
+ executePythonCommand();
+ pythonExecutor.waitForPreviousBatchCompletion();
+ }
+
+ pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL);
+ pythonExecutor.terminate();
+
+ try {
+ scoreScan = new Scanner(scoreFile);
+ vcfWriter = createVCFWriter(new File(outputFile));
+ scoreScan.useDelimiter("\\n");
+ writeVCFHeader(vcfWriter);
+ } catch (IOException e) {
+ throw new GATKException("Error when trying to temporary score file scanner.", e);
+ }
+
+ }
+
+ @Override
+ protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
+ String sv = scoreScan.nextLine();
+ String[] scoredVariant = sv.split("\\t");
+
+ if (variant.getContig().equals(scoredVariant[CONTIG_INDEX])
+ && Integer.toString(variant.getStart()).equals(scoredVariant[POS_INDEX])
+ && variant.getReference().getBaseString().equals(scoredVariant[REF_INDEX])
+ && variant.getAlternateAlleles().toString().equals(scoredVariant[ALT_INDEX])) {
+
+ final VariantContextBuilder builder = new VariantContextBuilder(variant);
+ if (scoredVariant.length > KEY_INDEX) {
+ builder.attribute(scoreKey, scoredVariant[KEY_INDEX]);
+ }
+ vcfWriter.add(builder.make());
+
+ } else {
+ String errorMsg = "Score file out of sync with original VCF. Score file has:" + sv;
+ errorMsg += "\n But VCF has:" + variant.toStringWithoutGenotypes();
+ throw new GATKException(errorMsg);
+ }
+ }
+
+ @Override
+ public void closeTool() {
+ logger.info("Done scoring variants with CNN.");
+ if (vcfWriter != null) {
+ vcfWriter.close();
+ }
+ if (scoreScan != null){
+ scoreScan.close();
+ }
+ }
+
private void transferToPythonViaFifo(final VariantContext variant, final ReferenceContext referenceContext) {
try {
final String outDat = String.format("%s\t%s\t%s\t%s\n",
@@ -329,6 +375,7 @@ private void transferReadsToPythonViaFifo(final VariantContext variant, final Re
if (!readIt.hasNext()) {
logger.warn("No reads at contig:" + variant.getContig() + " site:" + String.valueOf(variant.getStart()));
}
+
while (readIt.hasNext()) {
sb.append(GATKReadToString(readIt.next()));
}
@@ -374,29 +421,15 @@ private String getVariantDataString(final VariantContext variant) {
private String getVariantInfoString(final VariantContext variant) {
// Create a string that will easily be parsed as a python dictionary
- String varInfo = "";
- for (final String attributeKey : variant.getAttributes().keySet()) {
- varInfo += attributeKey + "=" + variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";";
- }
- return varInfo;
- }
-
- @Override
- public Object onTraversalSuccess() {
- if (waitforBatchCompletion) {
- pythonExecutor.waitForPreviousBatchCompletion();
- }
- if (curBatchSize > 0) {
- executePythonCommand();
- pythonExecutor.waitForPreviousBatchCompletion();
+ StringBuilder sb = new StringBuilder(FIFO_STRING_INITIAL_CAPACITY);
+ for (final String attributeKey : annotationKeys) {
+ if (variant.hasAttribute(attributeKey)) {
+ sb.append(attributeKey);
+ sb.append("=");
+ sb.append(variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";");
+ }
}
-
- pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL);
- pythonExecutor.terminate();
-
- writeOutputVCFWithScores();
-
- return true;
+ return sb.toString();
}
private void executePythonCommand() {
@@ -408,42 +441,6 @@ private void executePythonCommand() {
pythonExecutor.startBatchWrite(pythonCommand, batchList);
}
-
- private void writeOutputVCFWithScores() {
- try (final Scanner scoreScan = new Scanner(scoreFile);
- final VariantContextWriter vcfWriter = createVCFWriter(new File(outputFile))) {
- scoreScan.useDelimiter("\\n");
- writeVCFHeader(vcfWriter);
- final VariantFilter variantfilter = makeVariantFilter();
-
- // Annotate each variant in the input stream, as in variantWalkerBase.traverse()
- StreamSupport.stream(getSpliteratorForDrivingVariants(), false)
- .filter(variantfilter)
- .forEach(variant -> {
- String sv = scoreScan.nextLine();
- String[] scoredVariant = sv.split("\\t");
- if (variant.getContig().equals(scoredVariant[CONTIG_INDEX])
- && Integer.toString(variant.getStart()).equals(scoredVariant[POS_INDEX])
- && variant.getReference().getBaseString().equals(scoredVariant[REF_INDEX])
- && variant.getAlternateAlleles().toString().equals(scoredVariant[ALT_INDEX])) {
- final VariantContextBuilder builder = new VariantContextBuilder(variant);
- if (scoredVariant.length > KEY_INDEX) {
- builder.attribute(scoreKey, scoredVariant[KEY_INDEX]);
- }
- vcfWriter.add(builder.make());
- } else {
- String errorMsg = "Score file out of sync with original VCF. Score file has:" + sv;
- errorMsg += "\n But VCF has:" + variant.toStringWithoutGenotypes();
- throw new GATKException(errorMsg);
- }
- });
-
- } catch (IOException e) {
- throw new GATKException("Error when trying to write annotated VCF.", e);
- }
-
- }
-
private void writeVCFHeader(VariantContextWriter vcfWriter) {
// setup the header fields
final VCFHeader inputHeader = getHeaderForVariants();
@@ -471,20 +468,33 @@ private String getScoreKeyAndCheckModelAndReadsHarmony() {
}
}
- private void setArchitectureAndWeightsFromResources() {
- if (tensorType.equals(TensorType.read_tensor)) {
- architecture = IOUtils.writeTempResourceFromPath(resourcePathReadTensor, null).getAbsolutePath();
- weights = IOUtils.writeTempResourceFromPath(
- resourcePathReadTensor.replace(".json", ".hd5"),
- null).getAbsolutePath();
- } else if (tensorType.equals(TensorType.reference)) {
- architecture = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor, null).getAbsolutePath();
- weights = IOUtils.writeTempResourceFromPath(
- resourcePathReferenceTensor.replace(".json", ".hd5"), null).getAbsolutePath();
+ private void initializePythonArgsAndModel(){
+ if (weights == null && architecture == null) {
+ if (tensorType.equals(TensorType.read_tensor)) {
+ architecture = IOUtils.writeTempResourceFromPath(resourcePathReadTensor, null).getAbsolutePath();
+ weights = IOUtils.writeTempResourceFromPath(
+ resourcePathReadTensor.replace(".json", ".hd5"),
+ null).getAbsolutePath();
+ } else if (tensorType.equals(TensorType.reference)) {
+ architecture = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor, null).getAbsolutePath();
+ weights = IOUtils.writeTempResourceFromPath(
+ resourcePathReferenceTensor.replace(".json", ".hd5"), null).getAbsolutePath();
+ } else {
+ throw new GATKException("No default architecture for tensor type:" + tensorType.name());
+ }
+ }
+
+ String getArgsAndModel;
+ if (weights != null && architecture != null) {
+ getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, '%s', weights_hd5='%s')", intraOpThreads, interOpThreads, architecture, weights) + NL;
+ logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture + " and weights:" + weights);
+ } else if (architecture == null) {
+ getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, None, weights_hd5='%s', tensor_type='%s')", intraOpThreads, interOpThreads, weights, tensorType.name()) + NL;
+ logger.info("Using key:" + scoreKey + " for CNN weights:" + weights);
} else {
- throw new GATKException("No default architecture for tensor type:" + tensorType.name());
+ getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, '%s')", intraOpThreads, interOpThreads, architecture) + NL;
+ logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture);
}
+ pythonExecutor.sendSynchronousCommand(getArgsAndModel);
}
-
}
-
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java
index 57127447a0c..bde150764d0 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java
@@ -13,6 +13,7 @@
import java.util.Arrays;
import java.util.List;
+
/**
* Train a Convolutional Neural Network (CNN) for filtering variants.
* This tool expects requires training data generated by {@link CNNVariantWriteTensors}.
@@ -37,17 +38,17 @@
* Train a 1D CNN on Reference Tensors
*
* gatk CNNVariantTrain \
- * --tensor-type reference \
- * --input-tensor-dir my_tensor_folder \
- * --model-name my_1d_model
+ * -tensor-type reference \
+ * -input-tensor-dir my_tensor_folder \
+ * -model-name my_1d_model
*
*
* Train a 2D CNN on Read Tensors
*
* gatk CNNVariantTrain \
- * --input-tensor-dir my_tensor_folder \
- * --tensor-type read-tensor \
- * --model-name my_2d_model
+ * -input-tensor-dir my_tensor_folder \
+ * -tensor-type read-tensor \
+ * -model-name my_2d_model
*
*
*/
@@ -66,7 +67,7 @@ public class CNNVariantTrain extends CommandLineProgram {
@Argument(fullName = "output-dir", shortName = "output-dir", doc = "Directory where models will be saved, defaults to current working directory.", optional = true)
private String outputDir = "./";
- @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
+ @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Type of tensors to use as input reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;
@Argument(fullName = "model-name", shortName = "model-name", doc = "Name of the model to be trained.", optional = true)
@@ -84,6 +85,42 @@ public class CNNVariantTrain extends CommandLineProgram {
@Argument(fullName = "image-dir", shortName = "image-dir", doc = "Path where plots and figures are saved.", optional = true)
private String imageDir;
+ @Argument(fullName = "conv-width", shortName = "conv-width", doc = "Width of convolution kernels", optional = true)
+ private int convWidth = 5;
+
+ @Argument(fullName = "conv-height", shortName = "conv-height", doc = "Height of convolution kernels", optional = true)
+ private int convHeight = 5;
+
+ @Argument(fullName = "conv-dropout", shortName = "conv-dropout", doc = "Dropout rate in convolution layers", optional = true)
+ private float convDropout = 0.0f;
+
+ @Argument(fullName = "conv-batch-normalize", shortName = "conv-batch-normalize", doc = "Batch normalize convolution layers", optional = true)
+ private boolean convBatchNormalize = false;
+
+ @Argument(fullName = "conv-layers", shortName = "conv-layers", doc = "List of number of filters to use in each convolutional layer", optional = true)
+ private List convLayers = new ArrayList();
+
+ @Argument(fullName = "padding", shortName = "padding", doc = "Padding for convolution layers, valid or same", optional = true)
+ private String padding = "valid";
+
+ @Argument(fullName = "spatial-dropout", shortName = "spatial-dropout", doc = "Spatial dropout on convolution layers", optional = true)
+ private boolean spatialDropout = false;
+
+ @Argument(fullName = "fc-layers", shortName = "fc-layers", doc = "List of number of filters to use in each fully-connected layer", optional = true)
+ private List fcLayers = new ArrayList();
+
+ @Argument(fullName = "fc-dropout", shortName = "fc-dropout", doc = "Dropout rate in fully-connected layers", optional = true)
+ private float fcDropout = 0.0f;
+
+ @Argument(fullName = "fc-batch-normalize", shortName = "fc-batch-normalize", doc = "Batch normalize fully-connected layers", optional = true)
+ private boolean fcBatchNormalize = false;
+
+ @Argument(fullName = "annotation-units", shortName = "annotation-units", doc = "Number of units connected to the annotation input layer", optional = true)
+ private int annotationUnits = 16;
+
+ @Argument(fullName = "annotation-shortcut", shortName = "annotation-shortcut", doc = "Shortcut connections on the annotation layers.", optional = true)
+ private boolean annotationShortcut = false;
+
@Advanced
@Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true)
private boolean channelsLast = true;
@@ -109,11 +146,18 @@ protected Object doWork() {
"--output_dir", outputDir,
"--tensor_name", tensorType.name(),
"--annotation_set", annotationSet,
+ "--conv_width", Integer.toString(convWidth),
+ "--conv_height", Integer.toString(convHeight),
+ "--conv_dropout", Float.toString(convDropout),
+ "--padding", padding,
+ "--fc_dropout", Float.toString(fcDropout),
+ "--annotation_units", Integer.toString(annotationUnits),
"--epochs", Integer.toString(epochs),
"--training_steps", Integer.toString(trainingSteps),
"--validation_steps", Integer.toString(validationSteps),
"--id", modelName));
+ // Add boolean arguments
if(channelsLast){
arguments.add("--channels_last");
} else {
@@ -124,12 +168,45 @@ protected Object doWork() {
arguments.addAll(Arrays.asList("--image_dir", imageDir));
}
- if (tensorType == TensorType.reference) {
- arguments.addAll(Arrays.asList("--mode", "train_on_reference_tensors_and_annotations"));
- } else if (tensorType == TensorType.read_tensor) {
- arguments.addAll(Arrays.asList("--mode", "train_small_model_on_read_tensors_and_annotations"));
- } else {
- throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
+ if (convLayers.size() == 0 && fcLayers.size() == 0){
+ if (tensorType == TensorType.reference) {
+ arguments.addAll(Arrays.asList("--mode", "train_default_1d_model"));
+ } else if (tensorType == TensorType.read_tensor) {
+ arguments.addAll(Arrays.asList("--mode", "train_default_2d_model"));
+ } else {
+ throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
+ }
+ } else { // Command line specified custom architecture
+ if(convBatchNormalize){
+ arguments.add("--conv_batch_normalize");
+ }
+ if(fcBatchNormalize){
+ arguments.add("--fc_batch_normalize");
+ }
+ if(spatialDropout){
+ arguments.add("--spatial_dropout");
+ }
+ if(annotationShortcut){
+ arguments.add("--annotation_shortcut");
+ }
+
+ // Add list arguments
+ arguments.add("--conv_layers");
+ for(Integer cl : convLayers){
+ arguments.add(Integer.toString(cl));
+ }
+ arguments.add("--fc_layers");
+ for(Integer fl : fcLayers){
+ arguments.add(Integer.toString(fl));
+ }
+
+ if (tensorType == TensorType.reference) {
+ arguments.addAll(Arrays.asList("--mode", "train_args_model_on_reference_and_annotations"));
+ } else if (tensorType == TensorType.read_tensor) {
+ arguments.addAll(Arrays.asList("--mode", "train_args_model_on_read_tensors_and_annotations"));
+ } else {
+ throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
+ }
}
logger.info("Args are:"+ Arrays.toString(arguments.toArray()));
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java
index fd1cd1fc390..be7ba2b60be 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java
@@ -47,10 +47,10 @@
* gatk CNNVariantWriteTensors \
* -R reference.fasta \
* -V input.vcf.gz \
- * --truth-vcf platinum-genomes.vcf \
- * --truth-bed platinum-confident-region.bed \
- * --tensor-type reference \
- * --output-tensor-dir my-tensor-folder
+ * -truth-vcf platinum-genomes.vcf \
+ * -truth-bed platinum-confident-region.bed \
+ * -tensor-type reference \
+ * -output-tensor-dir my-tensor-folder
*
*
* Write Read Tensors
@@ -58,11 +58,11 @@
* gatk CNNVariantWriteTensors \
* -R reference.fasta \
* -V input.vcf.gz \
- * --truth-vcf platinum-genomes.vcf \
- * --truth-bed platinum-confident-region.bed \
- * --tensor-type read_tensor \
- * --bam-file input.bam \
- * --output-tensor-dir my-tensor-folder
+ * -truth-vcf platinum-genomes.vcf \
+ * -truth-bed platinum-confident-region.bed \
+ * -tensor-type read_tensor \
+ * -bam-file input.bam \
+ * -output-tensor-dir my-tensor-folder
*
*
*/
@@ -100,6 +100,12 @@ public class CNNVariantWriteTensors extends CommandLineProgram {
@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate.")
private TensorType tensorType = TensorType.reference;
+ @Argument(fullName = "downsample-snps", shortName = "downsample-snps", doc = "Fraction of SNPs to write tensors for.", optional = true)
+ private float downsampleSnps = 0.05f;
+
+ @Argument(fullName = "downsample-indels", shortName = "downsample-indels", doc = "Fraction of INDELs to write tensors for.", optional = true)
+ private float downsampleIndels = 0.5f;
+
@Advanced
@Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true)
private boolean channelsLast = true;
@@ -131,6 +137,8 @@ protected Object doWork() {
"--tensor_name", tensorType.name(),
"--annotation_set", annotationSet,
"--samples", Integer.toString(maxTensors),
+ "--downsample_snps", Float.toString(downsampleSnps),
+ "--downsample_indels", Float.toString(downsampleIndels),
"--data_dir", outputTensorsDir));
if(channelsLast){
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java
index 1788d7790cb..8f83b5b2482 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java
@@ -176,9 +176,11 @@ public void afterFirstPass() {
@Override
protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
final VariantContextBuilder builder = new VariantContextBuilder(variant);
+
if (removeOldFilters) {
builder.unfiltered();
}
+
if (variant.hasAttribute(infoKey)) {
final double score = Double.parseDouble((String) variant.getAttribute(infoKey));
if (variant.isSNP() && isTrancheFiltered(score, snpCutoffs)) {
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py
index 41ed768c90b..ece27f8dc1e 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py
@@ -1,5 +1,6 @@
-from .vqsr_cnn.models import build_read_tensor_2d_and_annotations_model, build_tiny_2d_annotation_model, build_reference_annotation_model
-from .vqsr_cnn.models import args_and_model_from_semantics, train_model_from_generators, build_small_2d_annotation_model
+from .vqsr_cnn.models import build_2d_annotation_model_from_args, build_1d_annotation_model_from_args
+from .vqsr_cnn.models import build_default_1d_annotation_model, build_default_2d_annotation_model
+from .vqsr_cnn.models import start_session_get_args_and_model, train_model_from_generators
from .vqsr_cnn.tensor_maps import get_tensor_channel_map_from_args, tensor_shape_from_args
from .vqsr_cnn.arguments import parse_args, weight_path_from_args, annotations_from_args
from .vqsr_cnn.inference import score_and_write_batch
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py
index 247587826d9..c89424a56ad 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py
@@ -1,6 +1,6 @@
-from .models import build_read_tensor_2d_and_annotations_model, build_tiny_2d_annotation_model
-from .models import args_and_model_from_semantics, build_small_2d_annotation_model
-from .models import build_reference_annotation_model, train_model_from_generators
+from .models import build_2d_annotation_model_from_args, build_1d_annotation_model_from_args
+from .models import build_default_1d_annotation_model, build_default_2d_annotation_model
+from .models import start_session_get_args_and_model, train_model_from_generators
from .tensor_maps import get_tensor_channel_map_from_args, tensor_shape_from_args
from .arguments import parse_args, weight_path_from_args, annotations_from_args
from .inference import score_and_write_batch
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py
index 6ab1bf15f33..f985626269c 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py
@@ -31,7 +31,6 @@ def parse_args():
help='Key which maps to an input symbol to index mapping.')
parser.add_argument('--input_symbols', help='Dict mapping input symbols to their index within input tensors, '
+ 'initialised via input_symbols_set argument')
-
parser.add_argument('--batch_size', default=32, type=int,
help='Mini batch size for stochastic gradient descent algorithms.')
parser.add_argument('--read_limit', default=128, type=int,
@@ -73,6 +72,7 @@ def parse_args():
help='Whether to skip positive examples when writing tensors.')
parser.add_argument('--chrom', help='Chromosome to load for parallel tensor writing.')
+
# I/O files and directories: vcfs, bams, beds, hd5, fasta
parser.add_argument('--output_dir', default='./', help='Directory to write models or other data out.')
parser.add_argument('--image_dir', default=None, help='Directory to write images and plots to.')
@@ -111,6 +111,32 @@ def parse_args():
parser.add_argument('--tensor_board', default=False, action='store_true',
help='Add the tensor board callback.')
+ # Architecture defining arguments
+ parser.add_argument('--conv_width', default=5, type=int, help='Width of convolutional kernels.')
+ parser.add_argument('--conv_height', default=5, type=int, help='Height of convolutional kernels.')
+ parser.add_argument('--conv_dropout', default=0.0, type=float,
+ help='Dropout rate in convolutional layers.')
+ parser.add_argument('--conv_batch_normalize', default=False, action='store_true',
+ help='Batch normalize convolutional layers.')
+ parser.add_argument('--conv_layers', nargs='+', default=[128, 96, 64, 48], type=int,
+ help='List of sizes for each convolutional filter layer')
+ parser.add_argument('--padding', default='valid', choices=['valid', 'same'],
+ help='Valid or same border padding for convolutional layers.')
+ parser.add_argument('--spatial_dropout', default=False, action='store_true',
+ help='Spatial dropout on the convolutional layers.')
+ parser.add_argument('--max_pools', nargs='+', default=[], type=int,
+ help='List of max-pooling layers.')
+ parser.add_argument('--fc_layers', nargs='+', default=[32], type=int,
+ help='List of sizes for each fully connected layer')
+ parser.add_argument('--fc_dropout', default=0.0, type=float,
+ help='Dropout rate in fully connected layers.')
+ parser.add_argument('--fc_batch_normalize', default=False, action='store_true',
+ help='Batch normalize fully connected layers.')
+ parser.add_argument('--annotation_units', default=16, type=int,
+ help='Number of units connected to the annotation input layer.')
+ parser.add_argument('--annotation_shortcut', default=False, action='store_true',
+ help='Shortcut connections on the annotations.')
+
# Evaluation related arguments
parser.add_argument('--score_keys', nargs='+', default=['VQSLOD'],
help='List of variant score keys for performance comparisons.')
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py
index 6c43b7c184f..a99baf1f384 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py
@@ -1,5 +1,6 @@
# Imports
import os
+import math
import h5py
import numpy as np
from collections import Counter, defaultdict, namedtuple
@@ -52,7 +53,6 @@ def score_and_write_batch(args, model, file_out, batch_size, python_batch_size,
variant_types = []
variant_data = []
read_batch = []
-
for _ in range(batch_size):
fifo_line = tool.readDataFIFO()
fifo_data = fifo_line.split(defines.SEPARATOR_CHAR)
@@ -117,7 +117,6 @@ def reference_string_to_tensor(reference):
break
else:
raise ValueError('Error! Unknown code:', b)
-
return dna_data
@@ -126,9 +125,8 @@ def annotation_string_to_tensor(args, annotation_string):
name_val_arrays = [p.split('=') for p in name_val_pairs]
annotation_map = {str(p[0]).strip() : p[1] for p in name_val_arrays if len(p) > 1}
annotation_data = np.zeros(( len(defines.ANNOTATIONS[args.annotation_set]),))
-
for i,a in enumerate(defines.ANNOTATIONS[args.annotation_set]):
- if a in annotation_map:
+ if a in annotation_map and not math.isnan(float(annotation_map[a])):
annotation_data[i] = annotation_map[a]
return annotation_data
@@ -434,3 +432,14 @@ def _write_tensor_to_hd5(args, tensor, annotations, contig, pos, variant_type):
with h5py.File(tensor_path, 'w') as hf:
hf.create_dataset(args.tensor_name, data=tensor, compression='gzip')
hf.create_dataset(args.annotation_set, data=annotations, compression='gzip')
+
+def clear_session():
+ try:
+ K.clear_session()
+ K.get_session().close()
+ cfg = K.tf.ConfigProto()
+ cfg.gpu_options.allow_growth = True
+ K.set_session(K.tf.Session(config=cfg))
+ except AttributeError as e:
+ print('Could not clear session. Maybe you are using Theano backend?')
+
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py
index 4b5e3900486..b42f8fd5cc0 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py
@@ -9,7 +9,7 @@
from keras.models import Model, load_model
from keras.layers.convolutional import Conv1D, Conv2D, MaxPooling1D, MaxPooling2D
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
-from keras.layers import Input, Dense, Dropout, BatchNormalization, SpatialDropout2D, Activation, Flatten
+from keras.layers import Input, Dense, Dropout, BatchNormalization, SpatialDropout1D, SpatialDropout2D, Activation, Flatten, AlphaDropout
from . import plots
from . import defines
@@ -17,10 +17,19 @@
from . import tensor_maps
+def start_session_get_args_and_model(intra_ops, inter_ops, semantics_json, weights_hd5=None, tensor_type=None):
+ K.clear_session()
+ K.get_session().close()
+ cfg = K.tf.ConfigProto(intra_op_parallelism_threads=intra_ops, inter_op_parallelism_threads=inter_ops)
+ cfg.gpu_options.allow_growth = True
+ K.set_session(K.tf.Session(config=cfg))
+ return args_and_model_from_semantics(semantics_json, weights_hd5, tensor_type)
+
+
def args_and_model_from_semantics(semantics_json, weights_hd5=None, tensor_type=None):
args = arguments.parse_args()
- if semantics_json is not None:
+ if semantics_json is not None and os.path.exists(semantics_json):
model = set_args_and_get_model_from_semantics(args, semantics_json, weights_hd5)
else:
model = load_model(weights_hd5, custom_objects=get_metric_dict(args.labels))
@@ -83,162 +92,70 @@ def set_args_and_get_model_from_semantics(args, semantics_json, weights_hd5=None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~ Models ~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-def build_reference_annotation_model(args):
- '''Build Reference 1d CNN model for classifying variants with skip connected annotations.
-
- Convolutions followed by dense connection, concatenated with annotations.
- Dynamically sets input channels based on args via tensor_maps.get_tensor_channel_map_from_args(args)
- Uses the functional API.
- Prints out model summary.
-
- Arguments
- args.tensor_name: The name of the tensor mapping which data goes to which channels
- args.annotation_set: The variant annotation set, perhaps from a HaplotypeCaller VCF.
- args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL)
-
- Returns
- The keras model
- '''
- if K.image_data_format() == 'channels_last':
- channel_axis = -1
- else:
- channel_axis = 1
-
- channel_map = tensor_maps.get_tensor_channel_map_from_args(args)
- reference = Input(shape=(args.window_size, len(channel_map)), name=args.tensor_name)
- conv_width = 12
- conv_dropout = 0.1
- fc_dropout = 0.2
- x = Conv1D(filters=256, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(reference)
- x = Conv1D(filters=256, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(x)
- x = Dropout(conv_dropout)(x)
- x = Conv1D(filters=128, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(x)
- x = Dropout(conv_dropout)(x)
- x = Flatten()(x)
-
- annotations = Input(shape=(len(args.annotations),), name=args.annotation_set)
- annos_normed = BatchNormalization(axis=channel_axis)(annotations)
- annos_normed_x = Dense(units=40, kernel_initializer='normal', activation='relu')(annos_normed)
-
- x = layers.concatenate([x, annos_normed_x], axis=channel_axis)
- x = Dense(units=40, kernel_initializer='normal', activation='relu')(x)
- x = Dropout(fc_dropout)(x)
- x = layers.concatenate([x, annos_normed], axis=channel_axis)
-
- prob_output = Dense(units=len(args.labels), kernel_initializer='glorot_normal', activation='softmax')(x)
-
- model = Model(inputs=[reference, annotations], outputs=[prob_output])
-
- adamo = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.)
-
- model.compile(optimizer=adamo, loss='categorical_crossentropy', metrics=get_metrics(args.labels))
- model.summary()
-
- if os.path.exists(args.weights_hd5):
- model.load_weights(args.weights_hd5, by_name=True)
- print('Loaded model weights from:', args.weights_hd5)
-
- return model
-
-
-
-def build_read_tensor_2d_and_annotations_model(args):
- '''Build Read Tensor 2d CNN model with variant annotations mixed in for classifying variants.
-
- 2d Convolutions followed by dense connection mixed with annotation values.
- Dynamically sets input channels based on args via defines.total_input_channels_from_args(args)
- Uses the functional API. Supports theano or tensorflow channel ordering via K.image_data_format().
- Prints out model summary.
-
- Arguments
- args.window_size: Length in base-pairs of sequence centered at the variant to use as input.
- args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL)
-
- Returns
- The keras model
- '''
- in_channels = tensor_maps.total_input_channels_from_args(args)
-
- if K.image_data_format() == 'channels_last':
- in_shape = (args.read_limit, args.window_size, in_channels)
- concat_axis = -1
- else:
- in_shape = (in_channels, args.read_limit, args.window_size)
- concat_axis = 1
-
- read_tensor = Input(shape=in_shape, name=args.tensor_name)
-
- read_conv_width = 16
- conv_dropout = 0.2
- fc_dropout = 0.3
- x = Conv2D(216, (read_conv_width, 1), padding='valid', activation="relu")(read_tensor)
- x = Conv2D(160, (1, read_conv_width), padding='valid', activation="relu")(x)
- x = Conv2D(128, (read_conv_width, 1), padding='valid', activation="relu")(x)
- x = MaxPooling2D((2,1))(x)
- x = Conv2D(96, (1, read_conv_width), padding='valid', activation="relu")(x)
- x = MaxPooling2D((2,1))(x)
- x = Dropout(conv_dropout)(x)
- x = Conv2D(64, (read_conv_width, 1), padding='valid', activation="relu")(x)
- x = MaxPooling2D((2,1))(x)
- x = Dropout(conv_dropout)(x)
-
- x = Flatten()(x)
-
- # Mix the variant annotations in
- annotations = Input(shape=(len(args.annotations),), name=args.annotation_set)
- annotations_bn = BatchNormalization(axis=1)(annotations)
- alt_input_mlp = Dense(units=16, kernel_initializer='glorot_normal', activation='relu')(annotations_bn)
- x = layers.concatenate([x, alt_input_mlp], axis=concat_axis)
-
- x = Dense(units=32, kernel_initializer='glorot_normal', activation='relu')(x)
- x = layers.concatenate([x, annotations_bn], axis=concat_axis)
- x = Dropout(fc_dropout)(x)
-
- prob_output = Dense(units=len(args.labels), kernel_initializer='glorot_normal', activation='softmax')(x)
-
- model = Model(inputs=[read_tensor, annotations], outputs=[prob_output])
-
- adamo = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.)
- model.compile(loss='categorical_crossentropy', optimizer=adamo, metrics=get_metrics(args.labels))
-
- model.summary()
-
- if os.path.exists(args.weights_hd5):
- model.load_weights(args.weights_hd5, by_name=True)
- print('Loaded model weights from:', args.weights_hd5)
-
- return model
-
-
-def build_tiny_2d_annotation_model(args):
+def build_default_1d_annotation_model(args):
+ return build_reference_annotation_1d_model_from_args(args,
+ conv_width=7,
+ conv_layers=[256, 216, 128, 64, 32],
+ conv_dropout=0.1,
+ conv_batch_normalize=True,
+ spatial_dropout=True,
+ max_pools=[],
+ padding='same',
+ annotation_units=64,
+ annotation_shortcut=True,
+ fc_layers=[64, 64],
+ fc_dropout=0.2,
+ annotation_batch_normalize=True,
+ fc_batch_normalize=False)
+
+
+def build_1d_annotation_model_from_args(args):
+ return build_reference_annotation_1d_model_from_args(args,
+ conv_width=args.conv_width,
+ conv_layers=args.conv_layers,
+ conv_dropout=args.conv_dropout,
+ conv_batch_normalize=args.conv_batch_normalize,
+ spatial_dropout=args.spatial_dropout,
+ max_pools=args.max_pools,
+ padding=args.padding,
+ annotation_units=args.annotation_units,
+ annotation_shortcut=args.annotation_shortcut,
+ fc_layers=args.fc_layers,
+ fc_dropout=args.fc_dropout,
+ fc_batch_normalize=args.fc_batch_normalize)
+
+
+def build_2d_annotation_model_from_args(args):
return read_tensor_2d_annotation_model_from_args(args,
- conv_width = 11,
- conv_height = 5,
- conv_layers = [32, 32],
- conv_dropout = 0.0,
- spatial_dropout = False,
- max_pools = [(2,1),(8,1)],
- padding='valid',
- annotation_units = 10,
- annotation_shortcut = False,
- fc_layers = [16],
- fc_dropout = 0.0)
-
-
-def build_small_2d_annotation_model(args):
+ conv_width = args.conv_width,
+ conv_height = args.conv_height,
+ conv_layers = args.conv_layers,
+ conv_dropout = args.conv_dropout,
+ conv_batch_normalize = args.conv_batch_normalize,
+ spatial_dropout = args.spatial_dropout,
+ max_pools = args.max_pools,
+ padding = args.padding,
+ annotation_units = args.annotation_units,
+ annotation_shortcut = args.annotation_shortcut,
+ fc_layers = args.fc_layers,
+ fc_dropout = args.fc_dropout,
+ fc_batch_normalize = args.fc_batch_normalize)
+
+
+def build_default_2d_annotation_model(args):
return read_tensor_2d_annotation_model_from_args(args,
conv_width = 25,
conv_height = 25,
conv_layers = [64, 48, 32, 24],
- conv_dropout = 0.0,
+ conv_dropout = 0.1,
conv_batch_normalize = False,
- spatial_dropout = False,
+ spatial_dropout = True,
max_pools = [(3,1),(3,1)],
padding='valid',
annotation_units = 64,
annotation_shortcut = False,
fc_layers = [24],
- fc_dropout = 0.0,
+ fc_dropout = 0.3,
fc_batch_normalize = False)
@@ -302,11 +219,11 @@ def read_tensor_2d_annotation_model_from_args(args,
cur_kernel = (conv_width, conv_height)
if conv_batch_normalize:
- x = Conv2D(f, cur_kernel, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x)
+ x = Conv2D(int(f), cur_kernel, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x)
x = BatchNormalization(axis=concat_axis)(x)
x = Activation('relu')(x)
else:
- x = Conv2D(f, cur_kernel, activation='relu', padding=padding, kernel_initializer=kernel_initializer)(x)
+ x = Conv2D(int(f), cur_kernel, activation='relu', padding=padding, kernel_initializer=kernel_initializer)(x)
if conv_dropout > 0 and spatial_dropout:
x = SpatialDropout2D(conv_dropout)(x)
@@ -386,6 +303,103 @@ def read_tensor_2d_annotation_model_from_args(args,
return model
+
+def build_reference_annotation_1d_model_from_args(args,
+ conv_width = 6,
+ conv_layers = [128, 128, 128, 128],
+ conv_dropout = 0.0,
+ conv_batch_normalize = False,
+ spatial_dropout = True,
+ max_pools = [],
+ padding='valid',
+ activation = 'relu',
+ annotation_units = 16,
+ annotation_shortcut = False,
+ annotation_batch_normalize = True,
+ fc_layers = [64],
+ fc_dropout = 0.0,
+ fc_batch_normalize = False,
+ fc_initializer = 'glorot_normal',
+ kernel_initializer = 'glorot_normal',
+ alpha_dropout = False
+ ):
+ '''Build Reference 1d CNN model for classifying variants.
+
+ Architecture specified by parameters.
+ Dynamically sets input channels based on args via defines.total_input_channels_from_args(args)
+ Uses the functional API.
+ Prints out model summary.
+
+ Arguments
+ args.annotations: The variant annotations, perhaps from a HaplotypeCaller VCF.
+ args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL)
+
+ Returns
+ The keras model
+ '''
+ in_channels = tensor_maps.total_input_channels_from_args(args)
+ concat_axis = -1
+ x = reference = Input(shape=(args.window_size, in_channels), name=args.tensor_name)
+
+ max_pool_diff = len(conv_layers)-len(max_pools)
+ for i,c in enumerate(conv_layers):
+
+ if conv_batch_normalize:
+ x = Conv1D(filters=c, kernel_size=conv_width, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x)
+ x = BatchNormalization(axis=concat_axis)(x)
+ x = Activation(activation)(x)
+ else:
+ x = Conv1D(filters=c, kernel_size=conv_width, activation=activation, padding=padding, kernel_initializer=kernel_initializer)(x)
+
+ if conv_dropout > 0 and alpha_dropout:
+ x = AlphaDropout(conv_dropout)(x)
+ elif conv_dropout > 0 and spatial_dropout:
+ x = SpatialDropout1D(conv_dropout)(x)
+ elif conv_dropout > 0:
+ x = Dropout(conv_dropout)(x)
+
+ if i >= max_pool_diff:
+ x = MaxPooling1D(max_pools[i-max_pool_diff])(x)
+
+ f = Flatten()(x)
+
+ annotations = annotations_in = Input(shape=(len(args.annotations),), name=args.annotation_set)
+ if annotation_batch_normalize:
+ annotations_in = BatchNormalization(axis=concat_axis)(annotations_in)
+ annotation_mlp = Dense(units=annotation_units, kernel_initializer=fc_initializer, activation=activation)(annotations_in)
+
+ x = layers.concatenate([f, annotation_mlp], axis=1)
+ for fc in fc_layers:
+ if fc_batch_normalize:
+ x = Dense(units=fc, activation='linear', kernel_initializer=fc_initializer)(x)
+ x = BatchNormalization(axis=1)(x)
+ x = Activation(activation)(x)
+ else:
+ x = Dense(units=fc, activation=activation, kernel_initializer=fc_initializer)(x)
+
+ if fc_dropout > 0 and alpha_dropout:
+ x = AlphaDropout(fc_dropout)(x)
+ elif fc_dropout > 0:
+ x = Dropout(fc_dropout)(x)
+
+ if annotation_shortcut:
+ x = layers.concatenate([x, annotations_in], axis=1)
+
+ prob_output = Dense(units=len(args.labels), activation='softmax', name='softmax_predictions')(x)
+
+ model = Model(inputs=[reference, annotations], outputs=[prob_output])
+
+ adam = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.)
+ model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=get_metrics(args.labels))
+ model.summary()
+
+ if os.path.exists(args.weights_hd5):
+ model.load_weights(args.weights_hd5, by_name=True)
+ print('Loaded model weights from:', args.weights_hd5)
+
+ return model
+
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~ Optimizing ~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -411,16 +425,18 @@ def train_model_from_generators(args, model, generate_train, generate_valid, sav
'''
if not os.path.exists(os.path.dirname(save_weight_hd5)):
os.makedirs(os.path.dirname(save_weight_hd5))
+ serialize_model_semantics(args, save_weight_hd5)
history = model.fit_generator(generate_train,
steps_per_epoch=args.training_steps, epochs=args.epochs, verbose=1,
validation_steps=args.validation_steps, validation_data=generate_valid,
callbacks=get_callbacks(args, save_weight_hd5))
+ print('Training complete, model weights saved at: %s' % save_weight_hd5)
if args.image_dir:
plots.plot_metric_history(history, plots.weight_path_to_title(save_weight_hd5), prefix=args.image_dir)
- serialize_model_semantics(args, save_weight_hd5)
- print('Model weights saved at: %s' % save_weight_hd5)
+
+
return model
diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py
index 3ecef9ab2bc..ee9fc749d7a 100644
--- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py
+++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py
@@ -6,7 +6,7 @@
def get_tensor_channel_map_from_args(args):
'''Return tensor mapping dict given args.tensor_name'''
- if not args.tensor_name:
+ if args.tensor_name is None:
return None
if 'read_tensor' == args.tensor_name:
diff --git a/src/main/resources/large/cnn_score_variants/small_2d.hd5 b/src/main/resources/large/cnn_score_variants/small_2d.hd5
index 5c007f29b26..deb36d22e04 100644
--- a/src/main/resources/large/cnn_score_variants/small_2d.hd5
+++ b/src/main/resources/large/cnn_score_variants/small_2d.hd5
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:acd9efef4700826a8244a550c40b42bf95dbeafeac5a05bab0ec8d0f353bb80c
-size 6410504
+oid sha256:6f663a2fdbcde0addc5cb755f7af5d4c19bed92dccfd20e25b2acf2bc8c2ca7c
+size 2163096
diff --git a/src/main/resources/large/cnn_score_variants/small_2d.json b/src/main/resources/large/cnn_score_variants/small_2d.json
index 648565cb199..c35cfbdfcae 100644
--- a/src/main/resources/large/cnn_score_variants/small_2d.json
+++ b/src/main/resources/large/cnn_score_variants/small_2d.json
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:226c0344e590051c67a61873733bfafef3db5a0b8e95015b0b688a542199142b
-size 720
+oid sha256:e38e09cfe7b7ffbc80dce4972bc9c382148520147d46738a3f6f3235b2d876c6
+size 758
diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py b/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py
index 4fe90ac70fd..574a50ae97e 100644
--- a/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py
+++ b/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py
@@ -1,10 +1,8 @@
# Imports
import os
-import sys
import vcf
import math
import h5py
-import time
import pysam
import vqsr_cnn
import numpy as np
@@ -14,20 +12,21 @@
# Keras Imports
import keras.backend as K
+
def run():
args = vqsr_cnn.parse_args()
if 'write_reference_and_annotation_tensors' == args.mode:
write_reference_and_annotation_tensors(args)
elif 'write_read_and_annotation_tensors' == args.mode:
write_read_and_annotation_tensors(args)
- elif 'train_on_reference_tensors_and_annotations' == args.mode:
- train_on_reference_tensors_and_annotations(args)
- elif 'train_on_read_tensors_and_annotations' == args.mode:
- train_on_read_tensors_and_annotations(args)
- elif 'train_tiny_model_on_read_tensors_and_annotations' == args.mode:
- train_tiny_model_on_read_tensors_and_annotations(args)
- elif 'train_small_model_on_read_tensors_and_annotations' == args.mode:
- train_small_model_on_read_tensors_and_annotations(args)
+ elif 'train_default_1d_model' == args.mode:
+ train_default_1d_model(args)
+ elif 'train_default_2d_model' == args.mode:
+ train_default_2d_model(args)
+ elif 'train_args_model_on_read_tensors_and_annotations' == args.mode:
+ train_args_model_on_read_tensors_and_annotations(args)
+ elif 'train_args_model_on_reference_and_annotations' == args.mode:
+ train_args_model_on_read_tensors_and_annotations(args)
else:
raise ValueError('Unknown training mode:', args.mode)
@@ -37,21 +36,15 @@ def write_reference_and_annotation_tensors(args, include_dna=True, include_annot
raise ValueError('Unknown tensor name:', args.tensor_name, '1d maps must be in:', str(vqsr_cnn.TENSOR_MAPS_1D))
record_dict = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta"))
- if os.path.splitext(args.input_vcf)[-1].lower() == '.gz':
- vcf_reader = vcf.Reader(open(args.input_vcf, 'rb'))
- else:
- vcf_reader = vcf.Reader(open(args.input_vcf, 'r'))
- if os.path.splitext(args.train_vcf)[-1].lower() == '.gz':
- vcf_ram = vcf.Reader(open(args.train_vcf, 'rb'))
- else:
- vcf_ram = vcf.Reader(open(args.train_vcf, 'r'))
+ vcf_reader = get_vcf_reader(args.input_vcf)
+ vcf_ram = get_vcf_reader(args.train_vcf)
bed_dict = bed_file_to_dict(args.bed_file)
stats = Counter()
if args.chrom:
- variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos)
+ variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos)
else:
variants = vcf_reader
@@ -106,7 +99,6 @@ def write_reference_and_annotation_tensors(args, include_dna=True, include_annot
print(k, ' has:', stats[k])
-
def write_read_and_annotation_tensors(args, include_annotations=True, pileup=False):
'''Create tensors structured as tensor map of reads organized by labels in the data directory.
@@ -134,8 +126,8 @@ def write_read_and_annotation_tensors(args, include_annotations=True, pileup=Fal
samfile = pysam.AlignmentFile(args.bam_file, "rb")
bed_dict = bed_file_to_dict(args.bed_file)
record_dict = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta"))
- vcf_reader = vcf.Reader(open(args.input_vcf, 'r'))
- vcf_ram = vcf.Reader(open(args.train_vcf, 'rb'))
+ vcf_reader = get_vcf_reader(args.input_vcf)
+ vcf_ram = get_vcf_reader(args.train_vcf)
if args.chrom:
variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos)
@@ -205,7 +197,7 @@ def write_read_and_annotation_tensors(args, include_annotations=True, pileup=Fal
print('Done generating tensors. Last variant:', str(variant), 'from vcf:', args.input_vcf)
-def train_on_reference_tensors_and_annotations(args):
+def train_default_1d_model(args):
'''Train a 1D Convolution plus reference tracks and MLP Annotation architecture.
Arguments:
@@ -223,7 +215,7 @@ def train_on_reference_tensors_and_annotations(args):
generate_valid = dna_annotation_generator(args, valid_paths)
weight_path = vqsr_cnn.weight_path_from_args(args)
- model = vqsr_cnn.build_reference_annotation_model(args)
+ model = vqsr_cnn.build_default_1d_annotation_model(args)
model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path)
test = load_dna_annotations_positions_from_class_dirs(args, test_paths, per_class_max=args.samples)
@@ -231,8 +223,7 @@ def train_on_reference_tensors_and_annotations(args):
vqsr_cnn.plot_roc_per_class(model, [test[0], test[1]], test[2], args.labels, args.id, prefix=args.image_dir)
-
-def train_on_read_tensors_and_annotations(args):
+def train_default_2d_model(args):
'''Trains a reference, read, and annotation CNN architecture on tensors at the supplied data directory.
This architecture looks at reads, read flags, reference sequence, and variant annotations.
@@ -251,7 +242,7 @@ def train_on_read_tensors_and_annotations(args):
generate_valid = tensor_generator_from_label_dirs_and_args(args, valid_paths)
weight_path = vqsr_cnn.weight_path_from_args(args)
- model = vqsr_cnn.build_read_tensor_2d_and_annotations_model(args)
+ model = vqsr_cnn.build_default_2d_annotation_model(args)
model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path)
test = load_tensors_and_annotations_from_class_dirs(args, test_paths, per_class_max=args.samples)
@@ -260,7 +251,7 @@ def train_on_read_tensors_and_annotations(args):
prefix=args.image_dir, batch_size=args.batch_size)
-def train_tiny_model_on_read_tensors_and_annotations(args):
+def train_args_model_on_read_tensors_and_annotations(args):
'''Trains a reference, read, and annotation CNN architecture on tensors at the supplied data directory.
This architecture looks at reads, read flags, reference sequence, and variant annotations.
@@ -279,7 +270,7 @@ def train_tiny_model_on_read_tensors_and_annotations(args):
generate_valid = tensor_generator_from_label_dirs_and_args(args, valid_paths)
weight_path = vqsr_cnn.weight_path_from_args(args)
- model = vqsr_cnn.build_tiny_2d_annotation_model(args)
+ model = vqsr_cnn.build_2d_annotation_model_from_args(args)
model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path)
test = load_tensors_and_annotations_from_class_dirs(args, test_paths, per_class_max=args.samples)
@@ -712,6 +703,7 @@ def get_true_label(allele, variant, bed_dict, truth_vcf, stats):
NOT_INDEL if variant is indel and not in truth vcf
'''
in_bed = in_bed_file(bed_dict, variant.CHROM, variant.POS)
+
if allele_in_vcf(allele, variant, truth_vcf) and in_bed:
class_prefix = ''
elif in_bed:
@@ -815,9 +807,14 @@ def bed_file_to_dict(bed_file):
def in_bed_file(bed_dict, contig, pos):
- # Exclusive
+
+ if not contig in bed_dict:
+ return False
+
lows = bed_dict[contig][0]
ups = bed_dict[contig][1]
+
+ # Half open interval [#,#)
return np.any((lows <= pos) & (pos < ups))
@@ -832,7 +829,14 @@ def allele_in_vcf(allele, variant, vcf_ram):
Returns
variant if it is found otherwise None
'''
- variants = vcf_ram.fetch(variant.CHROM, variant.POS-1, variant.POS)
+ if not variant.CHROM in vcf_ram.contigs:
+ return None
+
+ try:
+ variants = vcf_ram.fetch(variant.CHROM, variant.POS-1, variant.POS)
+ except ValueError as e:
+ print('catching value error on fetch')
+ return None
for v in variants:
if v.CHROM == variant.CHROM and v.POS == variant.POS and allele in v.ALT:
@@ -1145,6 +1149,12 @@ def plain_name(full_name):
name = os.path.basename(full_name)
return name.split('.')[0]
+def get_vcf_reader(my_vcf):
+ if os.path.splitext(my_vcf)[-1].lower() == '.gz':
+ return vcf.Reader(open(my_vcf, 'rb'))
+ else:
+ return vcf.Reader(open(my_vcf, 'r'))
+
# Back to the top!
if "__main__" == __name__:
diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java
new file mode 100644
index 00000000000..75a247f14bf
--- /dev/null
+++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java
@@ -0,0 +1,65 @@
+package org.broadinstitute.hellbender.tools.walkers.validation;
+
+import htsjdk.variant.variantcontext.VariantContext;
+import org.broadinstitute.hellbender.CommandLineProgramTest;
+import org.broadinstitute.hellbender.engine.AbstractConcordanceWalker;
+import org.broadinstitute.hellbender.testutils.ArgumentsBuilder;
+import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+import java.nio.file.Path;
+
+public class EvaluateInfoFieldConcordanceIntegrationTest extends CommandLineProgramTest {
+ final double epsilon = 1e-3;
+
+ @Test(dataProvider= "infoConcordanceDataProvider")
+ public void testInfoConcordanceFromProvider(String inputVcf1, String inputVcf2, String evalKey, String truthKey,
+ double snpMean, double snpSTD,
+ double indelMean, double indelSTD) throws Exception {
+ final Path summary = createTempPath("summary", ".txt");
+ final ArgumentsBuilder argsBuilder = new ArgumentsBuilder();
+ argsBuilder.addArgument(AbstractConcordanceWalker.EVAL_VARIANTS_SHORT_NAME, inputVcf1)
+ .addArgument(AbstractConcordanceWalker.TRUTH_VARIANTS_LONG_NAME, inputVcf2)
+ .addArgument("eval-info-key", evalKey)
+ .addArgument("truth-info-key", truthKey)
+ .addArgument(EvaluateInfoFieldConcordance.SUMMARY_LONG_NAME, summary.toString());
+ runCommandLine(argsBuilder);
+
+ try(InfoConcordanceRecord.InfoConcordanceReader
+ reader = new InfoConcordanceRecord.InfoConcordanceReader(summary)) {
+ InfoConcordanceRecord snpRecord = reader.readRecord();
+ InfoConcordanceRecord indelRecord = reader.readRecord();
+
+ Assert.assertEquals(snpRecord.getVariantType(), VariantContext.Type.SNP);
+ Assert.assertEquals(indelRecord.getVariantType(), VariantContext.Type.INDEL);
+
+ Assert.assertEquals(snpRecord.getMean(), snpMean, epsilon);
+ Assert.assertEquals(snpRecord.getStd(), snpSTD, epsilon);
+ Assert.assertEquals(indelRecord.getMean(), indelMean, epsilon);
+ Assert.assertEquals(indelRecord.getStd(), indelSTD, epsilon);
+ }
+ }
+
+ @DataProvider
+ public Object[][] infoConcordanceDataProvider() {
+ return new Object [][]{
+ new Object[]{
+ largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_gpu2.vcf",
+ largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_gpu2.vcf",
+ GATKVCFConstants.CNN_2D_KEY,
+ "NOVA_HISEQ_MIX_SMALL",
+ 0.108878, 0.229415, 0.067024, 0.142705 // numbers verified by manual inspection
+
+ },
+ new Object[]{
+ largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_cpu.vcf",
+ largeFileTestDir + "VQSR/expected/chr20_tiny_th_python_gpu.vcf",
+ GATKVCFConstants.CNN_1D_KEY,
+ "NOVA_HISEQ_MIX_1D_RAB",
+ 0.000256, 0.000136, 0.000240, 0.000153 // numbers verified by manual inspection
+ }
+ };
+ }
+}
diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf
new file mode 100644
index 00000000000..e4941b9bd9b
--- /dev/null
+++ b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4e21b9c31936d59a5b4a52756254d932b9b63db418a918567538e70215c52e9
+size 171577
diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf
new file mode 100644
index 00000000000..b14fb5b7bf7
--- /dev/null
+++ b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6338d032989dc6b9168bc94ee1da85c666ee1de995c3859046a65b3dba610350
+size 177316
diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf
new file mode 100755
index 00000000000..f8c92ca64b7
--- /dev/null
+++ b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f20647108452f717b84c7403d510d919fbbd83c7ce2ddf5545534fe6e8ee08f
+size 155325
diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf
new file mode 100644
index 00000000000..333e6644636
--- /dev/null
+++ b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebffb2b960005d08830e3b879350ec1a75f1121a3be1d0205a79cbbe24cd94e4
+size 169754
diff --git a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf
index 438dc0b4f67..8de9528ce1e 100644
--- a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf
+++ b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:8f4a4ccf2ebbaed80ee40313f0f50be4287b947fc1e331159de871f7f0856c7c
-size 150076
+oid sha256:6fe2f366e19080280fc8666b9fdc9f98e1143aa71b3ee27bd8002ccd6d4055d3
+size 150048
diff --git a/src/test/resources/large/VQSR/expected/nn_outy2d.vcf b/src/test/resources/large/VQSR/expected/nn_outy2d.vcf
new file mode 100644
index 00000000000..073956299af
--- /dev/null
+++ b/src/test/resources/large/VQSR/expected/nn_outy2d.vcf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b58efd72c6b0af998ff7c8450f091119a46a3c0f7d4760e082709770f776d46
+size 163170
diff --git a/src/test/resources/large/VQSR/small_2d.hd5 b/src/test/resources/large/VQSR/small_2d.hd5
index 5c007f29b26..deb36d22e04 100644
--- a/src/test/resources/large/VQSR/small_2d.hd5
+++ b/src/test/resources/large/VQSR/small_2d.hd5
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:acd9efef4700826a8244a550c40b42bf95dbeafeac5a05bab0ec8d0f353bb80c
-size 6410504
+oid sha256:6f663a2fdbcde0addc5cb755f7af5d4c19bed92dccfd20e25b2acf2bc8c2ca7c
+size 2163096
diff --git a/src/test/resources/large/VQSR/small_2d.json b/src/test/resources/large/VQSR/small_2d.json
index 648565cb199..c35cfbdfcae 100644
--- a/src/test/resources/large/VQSR/small_2d.json
+++ b/src/test/resources/large/VQSR/small_2d.json
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:226c0344e590051c67a61873733bfafef3db5a0b8e95015b0b688a542199142b
-size 720
+oid sha256:e38e09cfe7b7ffbc80dce4972bc9c382148520147d46738a3f6f3235b2d876c6
+size 758