From fb7074cab9d9e4e3d293cb5487a3525a5faf06fd Mon Sep 17 00:00:00 2001 From: akikuno Date: Thu, 18 Jan 2024 13:31:17 +0900 Subject: [PATCH] Updated `preprocess.insertions_to_fasta.merge_similar_insertions`: Previously, clustering was done using MiniBatchKMeans, but this method had an issue where it excessively clustered when only highly similar insertion sequences existed. Therefore, a strategy similar to `extract_enriched_insertions` was adopted, changing the algorithm to one that mixes with a uniform distribution of random scores before clustering. --- .../core/preprocess/insertions_to_fasta.py | 57 ++++++++----------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/src/DAJIN2/core/preprocess/insertions_to_fasta.py b/src/DAJIN2/core/preprocess/insertions_to_fasta.py index f17a0a2b..dd2fae77 100644 --- a/src/DAJIN2/core/preprocess/insertions_to_fasta.py +++ b/src/DAJIN2/core/preprocess/insertions_to_fasta.py @@ -6,10 +6,9 @@ from typing import Generator import numpy as np -from sklearn import metrics from rapidfuzz import process from rapidfuzz.distance import DamerauLevenshtein -from sklearn.cluster import MeanShift, MiniBatchKMeans +from sklearn.cluster import MeanShift from DAJIN2.utils import io, config from DAJIN2.utils.cssplits_handler import call_sequence @@ -18,6 +17,12 @@ import cstag + +def remove_non_alphabets(s): + # Using list comprehension to keep only alphabet characters + return "".join([char for char in s if char.isalpha()]) + + ########################################################### # Detect insertion sequences ########################################################### @@ -36,11 +41,6 @@ def extract_all_insertions(midsv_sample: Generator, mutation_loci: dict) -> dict return dict(insertion_index_sequences_control) -def remove_non_alphabets(s): - # Using list comprehension to keep only alphabet characters - return "".join([char for char in s if char.isalpha()]) - - def extract_enriched_insertions( insertions_sample: dict, insertions_control: dict, coverage_sample: int ) -> dict[int, dict[str, int]]: @@ -52,7 +52,13 @@ def extract_enriched_insertions( seq_all = [remove_non_alphabets(seq) for seq in ins_sample + ins_control] query = seq_all[0] _, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance)) + + # Add random values from 0 to 1 + scores = list(scores) + rng = np.random.default_rng(1) + scores.extend(rng.random(max(100, len(scores)))) labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist() + labels_sample = labels_all[: len(ins_sample)] labels_control = labels_all[len(ins_sample) :] labels_count_sample = dict(Counter(labels_sample)) @@ -119,30 +125,15 @@ def _group_insertions(insertions, index_grouped) -> dict[dict[str, int]]: return insertions_grouped -def _get_normalized_scores(query: str, choices: list[str]) -> np.ndarray: - seqs, scores, _ = zip(*process.extract_iter(query, choices, scorer=DamerauLevenshtein.normalized_distance)) - scores_np = np.array(scores).reshape(-1, 1) - normalized_scores = (scores_np - scores_np.min()) / (scores_np.max() - scores_np.min()) - counts = np.array([s.count("|") for s in seqs]).reshape(-1, 1) - X = np.concatenate([normalized_scores, counts], axis=1) - return X - - -def _optimize_labels(X: np.array) -> list[int]: - sample_size = X.shape[0] - labels_prev = list(range(sample_size)) - for i in range(1, sample_size): - np.random.seed(seed=1) - labels_current = MiniBatchKMeans(n_clusters=i, random_state=1, n_init="auto").fit_predict(X).tolist() - silhuette = metrics.silhouette_score(X, labels_current) if i > 1 else 0 - mutual_info = metrics.adjusted_mutual_info_score(labels_prev, labels_current) - # print(i, Counter(labels_current), round(silhuette, 2), round(mutual_info, 2) ) # ! DEBUG - if i == 2 and silhuette < 0: - return [1] * len(labels_current) - if 0.9 < silhuette or 0.9 < mutual_info: - return labels_current - labels_prev = labels_current - return labels_current +def _optimize_labels(insertion: dict[str, int]) -> list[int]: + seq_all = [remove_non_alphabets(seq) for seq in insertion] + query = seq_all[0] + _, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance)) + scores = list(scores) + rng = np.random.default_rng(1) + scores.extend(rng.random(max(100, len(scores)))) + labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist() + return labels_all[: len(insertion)] def _get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[frozenset, int]: @@ -169,9 +160,7 @@ def merge_similar_insertions(insertions, mutation_loci) -> dict[dict[frozenset[s key, val = list(insertion.items())[0] insertions_merged[idx][frozenset([key])] = val continue - query = list(insertion.keys())[0] - X = _get_normalized_scores(query, list(insertion.keys())) - labels = _optimize_labels(X) + labels = _optimize_labels(insertion) insertions_merged[idx] = _get_merged_insertion(insertion, labels) return insertions_merged