diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index aa6c651a..946d3cb6 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -12,7 +12,7 @@ Functionality Changes: Misc: -* mergeutils: function GetMinHarmonizedRecords was transformed into GetRecordComparabilityAndIncrement, which allows the caller +* mergeutils: function GetMinHarmonizedRecords was transformed into GetIncrementAndComparability, which allows the caller to define custom predicate that decides whether records are comparable. 4.0.2 diff --git a/trtools/compareSTR/compareSTR.py b/trtools/compareSTR/compareSTR.py index 2cc56288..77fef1ff 100644 --- a/trtools/compareSTR/compareSTR.py +++ b/trtools/compareSTR/compareSTR.py @@ -899,8 +899,8 @@ def main(args): range(len(current_records))] # increments contains information about which record should be # skipped in next iteration - increment, comparable = mergeutils.GetRecordComparabilityAndIncrement(harmonized_records, chroms, - handle_overlaps) + increment, comparable = mergeutils.GetIncrementAndComparability(harmonized_records, chroms, + handle_overlaps) if args.verbose: mergeutils.DebugPrintRecordLocations(current_records, increment) if mergeutils.CheckMin(increment): return 1 diff --git a/trtools/mergeSTR/mergeSTR.py b/trtools/mergeSTR/mergeSTR.py index 7e7bf7d7..142e7e92 100644 --- a/trtools/mergeSTR/mergeSTR.py +++ b/trtools/mergeSTR/mergeSTR.py @@ -624,12 +624,17 @@ def main(args: Any) -> int: ", e.g.: bcftools reheader -f hg19.fa.fai -o myvcf-readher.vcf.gz myvcf.vcf.gz") return 1 harmonized_records = HarmonizeIfNotNone(current_records, vcftype) - is_min = mergeutils.GetMinHarmonizedRecords(harmonized_records, chroms) - if args.verbose: mergeutils.DebugPrintRecordLocations(current_records, is_min) - if mergeutils.CheckMin(is_min): return 1 - MergeRecords(vcfreaders, vcftype, num_samples, harmonized_records, is_min, vcfw, useinfo, + + # mergeSTR doesnt provide custom comparability handler. By default, only the increment is necessary to decide + # which records should be merged during single iteration. This is because the merge is based on the position + # of the records. If this behaviour changes in the future, custom mergability handler will have to be created. + increment, _ = mergeutils.GetIncrementAndComparability(harmonized_records, chroms) + + if args.verbose: mergeutils.DebugPrintRecordLocations(current_records, increment) + if mergeutils.CheckMin(increment): return 1 + MergeRecords(vcfreaders, vcftype, num_samples, harmonized_records, increment, vcfw, useinfo, useformat, format_type) - current_records = mergeutils.GetNextRecords(vcfreaders, current_records, is_min) + current_records = mergeutils.GetNextRecords(vcfreaders, current_records, increment) done = mergeutils.DoneReading(current_records) return 0 diff --git a/trtools/utils/mergeutils.py b/trtools/utils/mergeutils.py index 817d0028..92bcfb3b 100644 --- a/trtools/utils/mergeutils.py +++ b/trtools/utils/mergeutils.py @@ -16,6 +16,7 @@ CYVCF_RECORD = cyvcf2.Variant CYVCF_READER = cyvcf2.VCF +COMPARABILITY_CALLBACK = Callable[[List[Optional[trh.TRRecord]], List[int], int], Union[bool, List[bool]]] def LoadReaders(vcffiles: List[str], region: Optional[str] = None) -> List[CYVCF_READER]: @@ -248,13 +249,18 @@ def GetMinRecords(record_list: List[Optional[trh.TRRecord]], chroms: List[str]) return [CheckPos(r, chroms[min_chrom], min_pos) for r in record_list] +def default_callback(records: List[trh.TRRecord], chrom_order: List[int], min_chrom_index: int) -> bool: + return True + + +def GetIncrementAndComparability(record_list: List[Optional[trh.TRRecord]], + chroms: List[str], + overlap_callback: COMPARABILITY_CALLBACK = default_callback) \ + -> Tuple[List[bool], Union[bool, List[bool]]]: -def GetRecordComparabilityAndIncrement(record_list: List[Optional[trh.TRRecord]], - chroms: List[str], - overlap_callback: Callable[[List[Optional[trh.TRRecord]], List[int], int], bool]) \ - -> Tuple[List[bool], bool]: r"""Get list that says which records should be skipped in the next - iteration, and whether they are all comparable with each other + iteration (increment), and whether they are all comparable / mergable + The value of increment elements is determined by the (harmonized) position of corresponding records Parameters @@ -265,7 +271,7 @@ def GetRecordComparabilityAndIncrement(record_list: List[Optional[trh.TRRecord]] chroms : list of str Ordered list of all chromosomes - overlap_callback: Callable[[List[Optional[trh.TRRecord]], List[int], int], bool] + overlap_callback: Callable[[List[Optional[trh.TRRecord]], List[int], int], Union[bool, List[bool]] Function that calculates whether the records are comparable Returns @@ -273,7 +279,7 @@ def GetRecordComparabilityAndIncrement(record_list: List[Optional[trh.TRRecord]] increment : list of bool List or bools, where items are set to True when the record at the index of the item should be skipped during VCF file comparison. - comparable: bool + comparable: bool or list of bool Value, that determines whether current records are comparable / mergable, depending on the callback """ chrom_order = [np.inf if r is None else chroms.index(r.chrom) for r in record_list] diff --git a/trtools/utils/tests/test_mergeutils.py b/trtools/utils/tests/test_mergeutils.py index ceda5684..b36c985f 100644 --- a/trtools/utils/tests/test_mergeutils.py +++ b/trtools/utils/tests/test_mergeutils.py @@ -110,27 +110,27 @@ def comp_callback_false(x, y, z): pair = [DummyHarmonizedRecord("chr1", 20), DummyHarmonizedRecord("chr1", 20)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([True, True], True) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_true) == ([True, True], True) # these two test cases show that second result of GetRecordComparabilityAndIncrement is # entirely dependant on the callback pair = [DummyHarmonizedRecord("chr1", 21), DummyHarmonizedRecord("chr1", 20)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_false) == ([False, True], False) pair = [DummyHarmonizedRecord("chr1", 21), DummyHarmonizedRecord("chr1", 20)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([False, True], True) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_true) == ([False, True], True) pair = [DummyHarmonizedRecord("chr2", 20), DummyHarmonizedRecord("chr1", 20)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_false) == ([False, True], False) pair = [DummyHarmonizedRecord("chr1", 20), DummyHarmonizedRecord("chr1", 21)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([True, False], True) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_true) == ([True, False], True) pair = [None, None] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, False], False) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_false) == ([False, False], False) pair = [DummyHarmonizedRecord("chr1", 20), None] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([True, False], False) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_false) == ([True, False], False) pair = [None, DummyHarmonizedRecord("chr1", 20)] - assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False) + assert mergeutils.GetIncrementAndComparability(pair, chromosomes, comp_callback_false) == ([False, True], False)