diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java index dc98d99072e..46fcac10bd5 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java @@ -295,21 +295,31 @@ protected void nthPassApply(final VariantContext variant, variant, featureContext, unlabeledDataReservoir != null); final boolean isVariantExtracted = !metadata.isEmpty(); if (isVariantExtracted) { - final boolean isUnlabeled = metadata.stream().map(Triple::getRight).allMatch(Set::isEmpty); - if (!isUnlabeled) { - addExtractedVariantToData(data, variant, metadata); - writeExtractedVariantToVCF(variant, metadata); - } else { - // Algorithm R for reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm - if (unlabeledIndex < maximumNumberOfUnlabeledVariants) { - addExtractedVariantToData(unlabeledDataReservoir, variant, metadata); - } else { - final int j = rng.nextInt(unlabeledIndex); - if (j < maximumNumberOfUnlabeledVariants) { - setExtractedVariantInData(unlabeledDataReservoir, variant, metadata, j); + // metadata may contain a mix of labeled and unlabeled alleles (e.g., when extracting unlabeled variants in allele-specific mode) + // which we separate here accordingly + final List, VariantType, TreeSet>> labeledMetadata = metadata.stream() + .filter(m -> !m.getRight().isEmpty()) + .collect(Collectors.toList()); + if (!labeledMetadata.isEmpty()) { + addExtractedVariantToData(data, variant, labeledMetadata); + writeExtractedVariantToVCF(variant, labeledMetadata); + } + if (unlabeledDataReservoir != null) { + final List, VariantType, TreeSet>> unlabeledMetadata = metadata.stream() + .filter(m -> m.getRight().isEmpty()) + .collect(Collectors.toList()); + if (!unlabeledMetadata.isEmpty()) { + // Algorithm R for reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm + if (unlabeledIndex < maximumNumberOfUnlabeledVariants) { + addExtractedVariantToData(unlabeledDataReservoir, variant, unlabeledMetadata); + } else { + final int j = rng.nextInt(unlabeledIndex); + if (j < maximumNumberOfUnlabeledVariants) { + setExtractedVariantInData(unlabeledDataReservoir, variant, unlabeledMetadata, j); + } } + unlabeledIndex++; } - unlabeledIndex++; } } } @@ -359,7 +369,7 @@ private void writeUnlabeledAnnotationsToHDF5() { logger.info(String.format("Extracted unlabeled annotations for %d variants of type %s.", unlabeledDataReservoir.getVariantTypeFlat().stream().mapToInt(t -> t == variantType ? 1 : 0).sum(), variantType)); } - logger.info(String.format("Extracted unlabeled annotations for %s total variants.", unlabeledDataReservoir.size())); + logger.info(String.format("Extracted unlabeled annotations for %s total records.", unlabeledDataReservoir.size())); logger.info("Writing unlabeled annotations..."); // TODO coordinate sort diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java index b90c2f91a64..11aa1680f08 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java @@ -276,7 +276,7 @@ void writeAnnotationsToHDF5() { logger.info(String.format("Extracted annotations for %d variants labeled as %s.", data.isLabelFlat(label).stream().mapToInt(b -> b ? 1 : 0).sum(), label)); } - logger.info(String.format("Extracted annotations for %s total variants.", data.size())); + logger.info(String.format("Extracted annotations for %s total records.", data.size())); logger.info("Writing annotations..."); data.writeHDF5(outputAnnotationsFile, omitAllelesInHDF5); @@ -357,10 +357,8 @@ private TreeSet findMatchingResourceLabels(final VariantContext vc, for (final FeatureInput resource : resources) { final List resourceVCs = featureContext.getValues(resource, featureContext.getInterval().getStart()); for (final VariantContext resourceVC : resourceVCs) { - if (useASAnnotations && !doAllelesMatch(vc.getReference(), altAllele, resourceVC)) { - continue; - } - if (isMatchingVariant(vc, resourceVC, !doNotTrustAllPolymorphic, resourceMatchingStrategy)) { + if ((!useASAnnotations && isMatchingVariant(vc, resourceVC, !doNotTrustAllPolymorphic, resourceMatchingStrategy)) || + (useASAnnotations && doAllelesMatch(vc.getReference(), altAllele, resourceVC))) { resource.getTagAttributes().entrySet().stream() .filter(e -> e.getValue().equals("true")) .map(Map.Entry::getKey)