Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNNScoreVariants out of beta #5548

Merged
merged 1 commit into from
Jan 29, 2019
Merged

CNNScoreVariants out of beta #5548

merged 1 commit into from
Jan 29, 2019

Conversation

lucidtronix
Copy link
Contributor

Add PEP8 python style with type hints and use model directories instead of separate arguments for config and weights.

@codecov-io
Copy link

codecov-io commented Dec 21, 2018

Codecov Report

Merging #5548 into master will increase coverage by 35.717%.
The diff coverage is 96.032%.

@@               Coverage Diff                @@
##              master     #5548        +/-   ##
================================================
+ Coverage     36.838%   72.555%   +35.717%     
- Complexity     17409     26182      +8773     
================================================
  Files           1934      1934                
  Lines         145691    145752        +61     
  Branches       16103     16106         +3     
================================================
+ Hits           53670    105751     +52081     
+ Misses         87181     34866     -52315     
- Partials        4840      5135       +295
Impacted Files Coverage Δ Complexity Δ
...der/tools/walkers/vqsr/CNNVariantPipelineTest.java 100% <100%> (ø) 8 <1> (ø) ⬇️
.../walkers/vqsr/CNNScoreVariantsIntegrationTest.java 100% <100%> (ø) 13 <7> (+3) ⬆️
...ellbender/tools/walkers/vqsr/CNNScoreVariants.java 80.444% <89.583%> (+6.736%) 45 <15> (+4) ⬆️
...nder/tools/copynumber/utils/TagGermlineEvents.java 0% <0%> (-100%) 0% <0%> (-3%)
...r/tools/spark/pathseq/PSBwaArgumentCollection.java 0% <0%> (-100%) 0% <0%> (-1%)
...ender/tools/readersplitters/ReadGroupSplitter.java 0% <0%> (-100%) 0% <0%> (-3%)
...ools/funcotator/filtrationRules/ClinVarFilter.java 0% <0%> (-100%) 0% <0%> (-5%)
...ls/walkers/varianteval/stratifications/Sample.java 0% <0%> (-100%) 0% <0%> (-4%)
...nes/metrics/QualityYieldMetricsCollectorSpark.java 0% <0%> (-100%) 0% <0%> (-7%)
...lkers/varianteval/util/SortableJexlVCMatchExp.java 0% <0%> (-100%) 0% <0%> (-2%)
... and 1381 more

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpointing part one of this review - I still have some more comments to come on the python code and test code but am saving what I have to so far so as to not hold things up. The type hints are a big improvement towards readability though!

@@ -90,11 +95,9 @@
* -inference-batch-size 2 \
* -transfer-batch-size 2 \
* -tensor-type read-tensor \
* -architecture path/to/my_model.json \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples above still use inference/batch size of 2. These (size) arguments are @Advanced, so the basic javadoc examples shouldn't refer to them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the javadoc should have an example using the simplest, all-default-args case, and we should have a corresponding test case for that (I think there is one that is very close).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The javadoc should explain that the tool is intended to be used for single-sample only, and since we only warn on multiple samples, something about what results to expect when using multiple samples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK the issue is that the best batch size for 1D is different from the bast batch size for 2D. I'll automatically set them to different defaults when they are not supplied on the command line.

The first example is the simplest possible way to run the tool.

Added a comment about single-sample.


@Argument(fullName = "weights", shortName = "weights", doc = "Keras model HD5 file with neural net weights.", optional = true)
private String weights;
@Argument(fullName = "model-dir", shortName = "model", doc = "Directory containing Neural Net architecture and configuration json file", optional = true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add something that says "if not supplied the default model is used" or some such.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;

@Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Name of the set of annotations to use", optional = true)
private String annotationSet = DEFAULT_ANNOTATION_SET;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it stands, nobody will know what to do with this argument. The other valid values either need to be documented (perhaps with a ClpEnum), and tests added, or else this arg should be removed, or at least @Hidden. If we do document any of the other sets, we'll need to add a test for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it and now the annotation set is controlled only by java.

@Argument(fullName = "weights", shortName = "weights", doc = "Keras model HD5 file with neural net weights.", optional = true)
private String weights;
@Argument(fullName = "model-dir", shortName = "model", doc = "Directory containing Neural Net architecture and configuration json file", optional = true)
private String modelDir;

@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we keeping 1d as the default ? If we decide to change to 2d, then the examples in the javadoc above will have to change to reflect that

@@ -209,7 +217,7 @@
return new String[]{"Inference batch size must be less than or equal to transfer batch size."};
}

if (weights == null && architecture == null){
if (modelDir == null){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the tensor-type test below be unconditional (model== null is irrelevant ?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now yes, but this way we can support new tensor types without requiring them to have default model

curBatchSize,
inferenceBatchSize,
tensorType,
annotationSet,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python code is based on the name of the (predefined) annotation passed here, but the java code serializes the set of annotations defined by the annotationKeys list from the command line. These two arguments are overlapping/redundant, and need to be consolidated or kept in sync somehow. It would be very easy to construct a command line that causes these to be out of sync.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

def score_and_write_batch(args, model, file_out, batch_size, python_batch_size, tensor_dir):
'''Score a batch of variants with a CNN model. Write tab delimited temp file with scores.
def score_and_write_batch(model: keras.Model,
file_out: TextIO,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this code is called by the Java code, the python statement retrieves the values model and file_out from the global namespace, since they're originally stored there when start_session_get_args_and_model is called. Its a bit weird to have to pass them in each time like this.

One option would be to create an instance of an inference wrapper class (see what I did for the FIFO code in tool.py), and have a function for init which creates it and stores it in the global namespace, and score and close functions that delegate to the global instance. I'm not sure how pythonic it is, but all of the code called by java would be in a single module, and there would only be a single variable in the global namespace, with no call redundancies.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, thanks for adding the type hints - they definitely improve readability...

"CNNScoreVariantsWorkflow.bam_file": "/home/travis/build/broadinstitute/gatk/src/test/resources/large/VQSR/g94982_chr20_1m_10m_bamout.bam",
"CNNScoreVariantsWorkflow.bam_file_index": "/home/travis/build/broadinstitute/gatk/src/test/resources/large/VQSR/g94982_chr20_1m_10m_bamout.bai",
"CNNScoreVariantsWorkflow.bam_file": "/home/travis/build/broadinstitute/gatk/src/test/resources/large/VQSR/g94982_b37_chr20_1m_895_bamout.bam",
"CNNScoreVariantsWorkflow.bam_file_index": "/home/travis/build/broadinstitute/gatk/src/test/resources/large/VQSR/g94982_b37_chr20_1m_895_bamout.bai",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do any tests for these exist ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are tested in the cnn_variant cromwell job.

@Argument(fullName = "window-size", shortName = "window-size", doc = "Neural Net input window size", minValue = 0, optional = true)
private int windowSize = 128;

@Argument(fullName = "read-limit", shortName = "read-limit", doc = "Maximum number of reads to encode in a tensor, for 2D models only.", minValue = 0, optional = true)
private int readLimit = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better name we can use for this arg/variable that matches the other tools that do this (maybe downsample-reads or something) ?

@@ -452,9 +465,13 @@ private String getVariantInfoString(final VariantContext variant) {

private void executePythonCommand() {
final String pythonCommand = String.format(
"vqsr_cnn.score_and_write_batch(args, model, tempFile, %d, %d, '%s')",
"vqsr_cnn.score_and_write_batch(model, tempFile, %d, %d, '%s', '%s', %d, %d, '%s')",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments in the python code.

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional comment - we need to think about how to version the models.

@sooheelee
Copy link
Contributor

@lucidtronix @cmnbroad, I see for v4.0.12.0, CNNScoreVariants falls under the EXPERIMENTAL Tool label. When you say the tool will come out of beta, do you mean there will be a change in this label or something else? I'm writing a document that links to the CNN workflow and need to be clear on the status of the workflow. Thanks.

@cmnbroad
Copy link
Collaborator

cmnbroad commented Jan 9, 2019

@sooheelee Right now the tool is marked @Experimental. The goal is to get this tool into production status for 4.1 , with no @Experimental or @Beta tags.

@sooheelee
Copy link
Contributor

Thanks for clarifying @cmnbroad. Have to say skipping @Beta and going directly from experimental to production is unusual. Congratulations.

@lucidtronix
Copy link
Contributor Author

@cmnbroad back to you

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, we're getting closer. There is still a fair amount of code cleanup that should be done longer term, especially on the Python side - but I commented on what I think is the minimum for now, keeping just to the inference code. Especially where there are hardcoded magic numbers and values that have to be kept in sync between Java and Python - those should be moved to constants on both sides and commented to make that relationship explicit. Back to @lucidtronix.

@@ -116,30 +118,34 @@
" If you have an older (pre-1.6) version of TensorFlow installed that does not require AVX you may attempt to re-run the tool with the %s argument to bypass this check.\n" +
" Note that such configurations are not officially supported.";

private static final String DEFAULT_ANNOTATION_SET = "best_practices";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused now, and can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

transferBatchSize = Math.max(transferBatchSize, MAX_BATCH_SIZE_1D);
inferenceBatchSize = Math.max(inferenceBatchSize, MAX_BATCH_SIZE_1D);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also require that the transfer size is an integral multiple of the inference size ? It would probably work without that, but would be inefficient, since the python size would wind up doing some smaller batches. If you think thats rigt, the doc for those args should be updated to mention that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this will only impact one batch of python inference I think we wont notice the inefficiency

@@ -273,6 +293,7 @@ public void onTraversalStart() {
pythonExecutor.sendSynchronousCommand("import vqsr_cnn" + NL);

scoreKey = getScoreKeyAndCheckModelAndReadsHarmony();
annotationSetString = this.annotationKeys.toString().replace(" ", "").replace("[", "").replace("]", "");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest the more idiomatic:

Suggested change
annotationSetString = this.annotationKeys.toString().replace(" ", "").replace("[", "").replace("]", "");
annotationSetString = annotationKeys.stream().collect(Collectors.joining(","));

Also, the "joining" arg should be in a constant, with a comment saying that the value has to be kept in sync with the corresponding Python constant that is used to parse these lines. The corresponding constant will have to be added to the python code, maybe in defines.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also do the same thing (add symbolic constants) for the hardcoded values like "\t" used by GATKReadToString, getVariantDataString, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats is much better, thanks.

Added constants for the comma, tab, equals and semi-colon. Also simplified getVariantInfoString fxn, which was needlessly sending single valued annotations as if they were lists.

tensorType,
annotationSet,
windowSize,
readLimit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR is #5594, once its reviewed we'll add code to call it to this tool.

}

@Test(groups = {"python"})
public void testInferenceWithWeightsOnly() throws IOException{
public void testInferenceWithWeightOverride() throws IOException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test differ from the testInference test ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets rename this to testInferenceWithModelOverride.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't, but I removed the model overide from testInference so now it does and renamed it.

'D':[0.333,0,0.333,0.334], 'X':[0.25,0.25,0.25,0.25], 'N':[0.25,0.25,0.25,0.25]
'K': [0, 0, 0.5, 0.5], 'M': [0.5, 0.5, 0, 0], 'R': [0.5, 0, 0, 0.5], 'Y': [0, 0.5, 0.5, 0], 'S': [0, 0.5, 0, 0.5],
'W': [0.5, 0, 0.5, 0], 'B': [0, 0.333, 0.333, 0.334], 'V': [0.333, 0.333, 0, 0.334], 'H': [0.333, 0.333, 0.334, 0],
'D': [0.333, 0, 0.333, 0.334], 'X': [0.25, 0.25, 0.25, 0.25], 'N': [0.25, 0.25, 0.25, 0.25]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to add comments describing what the values in the dictionaries mean, and rename AMBIGUITY_CODES to reflect that use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more comments explaining, but since these are defined by IUPAC as ambiguity codes I want to keep the the dictionary name as that.

'D':[0.333,0,0.333,0.334], 'X':[0.25,0.25,0.25,0.25], 'N':[0.25,0.25,0.25,0.25]
'K': [0, 0, 0.5, 0.5], 'M': [0.5, 0.5, 0, 0], 'R': [0.5, 0, 0, 0.5], 'Y': [0, 0.5, 0.5, 0], 'S': [0, 0.5, 0, 0.5],
'W': [0.5, 0, 0.5, 0], 'B': [0, 0.333, 0.333, 0.334], 'V': [0.333, 0.333, 0, 0.334], 'H': [0.333, 0.333, 0.334, 0],
'D': [0.333, 0, 0.333, 0.334], 'X': [0.25, 0.25, 0.25, 0.25], 'N': [0.25, 0.25, 0.25, 0.25]
}


# Annotation sets
ANNOTATIONS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be better named ANNOTATIONS_SETS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


CODE2CIGAR = 'MIDNSHP=XB'
CIGAR2CODE = dict([y, x] for x, y in enumerate(CODE2CIGAR))
CIGAR_CODE = {'M':0, 'I':1, 'D':2, 'N':3, 'S':4}
CIGAR_CODE = {'M': 0, 'I': 1, 'D': 2, 'N': 3, 'S': 4}
CIGAR_REGEX = re.compile("(\d+)([MIDNSHP=XB])")

SKIP_CHAR = '~'
INDEL_CHAR = '*'
SEPARATOR_CHAR = '\t'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SEPARATOR_CHAR is, I think, intended to be the FIFO separator char, and would be better named to reflect that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -59,40 +72,40 @@ def score_and_write_batch(args, model, file_out, batch_size, python_batch_size,

variant_data.append(fifo_data[0] + '\t' + fifo_data[1] + '\t' + fifo_data[2] + '\t' + fifo_data[3])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index values should be replaced with symbolic constants with the name of the field, and comments added saying these need to kept in sync with the code n the Java side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

variant_types.append(fifo_data[6].strip())

fidx = 7 # 7 Because above we parsed: contig pos ref alt reference_string annotation variant_type
if args.tensor_name in defines.TENSOR_MAPS_2D and len(fifo_data) > fidx:
fidx = 7 # 7 Because above we parsed: contig pos ref alt reference_string annotation variant_type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a symbolic constant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@lucidtronix
Copy link
Contributor Author

@cmnbroad Thanks for the review, back to you!

@cmnbroad
Copy link
Collaborator

Looks like we're mostly there, except for adding a model version, which @lucidtronix is working on, and java downsampling, which probably won't make it.

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lucidtronix. We should hold off merging until tests pass on this branch (current failures are unrelated) and master is back up (currently master is failing too). Then if we get a chance we can add Java downsampling if #5594 gets approved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants