From 7927feb0bb4f3091537aaebabd60a441456a3413 Mon Sep 17 00:00:00 2001 From: akikuno Date: Thu, 8 Feb 2024 06:35:28 +0900 Subject: [PATCH] Update `insertions_to_fasta.py`. + Modified the approach to reduce randomness by avoiding the use of set or frozenset. Now utilizes list or tuple along with `random.sample()` for subsetting reads. + Refactored `call_consensus_insertion_sequence`. + Fixed a bug in `extract_score_and_sequence` to ensure correct appending of scores for insertions_merged_subset. --- .../core/preprocess/insertions_to_fasta.py | 298 ++++++++++++------ 1 file changed, 198 insertions(+), 100 deletions(-) diff --git a/src/DAJIN2/core/preprocess/insertions_to_fasta.py b/src/DAJIN2/core/preprocess/insertions_to_fasta.py index ff62a112..24415377 100644 --- a/src/DAJIN2/core/preprocess/insertions_to_fasta.py +++ b/src/DAJIN2/core/preprocess/insertions_to_fasta.py @@ -1,5 +1,6 @@ from __future__ import annotations +import random from pathlib import Path from itertools import groupby from collections import defaultdict, Counter @@ -11,18 +12,24 @@ from sklearn.cluster import MeanShift from DAJIN2.utils import io, config -from DAJIN2.utils.cssplits_handler import call_sequence + +from DAJIN2.utils.cssplits_handler import convert_cssplits_to_cstag config.set_warnings_ignore() import cstag -def remove_non_alphabets(s): - # Using list comprehension to keep only alphabet characters +def remove_non_alphabets(s: str) -> str: + """Convert a cssplits to a plain DNA sequence.""" return "".join([char for char in s if char.isalpha()]) +########################################################### +# Cluster insertion alleles +########################################################### + + def clustering_insertions(insertions_cssplit: list[str]) -> list[int]: seq_all = [remove_non_alphabets(seq) for seq in insertions_cssplit] query = seq_all[0] @@ -41,7 +48,7 @@ def clustering_insertions(insertions_cssplit: list[str]) -> list[int]: ########################################################### -def extract_all_insertions(midsv_sample: Generator, mutation_loci: dict) -> dict[int, list[str]]: +def extract_all_insertions(midsv_sample: Generator, mutation_loci: list[set[str]]) -> dict[int, list[str]]: """To extract insertion sequences of **10 base pairs or more** at each index.""" insertion_index_sequences_control = defaultdict(list) for m_sample in midsv_sample: @@ -96,7 +103,9 @@ def extract_enriched_insertions( return enriched_insertions -def extract_insertions(path_sample: Path, path_control: Path, mutation_loci: dict) -> dict[int, dict[str, int]]: +def extract_insertions( + path_sample: Path, path_control: Path, mutation_loci: list[set[str]] +) -> dict[int, dict[str, int]]: insertions_sample = extract_all_insertions(io.read_jsonl(path_sample), mutation_loci) insertions_control = extract_all_insertions(io.read_jsonl(path_control), mutation_loci) coverage_sample = io.count_newlines(path_sample) @@ -109,7 +118,7 @@ def extract_insertions(path_sample: Path, path_control: Path, mutation_loci: dic ########################################################### -def _group_index_by_consecutive_insertions(mutation_loci: dict[str, set[int]]) -> list[tuple[int]]: +def group_index_by_consecutive_insertions(mutation_loci: list[set[str]]) -> list[tuple[int]]: index = sorted(i for i, m in enumerate(mutation_loci) if "+" in m) index_grouped = [] for _, group in groupby(enumerate(index), lambda i_x: i_x[0] - i_x[1]): @@ -118,7 +127,9 @@ def _group_index_by_consecutive_insertions(mutation_loci: dict[str, set[int]]) - return index_grouped -def _group_insertions(insertions, index_grouped) -> dict[dict[str, int]]: +def group_insertions( + insertions: dict[int, dict[str, int]], index_grouped: list[tuple[int]] +) -> dict[tuple[int], dict[str, int]]: """ Insertion bases in consecutive insertion base positions are grouped together in mutation_loci, as the insertion site may shift by a few bases. @@ -129,64 +140,127 @@ def _group_insertions(insertions, index_grouped) -> dict[dict[str, int]]: if i not in insertions: continue insertions_grouped[idx].update(insertions[i]) - return insertions_grouped + return dict(insertions_grouped) -def _get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[frozenset, int]: - insertion_label = [(label, {key: val}) for label, (key, val) in zip(labels, insertion.items())] +def get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[tuple[str], int]: + insertion_label = [(label, {seq: count}) for label, (seq, count) in zip(labels, insertion.items())] insertion_label.sort(key=lambda x: x[0]) + insertion_merged = dict() for _, group in groupby(insertion_label, key=lambda x: x[0]): group = [g[1] for g in group] - seq, count = set(), 0 - for g in group: - key, val = list(g.items())[0] - seq.add(key) - count += val - insertion_merged[frozenset(seq)] = count + sequences, counts = set(), 0 + for seq_count in group: + seq, count = next(iter(seq_count.items())) + sequences.add(seq) + counts += count + insertion_merged[tuple(sorted(sequences))] = counts + return insertion_merged -def merge_similar_insertions(insertions, mutation_loci) -> dict[dict[frozenset[str], int]]: - index_grouped = _group_index_by_consecutive_insertions(mutation_loci) - insertions_grouped = _group_insertions(insertions, index_grouped) - insertions_merged = defaultdict(dict) +def merge_similar_insertions( + insertions: dict[int, dict[str, int]], mutation_loci: list[set[str]] +) -> dict[tuple(int), dict[tuple[str], int]]: + index_grouped = group_index_by_consecutive_insertions(mutation_loci) + insertions_grouped = group_insertions(insertions, index_grouped) + insertions_merged = dict() for idx, insertion in insertions_grouped.items(): if len(insertion) == 1: - key, val = list(insertion.items())[0] - insertions_merged[idx][frozenset([key])] = val + seq, count = next(iter(insertion.items())) + insertions_merged[idx] = {tuple([seq]): count} continue labels = clustering_insertions(insertion) - insertions_merged[idx] = _get_merged_insertion(insertion, labels) + insertions_merged[idx] = get_merged_insertion(insertion, labels) + return insertions_merged ########################################################### -# Cluster insertion alleles +# Pre- and post-process of clustering insertion alleles ########################################################### -def extract_score_and_sequence(path_sample, insertions_merged) -> list[tuple[list[int], str]]: +def flatten_keys_to_set(dict_keys: list[tuple[int]]) -> set[int]: + flattened_nums = set() + for key_tuple in dict_keys: + for num in key_tuple: + flattened_nums.add(num) + return flattened_nums + + +def subset_insertions( + insertions_merged: dict[tuple(int), dict[tuple[str], int]], num_subset: int = 100 +) -> dict[tuple(int), dict[tuple[str], int]]: + """ + If the number of seqs exceeds `num_subset`, limit it to only the `num_subset` elements with fiexd random sampling. + """ + random.seed(1) + insertions_merged_subset = defaultdict(dict) + for idx_grouped, insertions in insertions_merged.items(): + for cs_insertion, counts in insertions.items(): + if len(cs_insertion) > num_subset: + sequences = random.sample(cs_insertion, num_subset) + else: + sequences = cs_insertion + sequences = tuple(remove_non_alphabets(seq) for seq in sequences) + insertions_merged_subset[idx_grouped][sequences] = counts + + return dict(insertions_merged_subset) + + +def extract_score_and_sequence( + path_sample, insertions_merged: dict[tuple(int), dict[tuple[str], int]] +) -> list[tuple[list[int], str]]: scores = [] sequences = [] + + insertions_merged_subset = subset_insertions(insertions_merged, num_subset=100) + set_keys = flatten_keys_to_set(insertions_merged.keys()) + + cache_score = defaultdict(int) for m in io.read_jsonl(path_sample): cssplits = m["CSSPLIT"].split(",") + if not any(True for i in set_keys if cssplits[i].startswith("+")): + continue + score = [0] * len(insertions_merged) sequence = ["N"] * len(insertions_merged) for i, (idx_grouped, insertions) in enumerate(insertions_merged.items()): for idx in idx_grouped: if not cssplits[idx].startswith("+"): continue + for seqs, count in insertions.items(): - _, distances, _ = zip( - *process.extract_iter(cssplits[idx], seqs, scorer=DamerauLevenshtein.normalized_distance) - ) - if any(True for d in distances if d < 0.1): + if cssplits[idx] in seqs: score[i] = count sequence[i] = cssplits[idx] + continue + + if (idx, cssplits[idx]) in cache_score: + score[i] = cache_score[(idx, cssplits[idx])] + sequence[i] = cssplits[idx] + continue + + flag_break = False + for seqs, count in insertions_merged_subset[idx_grouped].items(): + for _, distance, _ in process.extract_iter( + remove_non_alphabets(cssplits[idx]), seqs, scorer=DamerauLevenshtein.normalized_distance + ): + if distance < 0.1: + score[i] = count + sequence[i] = cssplits[idx] + flag_break = True + break + if flag_break: + cache_score[(idx, cssplits[idx])] = count + break + if any(score): scores.append(score) sequences.append(",".join(sequence)) + return [(score, sequence) for score, sequence in zip(scores, sequences)] @@ -204,24 +278,6 @@ def filter_minor_label( return labels_filtered, score_seq_filterd -def update_labels(d: dict, FASTA_ALLELES: dict) -> dict: - """ - Update labels to avoid duplicating user-specified alleles - (insertion1 -> insertion01 -> insertion001...) - """ - user_defined_alleles = set(FASTA_ALLELES) - d_values = list(d.values()) - len_d = len(d_values) - digits_up = 0 - while True: - digits = len(str(len_d)) + digits_up - d_updated = {f"insertion{i+1:0{digits}}": seq for i, seq in enumerate(d_values)} - if user_defined_alleles.isdisjoint(set(d_updated)): - break - digits_up += 1 - return d_updated - - ########################################################### # Call consensus ########################################################### @@ -233,28 +289,57 @@ def subset_sequences(sequences, labels, num=1000) -> list[dict]: """ sequences_subset = [] tmp_sequences = [] + random.seed(1) for sequence, label in zip(sequences, labels): tmp_sequences.append({"CSSPLIT": sequence, "LABEL": label}) tmp_sequences.sort(key=lambda x: x["LABEL"]) for _, group in groupby(tmp_sequences, key=lambda x: x["LABEL"]): - sequences_subset.extend(list(group)[:num]) + group = list(group) + if len(group) > num: + sequences_subset += random.sample(group, num) + else: + sequences_subset += group return sequences_subset -def _call_percentage(cssplits: list[list[str]]) -> list[dict[str, float]]: - """call position weight matrix in defferent loci.""" - coverage = len(cssplits) +def call_consensus_insertion_sequence(cssplits: list[list[str]]) -> str: + consensus_insertion = [] cssplits_transposed = (list(cs) for cs in zip(*cssplits)) - cons_percentage = [] for cs_transposed in cssplits_transposed: + if all(True if cs == "N" else False for cs in cs_transposed): + consensus_insertion.append("N") + continue + + count_N = sum(1 for cs in cs_transposed if cs == "N") + count_cs = defaultdict(float) for cs in cs_transposed: - count_cs[cs] += 1 / coverage * 100 - cons_percentage.append(dict(count_cs)) - return cons_percentage + if cs == "N": + continue + count_cs[len(cs)] += 1 + + consensus_length = max(count_cs, key=count_cs.get) + count_consensus = sum(1 for cs in cs_transposed if len(cs) == consensus_length) + + if count_N > count_consensus: + consensus_insertion.append("N") + continue + + cs_insertion = [cs for cs in cs_transposed if cs != "N" and len(cs) == consensus_length] + + cs_insertion_transposed = (list(cs) for cs in zip(*(cs.split("|") for cs in cs_insertion))) + cs_insertion_consensus = [] + for cs_ins in cs_insertion_transposed: + cs_ins_most_common = Counter(cs_ins).most_common()[0][0] + cs_insertion_consensus.append(cs_ins_most_common) + + cs_insertion_consensus = "|".join(cs_insertion_consensus) + consensus_insertion.append(cs_insertion_consensus) + return ",".join(consensus_insertion) -def _remove_all_n(cons_sequence: dict[int, str]) -> dict[int, str]: + +def remove_all_n(cons_sequence: dict[int, str]) -> dict[int, str]: cons_sequence_removed = dict() for label, seq in cons_sequence.items(): if all(True if s == "N" else False for s in seq.split(",")): @@ -263,19 +348,36 @@ def _remove_all_n(cons_sequence: dict[int, str]) -> dict[int, str]: return cons_sequence_removed -def call_consensus_of_insertion(insertion_sequences_subset: list[dict]) -> dict[int, str]: +def update_labels(d: dict, FASTA_ALLELES: dict) -> dict: + """ + Update labels to avoid duplicating user-specified alleles + (insertion1 -> insertion01 -> insertion001...) + """ + user_defined_alleles = set(FASTA_ALLELES) + d_values = list(d.values()) + len_d = len(d_values) + digits_up = 0 + while True: + digits = len(str(len_d)) + digits_up + d_updated = {f"insertion{i+1:0{digits}}": seq for i, seq in enumerate(d_values)} + if user_defined_alleles.isdisjoint(set(d_updated)): + break + digits_up += 1 + return d_updated + + +def call_consensus_of_insertion(insertion_sequences_subset: list[dict], FASTA_ALLELES: dict) -> dict[int, str]: cons_sequence = dict() insertion_sequences_subset.sort(key=lambda x: x["LABEL"]) for label, group in groupby(insertion_sequences_subset, key=lambda x: x["LABEL"]): cssplits = [cs["CSSPLIT"].split(",") for cs in group] - cons_per = _call_percentage(cssplits) - cons_seq = call_sequence(cons_per, sep=",") - cons_sequence[label] = cons_seq - return _remove_all_n(cons_sequence) + cons_sequence[label] = call_consensus_insertion_sequence(cssplits) + + return update_labels(remove_all_n(cons_sequence), FASTA_ALLELES) def extract_index_of_insertions( - insertions: dict[int, dict[str, int]], insertions_merged: dict[dict[frozenset[str], int]] + insertions: dict[int, dict[str, int]], insertions_merged: dict[tuple(int), dict[tuple[str], int]] ) -> list[int]: """`insertions_merged` contains multiple surrounding indices for a single insertion allele. Among them, select the one index where the insertion allele is most frequent.""" index_of_insertions = [] @@ -292,85 +394,81 @@ def extract_index_of_insertions( ########################################################### -# generate and save fasta +# generate cstag and FASTA ########################################################### -def generate_fasta(cons_sequence, index_of_insertions, sequence) -> dict[int, str]: - fasta_insertions = dict() - for label, cons_seq in cons_sequence.items(): +def generate_cstag( + consensus_of_insertions: dict[str, str], index_of_insertions: list[int], sequence: str +) -> dict[str, str]: + cstag_insertions = dict() + for label, cons_seq in consensus_of_insertions.items(): cons_seq = cons_seq.split(",") list_sequence = list(sequence) for idx, seq in zip(index_of_insertions, cons_seq): if seq == "N": continue - list_sequence[idx] = seq - fasta_insertions[label] = "".join(list_sequence) - return fasta_insertions + list_sequence[idx] = convert_cssplits_to_cstag([seq]) + cstag_insertions[label] = "cs:Z:=" + "".join(list_sequence) + return cstag_insertions -def save_fasta(TEMPDIR: Path | str, SAMPLE_NAME: str, fasta_insertions: dict) -> None: - for header, seq in fasta_insertions.items(): - Path(TEMPDIR, SAMPLE_NAME, "fasta", f"{header}.fasta").write_text(f">{header}\n{seq}") +def generate_fasta(cstag_insertions: dict[str, str]) -> dict[str, str]: + fasta_insertions = dict() + for label, cs_tag in cstag_insertions.items(): + fasta_insertions[label] = cstag.to_sequence(cs_tag) + return fasta_insertions ########################################################### -# generate and save as HTML and PDF +# Save cstag (HTML) and fasta ########################################################### -def generate_cstag(cons_sequence, index_of_insertions, sequence) -> dict[int, str]: - cstag_insertions = dict() - for label, cons_seq in cons_sequence.items(): - cons_seq = cons_seq.split(",") - list_sequence = list(sequence) - for idx, seq in zip(index_of_insertions, cons_seq): - if seq == "N": - continue - seq = seq.lower() - list_sequence[idx] = f"+{seq}=" - cstag_insertions[label] = "cs:Z:=" + "".join(list_sequence) - return cstag_insertions - - -def save_html(TEMPDIR: Path, SAMPLE_NAME: str, cstag_insertions: dict) -> None: +def save_html(TEMPDIR: Path, SAMPLE_NAME: str, cstag_insertions: dict[int, str]) -> None: for header, cs_tag in cstag_insertions.items(): html = cstag.to_html(cs_tag, f"{SAMPLE_NAME} {header}") Path(TEMPDIR, "report", "HTML", SAMPLE_NAME, f"{header}.html").write_text(html) +def save_fasta(TEMPDIR: Path | str, SAMPLE_NAME: str, fasta_insertions: dict) -> None: + for header, seq in fasta_insertions.items(): + Path(TEMPDIR, SAMPLE_NAME, "fasta", f"{header}.fasta").write_text(f">{header}\n{seq}\n") + + ########################################################### # main ########################################################### def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES) -> None: - path_sample = Path(TEMPDIR, SAMPLE_NAME, "midsv", "control.json") - path_control = Path(TEMPDIR, CONTROL_NAME, "midsv", "control.json") - sequence = FASTA_ALLELES["control"] - mutation_loci = io.load_pickle(Path(TEMPDIR, SAMPLE_NAME, "mutation_loci", "control.pickle")) - insertions = extract_insertions(path_sample, path_control, mutation_loci) + PATH_SAMPLE = Path(TEMPDIR, SAMPLE_NAME, "midsv", "control.json") + PATH_CONTROL = Path(TEMPDIR, CONTROL_NAME, "midsv", "control.json") + SEQUENCE = FASTA_ALLELES["control"] + MUTATION_LOCI = io.load_pickle(Path(TEMPDIR, SAMPLE_NAME, "mutation_loci", "control.pickle")) + + insertions = extract_insertions(PATH_SAMPLE, PATH_CONTROL, MUTATION_LOCI) if insertions == dict(): return None - insertions_merged = merge_similar_insertions(insertions, mutation_loci) - insertions_scores_sequences = extract_score_and_sequence(path_sample, insertions_merged) + insertions_merged = merge_similar_insertions(insertions, MUTATION_LOCI) + insertions_scores_sequences = extract_score_and_sequence(PATH_SAMPLE, insertions_merged) labels = clustering_insertions([cssplit for _, cssplit in insertions_scores_sequences]) labels_filtered, insertion_scores_sequences_filtered = filter_minor_label( - path_sample, labels, insertions_scores_sequences, threshold=0.5 + PATH_SAMPLE, labels, insertions_scores_sequences, threshold=0.5 ) insertion_sequences_subset = subset_sequences( [seq for _, seq in insertion_scores_sequences_filtered], labels_filtered, num=1000 ) - consensus_of_insertions = call_consensus_of_insertion(insertion_sequences_subset) + consensus_of_insertions = call_consensus_of_insertion(insertion_sequences_subset, FASTA_ALLELES) if consensus_of_insertions == dict(): """ If there is no insertion sequence, return None It is possible when all insertion sequence annotated as `N` that is filtered out """ return None - consensus_of_insertions = update_labels(consensus_of_insertions, FASTA_ALLELES) index_of_insertions = extract_index_of_insertions(insertions, insertions_merged) - fasta_insertions = generate_fasta(consensus_of_insertions, index_of_insertions, sequence) - cstag_insertions = generate_cstag(consensus_of_insertions, index_of_insertions, sequence) + cstag_insertions = generate_cstag(consensus_of_insertions, index_of_insertions, SEQUENCE) + fasta_insertions = generate_fasta(cstag_insertions) + save_fasta(TEMPDIR, SAMPLE_NAME, fasta_insertions) save_html(TEMPDIR, SAMPLE_NAME, cstag_insertions)