Skip to content

Commit

Permalink
Update extract_unique_insertions to merge highly similar extracted …
Browse files Browse the repository at this point in the history
…insertion sequences.
  • Loading branch information
akikuno committed May 28, 2024
1 parent 582f746 commit 50fe99f
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions src/DAJIN2/core/preprocess/insertions_to_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def clustering_insertions(cssplits_insertion: list[str]) -> list[int]:
_, distances, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))

insertion_lengths = [[len(c) for c in cs.split(",")] for cs in cssplits_insertion]

scores = [s + [d] for s, d in zip(insertion_lengths, distances)]

scores = [length + [distance] for length, distance in zip(insertion_lengths, distances)]


return MeanShift(bin_seeding=True).fit_predict(np.array(scores)).tolist()

Expand Down Expand Up @@ -162,12 +163,14 @@ def get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[
return insertion_merged


def remove_minor_groups(insertions_merged: dict[tuple[int], dict[tuple[str], int]], coverage: int, threshold: float = 0.5) -> dict[tuple[int], dict[tuple[str], int]]:
def remove_minor_groups(insertions_merged: dict[tuple[int], dict[tuple[str], int]], coverage: int, percentage: float = 0.5) -> dict[tuple[int], dict[tuple[str], int]]:
"""Remove minor groups with less than {percentage} % coverage or less than 5 reads."""
threshold = max(5, int(coverage * percentage // 100))
for _, ins in insertions_merged.items():
# Create a list of elements to delete
to_delete = []
for seq, count in ins.items():
if count < coverage * threshold // 100:
if count < threshold:
to_delete.append(seq)

# Delete the collected elements
Expand All @@ -178,7 +181,7 @@ def remove_minor_groups(insertions_merged: dict[tuple[int], dict[tuple[str], int


def merge_similar_insertions(
insertions: dict[int, dict[str, int]], mutation_loci: list[set[str]], coverage: int, threshold: float = 0.5
insertions: dict[int, dict[str, int]], mutation_loci: list[set[str]], coverage: int, percentage: float = 0.5
) -> dict[tuple[int], dict[tuple[str], int]]:
index_grouped = group_index_by_consecutive_insertions(mutation_loci)
insertions_grouped = group_insertions(insertions, index_grouped)
Expand All @@ -192,7 +195,7 @@ def merge_similar_insertions(
labels = clustering_insertions(insertion)
insertions_merged[idx] = get_merged_insertion(insertion, labels)

return remove_minor_groups(insertions_merged, coverage, threshold)
return remove_minor_groups(insertions_merged, coverage, percentage)


###########################################################
Expand Down Expand Up @@ -293,10 +296,11 @@ def filter_minor_label(
labels: list[int],
insertions_scores_sequences: list[tuple[list[int], str]],
coverage: int,
threshold: float = 0.5,
percentage: float = 0.5,
) -> tuple[list[int], list[str]]:
threshold = max(5, int(coverage * percentage // 100))
labels_, counts_ = np.unique(labels, return_counts=True)
minor_labels = {label for label, count in zip(labels_, counts_) if count < coverage * threshold // 100}
minor_labels = {label for label, count in zip(labels_, counts_) if count < threshold}
index_minor_labels = {i for i, label in enumerate(labels) if label in minor_labels}
labels_filtered = [label for i, label in enumerate(labels) if i not in index_minor_labels]
score_seq_filterd = [
Expand Down Expand Up @@ -476,18 +480,30 @@ def generate_fasta(cstag_insertions: dict[str, str]) -> dict[str, str]:
return fasta_insertions


def extract_unique_insertions(FASTA_ALLELES: dict[str, str], fasta_insertions: dict[str, str]) -> dict[str, str]:
def extract_unique_insertions(fasta_insertions: dict[str, str], FASTA_ALLELES: dict[str, str]) -> dict[str, str]:
"""
Extract unique insertion alleles if they are dissimilar to the FASTA_ALLELES input by the user.
"Unique insertion alleles" are defined as sequences that have a difference of more than 10 bases compared to the sequences in FASTA_ALLELES
"""
fasta_insertions_unique = fasta_insertions.copy()

to_keep = []
for query_key, query_seq in fasta_insertions.items():
for query_key, query_seq in fasta_insertions_unique.items():
_, distances, _ = zip(*process.extract_iter(query_seq, FASTA_ALLELES.values(), scorer=DamerauLevenshtein.distance))
if all(d > 10 for d in distances):
to_keep.append(query_key)
to_keep.append(query_key)
fasta_insertions_unique = {key: fasta_insertions_unique[key] for key in to_keep if key in fasta_insertions_unique}

to_delete = set()
for key, seq in fasta_insertions_unique.items():
if key in to_delete:
continue
_, distances, _ = zip(*process.extract_iter(seq, fasta_insertions_unique.values(), scorer=DamerauLevenshtein.distance))
similar_index = {i if d < 10 else None for i, d in enumerate(distances) if i != key}
to_delete |= similar_index
fasta_insertions_unique = {k: v for k, v in fasta_insertions_unique.items() if k not in to_delete}

return {key: fasta_insertions[key] for key in to_keep if key in fasta_insertions}
return fasta_insertions_unique


def update_labels(d: dict, FASTA_ALLELES: dict) -> dict:
Expand Down Expand Up @@ -546,7 +562,7 @@ def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES)

coverage: int = io.count_newlines(PATH_SAMPLE)

insertions_merged = merge_similar_insertions(insertions, MUTATION_LOCI, coverage, threshold=0.5)
insertions_merged = merge_similar_insertions(insertions, MUTATION_LOCI, coverage, percentage=0.5)
if all(True if v == {} else False for v in insertions_merged.values()):
""""If all the insertion alleles were minor, return None"""
return None
Expand All @@ -556,7 +572,7 @@ def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES)
cssplits_insertion = [cssplit for _, cssplit in insertions_scores_sequences]
labels = clustering_insertions(cssplits_insertion)
labels_filtered, insertion_scores_sequences_filtered = filter_minor_label(
labels, insertions_scores_sequences, coverage, threshold=0.5
labels, insertions_scores_sequences, coverage, percentage=0.5
)

# Consensus calling
Expand All @@ -576,7 +592,7 @@ def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES)
index_of_insertions = extract_index_of_insertions(insertions, insertions_merged)
cstag_insertions = generate_cstag(consensus_of_insertions, index_of_insertions, FASTA_ALLELES["control"])
fasta_insertions = generate_fasta(cstag_insertions)
fasta_insertions_unique = extract_unique_insertions(FASTA_ALLELES, fasta_insertions)
fasta_insertions_unique = extract_unique_insertions(fasta_insertions, FASTA_ALLELES)

if fasta_insertions_unique == dict():
remove_temporal_files(TEMPDIR, SAMPLE_NAME)
Expand Down

0 comments on commit 50fe99f

Please sign in to comment.