Skip to content

Commit

Permalink
Updated preprocess.insertions_to_fasta.merge_similar_insertions: Pr…
Browse files Browse the repository at this point in the history
…eviously, clustering was done using MiniBatchKMeans, but this method had an issue where it excessively clustered when only highly similar insertion sequences existed. Therefore, a strategy similar to `extract_enriched_insertions` was adopted, changing the algorithm to one that mixes with a uniform distribution of random scores before clustering.
  • Loading branch information
akikuno committed Jan 18, 2024
1 parent ce86359 commit fb7074c
Showing 1 changed file with 23 additions and 34 deletions.
57 changes: 23 additions & 34 deletions src/DAJIN2/core/preprocess/insertions_to_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from typing import Generator

import numpy as np
from sklearn import metrics
from rapidfuzz import process
from rapidfuzz.distance import DamerauLevenshtein
from sklearn.cluster import MeanShift, MiniBatchKMeans
from sklearn.cluster import MeanShift

from DAJIN2.utils import io, config
from DAJIN2.utils.cssplits_handler import call_sequence
Expand All @@ -18,6 +17,12 @@

import cstag


def remove_non_alphabets(s):
# Using list comprehension to keep only alphabet characters
return "".join([char for char in s if char.isalpha()])


###########################################################
# Detect insertion sequences
###########################################################
Expand All @@ -36,11 +41,6 @@ def extract_all_insertions(midsv_sample: Generator, mutation_loci: dict) -> dict
return dict(insertion_index_sequences_control)


def remove_non_alphabets(s):
# Using list comprehension to keep only alphabet characters
return "".join([char for char in s if char.isalpha()])


def extract_enriched_insertions(
insertions_sample: dict, insertions_control: dict, coverage_sample: int
) -> dict[int, dict[str, int]]:
Expand All @@ -52,7 +52,13 @@ def extract_enriched_insertions(
seq_all = [remove_non_alphabets(seq) for seq in ins_sample + ins_control]
query = seq_all[0]
_, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))

# Add random values from 0 to 1
scores = list(scores)
rng = np.random.default_rng(1)
scores.extend(rng.random(max(100, len(scores))))
labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist()

labels_sample = labels_all[: len(ins_sample)]
labels_control = labels_all[len(ins_sample) :]
labels_count_sample = dict(Counter(labels_sample))
Expand Down Expand Up @@ -119,30 +125,15 @@ def _group_insertions(insertions, index_grouped) -> dict[dict[str, int]]:
return insertions_grouped


def _get_normalized_scores(query: str, choices: list[str]) -> np.ndarray:
seqs, scores, _ = zip(*process.extract_iter(query, choices, scorer=DamerauLevenshtein.normalized_distance))
scores_np = np.array(scores).reshape(-1, 1)
normalized_scores = (scores_np - scores_np.min()) / (scores_np.max() - scores_np.min())
counts = np.array([s.count("|") for s in seqs]).reshape(-1, 1)
X = np.concatenate([normalized_scores, counts], axis=1)
return X


def _optimize_labels(X: np.array) -> list[int]:
sample_size = X.shape[0]
labels_prev = list(range(sample_size))
for i in range(1, sample_size):
np.random.seed(seed=1)
labels_current = MiniBatchKMeans(n_clusters=i, random_state=1, n_init="auto").fit_predict(X).tolist()
silhuette = metrics.silhouette_score(X, labels_current) if i > 1 else 0
mutual_info = metrics.adjusted_mutual_info_score(labels_prev, labels_current)
# print(i, Counter(labels_current), round(silhuette, 2), round(mutual_info, 2) ) # ! DEBUG
if i == 2 and silhuette < 0:
return [1] * len(labels_current)
if 0.9 < silhuette or 0.9 < mutual_info:
return labels_current
labels_prev = labels_current
return labels_current
def _optimize_labels(insertion: dict[str, int]) -> list[int]:
seq_all = [remove_non_alphabets(seq) for seq in insertion]
query = seq_all[0]
_, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))
scores = list(scores)
rng = np.random.default_rng(1)
scores.extend(rng.random(max(100, len(scores))))
labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist()
return labels_all[: len(insertion)]


def _get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[frozenset, int]:
Expand All @@ -169,9 +160,7 @@ def merge_similar_insertions(insertions, mutation_loci) -> dict[dict[frozenset[s
key, val = list(insertion.items())[0]
insertions_merged[idx][frozenset([key])] = val
continue
query = list(insertion.keys())[0]
X = _get_normalized_scores(query, list(insertion.keys()))
labels = _optimize_labels(X)
labels = _optimize_labels(insertion)
insertions_merged[idx] = _get_merged_insertion(insertion, labels)
return insertions_merged

Expand Down

0 comments on commit fb7074c

Please sign in to comment.