Skip to content

Commit

Permalink
Updated mutation_extractor.py:
Browse files Browse the repository at this point in the history
- Changed the method in consensus.clust_formatter.get_thresholds to dynamically define the thresholds for ignoring mutations, instead of using fixed values.
- Removed code that was previously commented out."
  • Loading branch information
akikuno committed Dec 22, 2023
1 parent 98a8a45 commit 2249d16
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 112 deletions.
82 changes: 56 additions & 26 deletions src/DAJIN2/core/consensus/clust_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from pathlib import Path
from itertools import groupby

from sklearn.cluster import MiniBatchKMeans

from DAJIN2.utils import io
from DAJIN2.core.preprocess.mutation_extractor import summarize_indels, extract_mutation_loci
from DAJIN2.core.consensus.similarity_searcher import cache_selected_control_by_similarity


def subset_clust(clust_sample: list[dict], num: int = 1000) -> list[dict]:
Expand All @@ -15,41 +18,68 @@ def subset_clust(clust_sample: list[dict], num: int = 1000) -> list[dict]:
return clust_subset_sample


def cache_mutation_loci(ARGS, clust_sample: list[dict]) -> None:
tempdir, sample_name, control_name, fasta_alleles = (
ARGS.tempdir,
ARGS.sample_name,
ARGS.control_name,
ARGS.fasta_alleles,
)
###########################################################
# cache_mutation_loci
###########################################################


def get_thresholds(path_indels_normalized_sample, path_indels_normalized_control) -> dict[str, float]:
indels_normalized_sample = io.load_pickle(path_indels_normalized_sample)
indels_normalized_control = io.load_pickle(path_indels_normalized_control)
thresholds = dict()
for mut in {"+", "-", "*"}:
values_sample = indels_normalized_sample[mut]
values_control = indels_normalized_control[mut]
values_subtract = values_sample - values_control
kmeans = MiniBatchKMeans(n_clusters=2, random_state=0).fit(values_subtract.reshape(-1, 1))
threshold = kmeans.cluster_centers_.mean()
thresholds[mut] = max(threshold, 0.05)
return thresholds


def cache_normalized_indels(ARGS, path_midsv_sample: Path) -> None:
allele, label, *_ = path_midsv_sample.stem.split("_")
sequence = ARGS.fasta_alleles[allele]

if Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}.json").exists():
path_midsv_control = Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}.json")
else:
path_midsv_control = Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}_{ARGS.sample_name}.json")

cache_selected_control_by_similarity(path_midsv_control, path_midsv_sample, path_midsv_sample.parent)

path_midsv_filtered_control = Path(path_midsv_sample.parent, f"{allele}_{label}_control.jsonl")

_, indels_normalized_sample = summarize_indels(path_midsv_sample, sequence)
_, indels_normalized_control = summarize_indels(path_midsv_filtered_control, sequence)

path_consensus = Path(ARGS.tempdir, ARGS.sample_name, "consensus")
io.save_pickle(indels_normalized_sample, Path(path_consensus, f"{allele}_{label}_normalized_sample.pickle"))
io.save_pickle(indels_normalized_control, Path(path_consensus, f"{allele}_{label}_normalized_control.pickle"))


def cache_mutation_loci(ARGS, clust_sample: list[dict]) -> None:
# Separate clusters by label and cache them
path_consensus = Path(tempdir, sample_name, "consensus")
clust_sample.sort(key=lambda x: [x["ALLELE"], x["LABEL"]])
path_consensus = Path(ARGS.tempdir, ARGS.sample_name, "consensus")
for (allele, label), group in groupby(clust_sample, key=lambda x: [x["ALLELE"], x["LABEL"]]):
io.write_jsonl(group, Path(path_consensus, f"clust_{allele}_{label}.jsonl"))
io.write_jsonl(group, Path(path_consensus, f"{allele}_{label}_sample.jsonl"))

# Cache normalized indels counts
for path_clust in path_consensus.glob("clust_*.jsonl"):
_, allele, label = path_clust.stem.split("_")
sequence = fasta_alleles[allele]
_, indels_normalized = summarize_indels(path_clust, sequence)
io.save_pickle(indels_normalized, Path(path_consensus, f"clust_{allele}_{label}_normalized.pickle"))
for path_midsv_sample in path_consensus.glob("*_sample.jsonl"):
cache_normalized_indels(ARGS, path_midsv_sample)

# Extract and cache mutation loci
path_mutation_control = Path(tempdir, control_name, "mutation_loci")
for path_indels_normalized_sample in path_consensus.glob("clust_*normalized.pickle"):
_, allele, label, _ = path_indels_normalized_sample.stem.split("_")

sequence = fasta_alleles[allele]
for path_indels_normalized_sample in path_consensus.glob("*_normalized_sample.pickle"):
allele, label, *_ = path_indels_normalized_sample.stem.split("_")
path_indels_normalized_control = Path(path_consensus, f"{allele}_{label}_normalized_control.pickle")
sequence = ARGS.fasta_alleles[allele]
path_knockin = Path(ARGS.tempdir, ARGS.sample_name, "knockin_loci", f"{allele}.pickle")

file_name = f"{allele}_{sample_name}_normalized.pickle"
if not Path(path_mutation_control, file_name).exists():
file_name = f"{allele}_normalized.pickle"
path_indels_normalized_control = Path(path_mutation_control, file_name)
thresholds = get_thresholds(path_indels_normalized_sample, path_indels_normalized_control)

path_knockin = Path(tempdir, sample_name, "knockin_loci", f"{allele}.pickle")
mutation_loci = extract_mutation_loci(
sequence, path_indels_normalized_sample, path_indels_normalized_control, path_knockin
sequence, path_indels_normalized_sample, path_indels_normalized_control, path_knockin, thresholds
)
io.save_pickle(mutation_loci, Path(path_consensus, f"clust_{allele}_{label}_mutation_loci.pickle"))

io.save_pickle(mutation_loci, Path(path_consensus, f"{allele}_{label}_mutation_loci.pickle"))
113 changes: 27 additions & 86 deletions src/DAJIN2/core/preprocess/mutation_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@


import numpy as np
# from scipy import stats
# from scipy.spatial import distance
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans

from DAJIN2.utils import io
from DAJIN2.core.preprocess import homopolymer_handler
Expand Down Expand Up @@ -81,78 +79,12 @@ def split_kmer(indels: dict[str, np.array], kmer: int = 11) -> dict[str, np.arra
return results


###########################################################
# Using Cosine similarity and T-test to extract dissimilar Loci
###########################################################


# def calculate_cosine_similarities(values_sample: list[float], values_control: list[float]) -> list[float]:
# """
# Calculate cosine similarities between sample and control values.

# Due to the behavior of distance.cosine, when dealing with zero-vectors,
# it doesn't return the expected cosine distance of 1. For example, distance.cosine([0,0,0], [1,2,3]) returns 0.
# To handle this, a small value (1e-10) is added to each element of the vector to prevent them from being zero-vectors.
# This ensures the correct behavior without significantly affecting the cosine similarity calculation.
# """
# return [1 - distance.cosine(x + 1e-10, y + 1e-10) for x, y in zip(values_sample, values_control)]


# def perform_statistics(values_sample: list[float], values_control: list[float]) -> list[float]:
# """
# Perform statistics between sample and control values.
# """
# return [1 if np.array_equal(x, y) else stats.wilcoxon(x, y)[1] for x, y in zip(values_sample, values_control)]


# def find_dissimilar_indices(cossims: list[float], pvalues: list[float], pvalues_deleted: list[float]) -> set[int]:
# """Identify indices that are dissimilar based on cosine similarities and statistics p-values."""
# return {
# i
# for i, (cossim, pval, pval_del) in enumerate(zip(cossims, pvalues, pvalues_deleted))
# if (cossim >= 0.8 and pval < 0.05 and pval_del > 0.05) or cossim < 0.8
# }


# def extract_dissimilar_loci(
# indels_normalized_sample: dict[str, list[float]], indels_normalized_control: dict[str, list[float]]
# ) -> dict[str, set[int]]:
# """
# Compare Sample and Control to identify dissimilar loci.

# Loci that do not closely resemble the reference in both mean and variance, indicating statistically significant differences, are detected as dissimilar loci.
# """
# results = {}
# for mut in {"+", "-", "*"}:
# kmer = 11

# indels_kmer_sample = split_kmer(indels_normalized_sample, kmer=kmer)
# indels_kmer_control = split_kmer(indels_normalized_control, kmer=kmer)

# values_sample = indels_kmer_sample[mut]
# values_control = indels_kmer_control[mut]

# values_deleted_sample = [np.delete(v, kmer // 2) for v in values_sample]
# values_deleted_control = [np.delete(v, kmer // 2) for v in values_control]

# cossims = calculate_cosine_similarities(values_sample, values_control)
# pvalues = perform_statistics(values_sample, values_control)
# pvalues_deleted = perform_statistics(values_deleted_sample, values_deleted_control)

# results[mut] = find_dissimilar_indices(cossims, pvalues, pvalues_deleted)

# return results


###########################################################
# Extract Anomalous Loci
# The function `extract_dissimilar_loci` overlooks the mutation rate within each kmer.
# As a result, we encounter numerous false positives, especially in kmers with an extremely low mutation rate.
# It's essential to account for the mutation rate across the entire sequence.
###########################################################


def detect_anomalies(values_subtract: np.array) -> list[int]:
def detect_anomalies(values_subtract: np.array) -> set[int]:
"""
Detect anomalies and return indices of outliers.
Expand All @@ -163,22 +95,30 @@ def detect_anomalies(values_subtract: np.array) -> list[int]:
This function returns the indices of the class with the higher mean of values_subtract
values, as this class is considered to be the true anomalies.
"""
kmeans = KMeans(n_clusters=2, random_state=0)
_ = kmeans.fit_predict(values_subtract)
values_subtract_reshaped = values_subtract.reshape(-1, 1)
kmeans = MiniBatchKMeans(n_clusters=2, random_state=0)
_ = kmeans.fit_predict(values_subtract_reshaped)
threshold = kmeans.cluster_centers_.mean()
return [i for i, v in enumerate(values_subtract) if v > threshold]
return {i for i, v in enumerate(values_subtract_reshaped) if v > threshold}


def extract_anomal_loci(indels_normalized_sample, indels_normalized_control) -> dict[str, set[int]]:
def extract_anomal_loci(
indels_normalized_sample,
indels_normalized_control,
thresholds: dict[str, float],
) -> dict[str, set[int]]:
results = dict()
for mut in {"+", "-", "*"}:
values_sample = indels_normalized_sample[mut]
values_control = indels_normalized_control[mut]
values_subtract = values_sample - values_control
""""When the result of subtraction is 0.05 (%) or less, ignore it as 0"""
values_subtract = np.where(values_subtract <= 0.05, 0, values_subtract)
idx_outliers = detect_anomalies(values_subtract.reshape(-1, 1))
results[mut] = set(idx_outliers)
""""
When the result of subtraction is threshold (%) or less, ignore it as 0
"""
threshold = thresholds[mut]
values_subtract = np.where(values_subtract <= threshold, 0, values_subtract)
idx_outliers = detect_anomalies(values_subtract)
results[mut] = idx_outliers
return results


Expand Down Expand Up @@ -322,23 +262,24 @@ def cache_indels_count(ARGS, is_control: bool = False, is_insertion: bool = Fals


def extract_mutation_loci(
sequence: str, path_indels_normalized_sample: Path, path_indels_normalized_control: Path, path_knockin: Path
sequence: str,
path_indels_normalized_sample: Path,
path_indels_normalized_control: Path,
path_knockin: Path,
thresholds: dict[str, float] = {"*": 0.05, "-": 0.05, "+": 0.05},
) -> list[set[str]]:
indels_normalized_sample = io.load_pickle(path_indels_normalized_sample)
indels_normalized_control = io.load_pickle(path_indels_normalized_control)
indels_normalized_control = minimize_mutation_counts(indels_normalized_control, indels_normalized_sample)

"""Extract candidate mutation loci"""
# dissimilar_loci = extract_dissimilar_loci(indels_normalized_sample, indels_normalized_control)
anomal_loci = extract_anomal_loci(indels_normalized_sample, indels_normalized_control)
# candidate_loci = merge_loci(dissimilar_loci, anomal_loci)
candidate_loci = anomal_loci
indels_normalized_minimize_control = minimize_mutation_counts(indels_normalized_control, indels_normalized_sample)
anomal_loci = extract_anomal_loci(indels_normalized_sample, indels_normalized_minimize_control, thresholds)

"""Extract error loci in homopolymer regions"""
errors_in_homopolymer = homopolymer_handler.extract_errors(
sequence, indels_normalized_sample, indels_normalized_control, candidate_loci
sequence, indels_normalized_sample, indels_normalized_control, anomal_loci
)
mutation_loci = discard_errors_in_homopolymer(candidate_loci, errors_in_homopolymer)
mutation_loci = discard_errors_in_homopolymer(anomal_loci, errors_in_homopolymer)

"""Merge all mutations and knockin loci"""
if path_knockin.exists():
Expand Down

0 comments on commit 2249d16

Please sign in to comment.