From 80c7afce7c82ef7e60f0807808671ad7b955d6f7 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Fri, 1 Nov 2024 17:23:34 +0000 Subject: [PATCH] allow duplicated matching to the best match --- bean/mapping/GuideEditCounter.py | 40 ++++++++++++---- bean/mapping/_supporting_fn.py | 80 +++++++++++++++++++++++++++++++- bean/mapping/utils.py | 16 +++++++ tests/test_count.py | 2 +- 4 files changed, 127 insertions(+), 11 deletions(-) diff --git a/bean/mapping/GuideEditCounter.py b/bean/mapping/GuideEditCounter.py index 2d7b4ad..e1a1f70 100755 --- a/bean/mapping/GuideEditCounter.py +++ b/bean/mapping/GuideEditCounter.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional, Sequence, Union +from typing import Tuple, Optional, Sequence, Union, Iterable import gzip import logging from copy import deepcopy @@ -30,6 +30,7 @@ def tqdm(iterable, **kwargs): _write_alignment_matrix, _write_paired_end_reads, revcomp, + find_closest_sequence_index ) logging.basicConfig( @@ -126,12 +127,15 @@ def __init__(self, **kwargs): - kwargs["gstart_reporter"] - self.screen.guides["guide_len"].max() ) + self.map_duplicated_to_best = kwargs["map_duplicated_to_best"] + self.map_duplicated_hamming_threshold = kwargs["map_duplicated_hamming_threshold"] self.count_guide_edits = kwargs["count_guide_edits"] if self.count_guide_edits: self.screen.uns["guide_edit_counts"] = {} self.count_reporter_edits = ( kwargs["count_reporter"] or kwargs["count_guide_reporter_alleles"] ) + self.mask_barcode = kwargs["mask_barcode"] if self.count_reporter_edits: self.screen.uns["edit_counts"] = {} self.gstart_reporter = kwargs["gstart_reporter"] @@ -584,6 +588,7 @@ def _get_guide_counts_bcmatch_semimatch( semimatch, R2_start, ) = self._match_read_to_sgRNA_bcmatch_semimatch(R1_seq, R2_seq) + mapped = False if len(bc_match) == 0: if len(semimatch) == 0: # no guide match if self.keep_intermediate: @@ -597,19 +602,26 @@ def _get_guide_counts_bcmatch_semimatch( r1, r2, outfile_R1_dup_wo_bc, outfile_R2_dup_wo_bc ) self.duplicate_match_wo_barcode += 1 + if self.map_duplicated_to_best: + best_matched_guide_idx = self._find_closest_sequence(R1_seq, R2_seq, semimatch, self.map_duplicated_hamming_threshold) + if best_matched_guide_idx is not None: + self.screen.layers[semimatch_layer][best_matched_guide_idx, 0] += 1 else: # guide match with no barcode match matched_guide_idx = semimatch[0] self.screen.layers[semimatch_layer][matched_guide_idx, 0] += 1 if self.count_guide_edits: guide_allele, _ = self._count_guide_edits(matched_guide_idx, r1) self.semimatch += 1 - elif len(bc_match) >= 2: # duplicate mapping if self.keep_intermediate: _write_paired_end_reads(r1, r2, outfile_R1_dup, outfile_R2_dup) - self.duplicate_match += 1 - + if self.map_duplicated_to_best: + best_matched_guide_idx = self._find_closest_sequence(R1_seq, R2_seq, bc_match, self.map_duplicated_hamming_threshold) + if best_matched_guide_idx is not None: + mapped = True else: # unique barcode match + mapped = True + if mapped: matched_guide_idx = bc_match[0] self.screen.layers[bcmatch_layer][matched_guide_idx, 0] += 1 self.bcmatch += 1 @@ -732,7 +744,7 @@ def get_barcode(self, R1_seq, R2_seq): R2_seq[barcode_start_idx : (barcode_start_idx + self.guide_bc_len)] ) - def _match_read_to_sgRNA_bcmatch_semimatch(self, R1_seq: str, R2_seq: str): + def _match_read_to_sgRNA_bcmatch_semimatch(self, R1_seq: str, R2_seq: str) -> Tuple[int, int, int]: # This should be adjusted for each experimental recipes. bc_start_idx, guide_barcode = self.get_barcode(R1_seq, R2_seq) bc_match_idx = np.array([]) @@ -745,17 +757,27 @@ def _match_read_to_sgRNA_bcmatch_semimatch(self, R1_seq: str, R2_seq: str): _seq_match = np.where( self.mask_sequence(seq) == self.screen.guides.masked_sequence )[0] - _bc_match = np.where( - self.mask_sequence(guide_barcode) == self.screen.guides.masked_barcode - )[0] + if self.mask_barcode: + _bc_match = np.where( + self.mask_sequence(guide_barcode) == self.screen.guides.masked_barcode + )[0] + else: + _bc_match = np.where( + guide_barcode == self.screen.guides.barcode + )[0] bc_match_idx = np.append( bc_match_idx, np.intersect1d(_seq_match, _bc_match) ) semimatch_idx = np.append(semimatch_idx, _seq_match) - return bc_match_idx.astype(int), semimatch_idx.astype(int), bc_start_idx + def _find_closest_sequence(self, R1_seq, R2_seq, ref_guide_indices: Sequence[int], hamming_distance_threshold: int = 0.1,) -> int: + guide_lens = self.screen.guides.iloc[ref_guide_indices].map(len) + query_guide_seqs = [self.get_guide_seq(R1_seq, R2_seq, guide_len) for guide_len in guide_lens] + closest_idx = find_closest_sequence_index(query_guide_seqs, self.screen.guides.sequence.iloc[ref_guide_indices].values.tolist(), hamming_distance_threshold, match_score, mismatch_penalty, allowed_substitution_penalties = {(k, v):0.2 for k, v in self.target_base_edits}) + return ref_guide_indices[closest_idx] + def _get_guide_position_seq_of_read(self, seq): guide_start_idx = self._get_guide_start_idx(seq) if guide_start_idx == -1: diff --git a/bean/mapping/_supporting_fn.py b/bean/mapping/_supporting_fn.py index 62c0f79..fb35681 100755 --- a/bean/mapping/_supporting_fn.py +++ b/bean/mapping/_supporting_fn.py @@ -1,4 +1,4 @@ -from typing import List, Union, Dict, Optional, Tuple +from typing import List, Union, Dict, Optional, Tuple, Iterable import subprocess as sb import numpy as np import pandas as pd @@ -313,3 +313,81 @@ def _multiindex_dict_to_df(input_dict, key_column_names, value_column_name): inplace=True, ) return df + + +def hamming_distance( + seq1: str, + seq2: str, + match_score: float = 0, + mismatch_penalty: float = 1, + allowed_substitutions_penalties: Dict[Tuple[str, str], float] = { + ("A", "G"): 0.2, + ("T", "C"): 0.2, + }, +): + """ + Calculates the Hamming distance between two DNA sequences with different penalties. + + Args: + seq1 (str): The first DNA sequence. + seq2 (str): The second DNA sequence. + match_score (int): Score for a match (default: 0). + mismatch_penalty (int): Penalty for a mismatch (default: 1). + + Returns: + int: The Hamming distance. + """ + + if len(seq1) != len(seq2): + raise ValueError("Sequences must be of equal length.") + + distance = 0 + for i in range(len(seq1)): + if seq1[i] == seq2[i]: + distance += match_score + elif (seq1[i], seq2[i]) in allowed_substitutions_penalties: + distance += allowed_substitutions_penalties[(seq1[i], seq2[i])] + else: + distance += mismatch_penalty + return distance + + +def find_closest_sequence_index( + query_seqs: Iterable[str], + ref_seqs: Iterable[str], + hamming_distance_threshold: int = 0.1, + match_score: float = 0, + mismatch_penalty: float = 1, + allowed_substitutions_penalties: Dict[Tuple[str, str], float] = { + ("A", "G"): 0.2, + ("T", "C"): 0.2, + }, +) -> int: + """ + Find the closest sequence to a query sequence. + + Args: + query_seq (str): The query sequence. + ref_seqs (Iterable[str]): The reference sequences. + hamming_distance_threshold (int, optional): The maximum allowed hamming distance. Defaults to 3. + match_score (int, optional): Score for a match (default: 0). Defaults to 0. + mismatch_penalty (int, optional): Penalty for a mismatch (default: 1). Defaults to 1. + allowed_substitutions_penalties (Dict[Tuple[str, str], float], optional): Allowed substitutions and penalties. Defaults to [("A", "G"):0.2, ("T", "C"):0.2]. + + Returns: + int: The closest sequence pair's index in the input ref_seqs list. + """ + min_distance = float("inf") + closest_seq_index = None + for i, (query_seq, ref_seq) in enumerate(zip(query_seqs, ref_seqs)): + distance = hamming_distance( + query_seq, + ref_seq, + match_score, + mismatch_penalty, + allowed_substitutions_penalties, + ) / len(query_seq) + if distance < min_distance and distance <= hamming_distance_threshold: + min_distance = distance + closest_seq_index = i + return closest_seq_index diff --git a/bean/mapping/utils.py b/bean/mapping/utils.py index 8f220ee..a8796a6 100755 --- a/bean/mapping/utils.py +++ b/bean/mapping/utils.py @@ -244,6 +244,22 @@ def _get_input_parser(parser=None): help="count the matched allele of guide and reporter edit", action="store_true", ) + parser.add_argument( + "--map-duplicated-to-best", + help="When found duplicated mapping allowing for intended edits, map them to the best-matching reads", + action="store_true", + ) + parser.add_argument( + "--map-duplicated-hamming-threshold", + help="When found duplicated mapping allowing for intended edits, map them to the best-matching reads only when hamming distance is less than 5*value*guide_length for intended edit", + type=float, + default=0.1, + ) + parser.add_argument( + "--mask-barcode", + help="Allow intended base edit in the barcode sequence.", + action="store_true", + ) parser.add_argument( "--tiling", help="Specify that the guide library is tiling library without 'n guides per target' design", diff --git a/tests/test_count.py b/tests/test_count.py index 472d451..fae1c94 100755 --- a/tests/test_count.py +++ b/tests/test_count.py @@ -72,7 +72,7 @@ def test_count_samples_dual(): @pytest.mark.order(107) def test_count_samples_bcstart(): - cmd = "bean count-samples -i tests/data/sample_list.csv -b A -f tests/data/test_guide_info.csv -o tests/test_res/var2/ -r --barcode-start-seq=GGAA" + cmd = "bean count-samples -i tests/data/sample_list.csv -b A -f tests/data/test_guide_info.csv -o tests/test_res/var2/ -r --barcode-start-seq=GGAA --map-duplicated-to-best" try: subprocess.check_output( cmd,