Skip to content

Commit

Permalink
Added preprocess.insertions_to_fasta.clustering_insertions: Combine…
Browse files Browse the repository at this point in the history
…d the clustering methods used in `extract_enriched_insertions` and `merge_similar_insertions` into a common function.
  • Loading branch information
akikuno committed Jan 18, 2024
1 parent fb7074c commit 6d7ff79
Showing 1 changed file with 35 additions and 47 deletions.
82 changes: 35 additions & 47 deletions src/DAJIN2/core/preprocess/insertions_to_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def remove_non_alphabets(s):
return "".join([char for char in s if char.isalpha()])


def clustering_insertions(insertions_cssplit: list[str]) -> list[int]:
seq_all = [remove_non_alphabets(seq) for seq in insertions_cssplit]
query = seq_all[0]
_, distances, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))

# Add random values from 0 to 1
distances = list(distances)
rng = np.random.default_rng(1)
distances.extend(rng.random(max(100, len(seq_all))))

return MeanShift().fit(np.array(distances).reshape(-1, 1)).labels_.tolist()[: len(seq_all)]


###########################################################
# Detect insertion sequences
###########################################################
Expand All @@ -49,18 +62,11 @@ def extract_enriched_insertions(
for i in insertions_sample:
ins_sample = insertions_sample[i]
ins_control = insertions_control.get(i, [])
seq_all = [remove_non_alphabets(seq) for seq in ins_sample + ins_control]
query = seq_all[0]
_, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))

# Add random values from 0 to 1
scores = list(scores)
rng = np.random.default_rng(1)
scores.extend(rng.random(max(100, len(scores))))
labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist()

labels_all = clustering_insertions(ins_sample + ins_control)
labels_sample = labels_all[: len(ins_sample)]
labels_control = labels_all[len(ins_sample) :]
labels_control = labels_all[len(ins_sample) : len(ins_sample) + len(ins_control)]

labels_count_sample = dict(Counter(labels_sample))
labels_count_control = dict(Counter(labels_control))

Expand All @@ -77,7 +83,8 @@ def extract_enriched_insertions(
to_delete.add(label)

for label in to_delete:
del labels_count_sample[label]
if label in labels_count_sample:
del labels_count_sample[label]

if labels_count_sample == dict():
continue
Expand Down Expand Up @@ -125,17 +132,6 @@ def _group_insertions(insertions, index_grouped) -> dict[dict[str, int]]:
return insertions_grouped


def _optimize_labels(insertion: dict[str, int]) -> list[int]:
seq_all = [remove_non_alphabets(seq) for seq in insertion]
query = seq_all[0]
_, scores, _ = zip(*process.extract_iter(query, seq_all, scorer=DamerauLevenshtein.normalized_distance))
scores = list(scores)
rng = np.random.default_rng(1)
scores.extend(rng.random(max(100, len(scores))))
labels_all = MeanShift().fit(np.array(scores).reshape(-1, 1)).labels_.tolist()
return labels_all[: len(insertion)]


def _get_merged_insertion(insertion: dict[str, int], labels: np.ndarray) -> dict[frozenset, int]:
insertion_label = [(label, {key: val}) for label, (key, val) in zip(labels, insertion.items())]
insertion_label.sort(key=lambda x: x[0])
Expand All @@ -160,7 +156,7 @@ def merge_similar_insertions(insertions, mutation_loci) -> dict[dict[frozenset[s
key, val = list(insertion.items())[0]
insertions_merged[idx][frozenset([key])] = val
continue
labels = _optimize_labels(insertion)
labels = clustering_insertions(insertion)
insertions_merged[idx] = _get_merged_insertion(insertion, labels)
return insertions_merged

Expand All @@ -174,34 +170,26 @@ def extract_score_and_sequence(path_sample, insertions_merged) -> list[tuple[lis
scores = []
sequences = []
for m in io.read_jsonl(path_sample):
score = defaultdict(int)
seq = defaultdict(lambda: "N")
cssplits = m["CSSPLIT"].split(",")
for idx_grouped, ins in insertions_merged.items():
score[idx_grouped] = 0
seq[idx_grouped] = "N"
score = [0] * len(insertions_merged)
sequence = ["N"] * len(insertions_merged)
for i, (idx_grouped, insertions) in enumerate(insertions_merged.items()):
for idx in idx_grouped:
for seqs, value in ins.items():
if cssplits[idx] in seqs:
score[idx_grouped] = value
seq[idx_grouped] = cssplits[idx]
if any(score.values()):
scores.append([s for _, s in sorted(score.items())])
sequences.append(",".join(s for _, s in sorted(seq.items())))
if not cssplits[idx].startswith("+"):
continue
for seqs, count in insertions.items():
_, distances, _ = zip(
*process.extract_iter(cssplits[idx], seqs, scorer=DamerauLevenshtein.normalized_distance)
)
if any(True for d in distances if d < 0.1):
score[i] = count
sequence[i] = cssplits[idx]
if any(score):
scores.append(score)
sequences.append(",".join(sequence))
return [(score, sequence) for score, sequence in zip(scores, sequences)]


def clustering_insertions(scores) -> list[int]:
X = np.array(scores)
if X.max() - X.min() == 0:
return [1] * len(X)
X = (X - X.min()) / (X.max() - X.min())
np.random.seed(seed=1)
clustering = MeanShift().fit(X)
labels = clustering.labels_
return labels.tolist()


def filter_minor_label(
path_sample: str, labels: list[int], insertions_scores_sequences, threshold: float = 0.5
) -> tuple(list[int], list[str]):
Expand Down Expand Up @@ -361,7 +349,7 @@ def generate_insertion_fasta(TEMPDIR, SAMPLE_NAME, CONTROL_NAME, FASTA_ALLELES)
return None
insertions_merged = merge_similar_insertions(insertions, mutation_loci)
insertions_scores_sequences = extract_score_and_sequence(path_sample, insertions_merged)
labels = clustering_insertions([score for score, _ in insertions_scores_sequences])
labels = clustering_insertions([cssplit for _, cssplit in insertions_scores_sequences])
labels_filtered, insertion_scores_sequences_filtered = filter_minor_label(
path_sample, labels, insertions_scores_sequences, threshold=0.5
)
Expand Down

0 comments on commit 6d7ff79

Please sign in to comment.