diff --git a/src/DAJIN2/core/consensus/clust_formatter.py b/src/DAJIN2/core/consensus/clust_formatter.py index 7a510d70..c221cab2 100644 --- a/src/DAJIN2/core/consensus/clust_formatter.py +++ b/src/DAJIN2/core/consensus/clust_formatter.py @@ -3,8 +3,11 @@ from pathlib import Path from itertools import groupby +from sklearn.cluster import MiniBatchKMeans + from DAJIN2.utils import io from DAJIN2.core.preprocess.mutation_extractor import summarize_indels, extract_mutation_loci +from DAJIN2.core.consensus.similarity_searcher import cache_selected_control_by_similarity def subset_clust(clust_sample: list[dict], num: int = 1000) -> list[dict]: @@ -15,41 +18,68 @@ def subset_clust(clust_sample: list[dict], num: int = 1000) -> list[dict]: return clust_subset_sample -def cache_mutation_loci(ARGS, clust_sample: list[dict]) -> None: - tempdir, sample_name, control_name, fasta_alleles = ( - ARGS.tempdir, - ARGS.sample_name, - ARGS.control_name, - ARGS.fasta_alleles, - ) +########################################################### +# cache_mutation_loci +########################################################### + + +def get_thresholds(path_indels_normalized_sample, path_indels_normalized_control) -> dict[str, float]: + indels_normalized_sample = io.load_pickle(path_indels_normalized_sample) + indels_normalized_control = io.load_pickle(path_indels_normalized_control) + thresholds = dict() + for mut in {"+", "-", "*"}: + values_sample = indels_normalized_sample[mut] + values_control = indels_normalized_control[mut] + values_subtract = values_sample - values_control + kmeans = MiniBatchKMeans(n_clusters=2, random_state=0).fit(values_subtract.reshape(-1, 1)) + threshold = kmeans.cluster_centers_.mean() + thresholds[mut] = max(threshold, 0.05) + return thresholds + + +def cache_normalized_indels(ARGS, path_midsv_sample: Path) -> None: + allele, label, *_ = path_midsv_sample.stem.split("_") + sequence = ARGS.fasta_alleles[allele] + + if Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}.json").exists(): + path_midsv_control = Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}.json") + else: + path_midsv_control = Path(ARGS.tempdir, ARGS.control_name, "midsv", f"{allele}_{ARGS.sample_name}.json") + + cache_selected_control_by_similarity(path_midsv_control, path_midsv_sample, path_midsv_sample.parent) + path_midsv_filtered_control = Path(path_midsv_sample.parent, f"{allele}_{label}_control.jsonl") + + _, indels_normalized_sample = summarize_indels(path_midsv_sample, sequence) + _, indels_normalized_control = summarize_indels(path_midsv_filtered_control, sequence) + + path_consensus = Path(ARGS.tempdir, ARGS.sample_name, "consensus") + io.save_pickle(indels_normalized_sample, Path(path_consensus, f"{allele}_{label}_normalized_sample.pickle")) + io.save_pickle(indels_normalized_control, Path(path_consensus, f"{allele}_{label}_normalized_control.pickle")) + + +def cache_mutation_loci(ARGS, clust_sample: list[dict]) -> None: # Separate clusters by label and cache them - path_consensus = Path(tempdir, sample_name, "consensus") clust_sample.sort(key=lambda x: [x["ALLELE"], x["LABEL"]]) + path_consensus = Path(ARGS.tempdir, ARGS.sample_name, "consensus") for (allele, label), group in groupby(clust_sample, key=lambda x: [x["ALLELE"], x["LABEL"]]): - io.write_jsonl(group, Path(path_consensus, f"clust_{allele}_{label}.jsonl")) + io.write_jsonl(group, Path(path_consensus, f"{allele}_{label}_sample.jsonl")) # Cache normalized indels counts - for path_clust in path_consensus.glob("clust_*.jsonl"): - _, allele, label = path_clust.stem.split("_") - sequence = fasta_alleles[allele] - _, indels_normalized = summarize_indels(path_clust, sequence) - io.save_pickle(indels_normalized, Path(path_consensus, f"clust_{allele}_{label}_normalized.pickle")) + for path_midsv_sample in path_consensus.glob("*_sample.jsonl"): + cache_normalized_indels(ARGS, path_midsv_sample) # Extract and cache mutation loci - path_mutation_control = Path(tempdir, control_name, "mutation_loci") - for path_indels_normalized_sample in path_consensus.glob("clust_*normalized.pickle"): - _, allele, label, _ = path_indels_normalized_sample.stem.split("_") - - sequence = fasta_alleles[allele] + for path_indels_normalized_sample in path_consensus.glob("*_normalized_sample.pickle"): + allele, label, *_ = path_indels_normalized_sample.stem.split("_") + path_indels_normalized_control = Path(path_consensus, f"{allele}_{label}_normalized_control.pickle") + sequence = ARGS.fasta_alleles[allele] + path_knockin = Path(ARGS.tempdir, ARGS.sample_name, "knockin_loci", f"{allele}.pickle") - file_name = f"{allele}_{sample_name}_normalized.pickle" - if not Path(path_mutation_control, file_name).exists(): - file_name = f"{allele}_normalized.pickle" - path_indels_normalized_control = Path(path_mutation_control, file_name) + thresholds = get_thresholds(path_indels_normalized_sample, path_indels_normalized_control) - path_knockin = Path(tempdir, sample_name, "knockin_loci", f"{allele}.pickle") mutation_loci = extract_mutation_loci( - sequence, path_indels_normalized_sample, path_indels_normalized_control, path_knockin + sequence, path_indels_normalized_sample, path_indels_normalized_control, path_knockin, thresholds ) - io.save_pickle(mutation_loci, Path(path_consensus, f"clust_{allele}_{label}_mutation_loci.pickle")) + + io.save_pickle(mutation_loci, Path(path_consensus, f"{allele}_{label}_mutation_loci.pickle")) diff --git a/src/DAJIN2/core/preprocess/mutation_extractor.py b/src/DAJIN2/core/preprocess/mutation_extractor.py index 2db85d93..f073e228 100644 --- a/src/DAJIN2/core/preprocess/mutation_extractor.py +++ b/src/DAJIN2/core/preprocess/mutation_extractor.py @@ -16,9 +16,7 @@ import numpy as np -# from scipy import stats -# from scipy.spatial import distance -from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans from DAJIN2.utils import io from DAJIN2.core.preprocess import homopolymer_handler @@ -81,78 +79,12 @@ def split_kmer(indels: dict[str, np.array], kmer: int = 11) -> dict[str, np.arra return results -########################################################### -# Using Cosine similarity and T-test to extract dissimilar Loci -########################################################### - - -# def calculate_cosine_similarities(values_sample: list[float], values_control: list[float]) -> list[float]: -# """ -# Calculate cosine similarities between sample and control values. - -# Due to the behavior of distance.cosine, when dealing with zero-vectors, -# it doesn't return the expected cosine distance of 1. For example, distance.cosine([0,0,0], [1,2,3]) returns 0. -# To handle this, a small value (1e-10) is added to each element of the vector to prevent them from being zero-vectors. -# This ensures the correct behavior without significantly affecting the cosine similarity calculation. -# """ -# return [1 - distance.cosine(x + 1e-10, y + 1e-10) for x, y in zip(values_sample, values_control)] - - -# def perform_statistics(values_sample: list[float], values_control: list[float]) -> list[float]: -# """ -# Perform statistics between sample and control values. -# """ -# return [1 if np.array_equal(x, y) else stats.wilcoxon(x, y)[1] for x, y in zip(values_sample, values_control)] - - -# def find_dissimilar_indices(cossims: list[float], pvalues: list[float], pvalues_deleted: list[float]) -> set[int]: -# """Identify indices that are dissimilar based on cosine similarities and statistics p-values.""" -# return { -# i -# for i, (cossim, pval, pval_del) in enumerate(zip(cossims, pvalues, pvalues_deleted)) -# if (cossim >= 0.8 and pval < 0.05 and pval_del > 0.05) or cossim < 0.8 -# } - - -# def extract_dissimilar_loci( -# indels_normalized_sample: dict[str, list[float]], indels_normalized_control: dict[str, list[float]] -# ) -> dict[str, set[int]]: -# """ -# Compare Sample and Control to identify dissimilar loci. - -# Loci that do not closely resemble the reference in both mean and variance, indicating statistically significant differences, are detected as dissimilar loci. -# """ -# results = {} -# for mut in {"+", "-", "*"}: -# kmer = 11 - -# indels_kmer_sample = split_kmer(indels_normalized_sample, kmer=kmer) -# indels_kmer_control = split_kmer(indels_normalized_control, kmer=kmer) - -# values_sample = indels_kmer_sample[mut] -# values_control = indels_kmer_control[mut] - -# values_deleted_sample = [np.delete(v, kmer // 2) for v in values_sample] -# values_deleted_control = [np.delete(v, kmer // 2) for v in values_control] - -# cossims = calculate_cosine_similarities(values_sample, values_control) -# pvalues = perform_statistics(values_sample, values_control) -# pvalues_deleted = perform_statistics(values_deleted_sample, values_deleted_control) - -# results[mut] = find_dissimilar_indices(cossims, pvalues, pvalues_deleted) - -# return results - - ########################################################### # Extract Anomalous Loci -# The function `extract_dissimilar_loci` overlooks the mutation rate within each kmer. -# As a result, we encounter numerous false positives, especially in kmers with an extremely low mutation rate. -# It's essential to account for the mutation rate across the entire sequence. ########################################################### -def detect_anomalies(values_subtract: np.array) -> list[int]: +def detect_anomalies(values_subtract: np.array) -> set[int]: """ Detect anomalies and return indices of outliers. @@ -163,22 +95,30 @@ def detect_anomalies(values_subtract: np.array) -> list[int]: This function returns the indices of the class with the higher mean of values_subtract values, as this class is considered to be the true anomalies. """ - kmeans = KMeans(n_clusters=2, random_state=0) - _ = kmeans.fit_predict(values_subtract) + values_subtract_reshaped = values_subtract.reshape(-1, 1) + kmeans = MiniBatchKMeans(n_clusters=2, random_state=0) + _ = kmeans.fit_predict(values_subtract_reshaped) threshold = kmeans.cluster_centers_.mean() - return [i for i, v in enumerate(values_subtract) if v > threshold] + return {i for i, v in enumerate(values_subtract_reshaped) if v > threshold} -def extract_anomal_loci(indels_normalized_sample, indels_normalized_control) -> dict[str, set[int]]: +def extract_anomal_loci( + indels_normalized_sample, + indels_normalized_control, + thresholds: dict[str, float], +) -> dict[str, set[int]]: results = dict() for mut in {"+", "-", "*"}: values_sample = indels_normalized_sample[mut] values_control = indels_normalized_control[mut] values_subtract = values_sample - values_control - """"When the result of subtraction is 0.05 (%) or less, ignore it as 0""" - values_subtract = np.where(values_subtract <= 0.05, 0, values_subtract) - idx_outliers = detect_anomalies(values_subtract.reshape(-1, 1)) - results[mut] = set(idx_outliers) + """" + When the result of subtraction is threshold (%) or less, ignore it as 0 + """ + threshold = thresholds[mut] + values_subtract = np.where(values_subtract <= threshold, 0, values_subtract) + idx_outliers = detect_anomalies(values_subtract) + results[mut] = idx_outliers return results @@ -322,23 +262,24 @@ def cache_indels_count(ARGS, is_control: bool = False, is_insertion: bool = Fals def extract_mutation_loci( - sequence: str, path_indels_normalized_sample: Path, path_indels_normalized_control: Path, path_knockin: Path + sequence: str, + path_indels_normalized_sample: Path, + path_indels_normalized_control: Path, + path_knockin: Path, + thresholds: dict[str, float] = {"*": 0.05, "-": 0.05, "+": 0.05}, ) -> list[set[str]]: indels_normalized_sample = io.load_pickle(path_indels_normalized_sample) indels_normalized_control = io.load_pickle(path_indels_normalized_control) - indels_normalized_control = minimize_mutation_counts(indels_normalized_control, indels_normalized_sample) """Extract candidate mutation loci""" - # dissimilar_loci = extract_dissimilar_loci(indels_normalized_sample, indels_normalized_control) - anomal_loci = extract_anomal_loci(indels_normalized_sample, indels_normalized_control) - # candidate_loci = merge_loci(dissimilar_loci, anomal_loci) - candidate_loci = anomal_loci + indels_normalized_minimize_control = minimize_mutation_counts(indels_normalized_control, indels_normalized_sample) + anomal_loci = extract_anomal_loci(indels_normalized_sample, indels_normalized_minimize_control, thresholds) """Extract error loci in homopolymer regions""" errors_in_homopolymer = homopolymer_handler.extract_errors( - sequence, indels_normalized_sample, indels_normalized_control, candidate_loci + sequence, indels_normalized_sample, indels_normalized_control, anomal_loci ) - mutation_loci = discard_errors_in_homopolymer(candidate_loci, errors_in_homopolymer) + mutation_loci = discard_errors_in_homopolymer(anomal_loci, errors_in_homopolymer) """Merge all mutations and knockin loci""" if path_knockin.exists():