Skip to content

Commit

Permalink
+ Update preprocess.insertions_to_fasta.py which detects unintended…
Browse files Browse the repository at this point in the history
… insertion alleles.

  + `clustering_insertions`: To accelerate MeanShift clustering, set `bin_seeding=True`. Additionally, because clustering decoys without variation becomes extremely slow, we have switched to using decoys that include slight variations.
  + `extract_unique_insertions`: Within `unintended insertion alleles`, alleles similar to the `intended allele` provided by the user are now excluded.
    + The similarity is defined as there being differences of more than 10 bases
  • Loading branch information
akikuno committed May 1, 2024
1 parent 524a195 commit d8bbd9f
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions src/DAJIN2/core/preprocess/insertions_to_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rapidfuzz import process
from rapidfuzz.distance import DamerauLevenshtein

from sklearn.cluster import HDBSCAN
from sklearn.cluster import MeanShift

from DAJIN2.core.preprocess.mapping import to_sam
from DAJIN2.utils import io, config
Expand All @@ -32,17 +32,23 @@ def remove_non_alphabets(cssplits: str) -> str:
###########################################################


def clustering_insertions(insertions_cssplit: list[str], n_random: int = 1000) -> list[int]:
def clustering_insertions(insertions_cssplit: list[str], n_decoy: int = 1000) -> list[int]:
seq_all = [remove_non_alphabets(seq) for seq in insertions_cssplit]
query = seq_all[0]
_, distances, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))

# Add random values from 0 to 1 to prevent clustering due to minor differences.
# By adding upper (1) and lower (0) limits, we prevent errors where minor differences are clustered (e.g., 0.1 and 0.2 becoming separate clusters).
distances = list(distances)

# As MeanShift becomes extremely slow with values that have no variation, we add appropriate variation.

rng = np.random.default_rng(1)
distances.extend(rng.random(n_random))
distances.extend(rng.uniform(0.0, 0.001, n_decoy // 2).tolist())
distances.extend(rng.uniform(0.999, 1.0, 500).tolist())

return HDBSCAN().fit(np.array(distances).reshape(-1, 1)).labels_.tolist()
# Currently, MeanShift is the preferred algorithm. Other clustering methods like HDBSCAN, OPTICS, and Birch tend to produce overly fine clusters, even though they operate faster than MeanShift.

return MeanShift(bin_seeding=True).fit_predict(np.array(distances).reshape(-1, 1)).tolist()[:len(seq_all)]


###########################################################
Expand Down Expand Up @@ -75,19 +81,11 @@ def extract_enriched_insertions(
labels_all = clustering_insertions(ins_sample + ins_control)
labels_sample = labels_all[: len(ins_sample)]
labels_control = labels_all[len(ins_sample) : len(ins_sample) + len(ins_control)]
labels_random = labels_all[len(ins_sample) + len(ins_control) :]

labels_count_sample = dict(Counter(labels_sample))
labels_count_control = dict(Counter(labels_control))
labels_count_random = dict(Counter(labels_random))

to_delete = set()
# To remove labels containing a high proportion of randoms (5% or more, or 5 or more reads).
threshold_random = max(5, int(len(labels_random) * 0.05))
for label, count_random in labels_count_random.items():
if count_random > threshold_random:
to_delete.add(label)

# To remove labels containing a high proportion of controls (5% or more, or 5 or more reads).
threshold_control = max(5, int(len(labels_control) * 0.05))
for label, count_control in labels_count_control.items():
Expand Down Expand Up @@ -197,7 +195,7 @@ def merge_similar_insertions(
seq, count = next(iter(insertion.items()))
insertions_merged[idx] = {tuple([seq]): count}
continue
labels = clustering_insertions(insertion, n_random=0)
labels = clustering_insertions(insertion, n_decoy=1000)
insertions_merged[idx] = get_merged_insertion(insertion, labels)

return remove_minor_groups(insertions_merged, coverage, threshold)
Expand Down Expand Up @@ -502,11 +500,14 @@ def generate_fasta(cstag_insertions: dict[str, str]) -> dict[str, str]:


def extract_unique_insertions(FASTA_ALLELES: dict[str, str], fasta_insertions: dict[str, str]) -> dict[str, str]:
"""Extract unique insertion alleles if they are dissimilar to the FASTA_ALLELES input by the user"""
"""
Extract unique insertion alleles if they are dissimilar to the FASTA_ALLELES input by the user.
"Unique insertion alleles" are defined as sequences that have a difference of more than 10 bases compared to the sequences in FASTA_ALLELES
"""
to_keep = []
for query_key, query_seq in fasta_insertions.items():
_, distances, _ = zip(*process.extract_iter(query_seq, FASTA_ALLELES.values(), scorer=DamerauLevenshtein.normalized_distance))
if all(d > 0.1 for d in distances):
_, distances, _ = zip(*process.extract_iter(query_seq, FASTA_ALLELES.values(), scorer=DamerauLevenshtein.distance))
if all(d > 10 for d in distances):
to_keep.append(query_key)

return {key: fasta_insertions[key] for key in to_keep if key in fasta_insertions}
Expand Down Expand Up @@ -555,7 +556,7 @@ def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES)
# Clustering similar insertion alleles
insertions_scores_sequences = extract_score_and_sequence(PATH_SAMPLE, insertions_merged)
cssplits_insertion = [cssplit for _, cssplit in insertions_scores_sequences]
labels = clustering_insertions(cssplits_insertion, n_random=1000)[: len(cssplits_insertion)]
labels = clustering_insertions(cssplits_insertion, n_decoy=1000)
labels_filtered, insertion_scores_sequences_filtered = filter_minor_label(
labels, insertions_scores_sequences, coverage, threshold=0.5
)
Expand Down

0 comments on commit d8bbd9f

Please sign in to comment.