From 3d868331ac20d2468fadedd39f32fc54e67ce8ac Mon Sep 17 00:00:00 2001 From: ashkan98 Date: Sat, 6 Jan 2024 03:54:16 +0100 Subject: [PATCH] Right calculation of precision and recall --- scripts/init.py | 186 +++++++++++++++++++----------------------------- 1 file changed, 73 insertions(+), 113 deletions(-) diff --git a/scripts/init.py b/scripts/init.py index a4e5abd..cb7caf6 100644 --- a/scripts/init.py +++ b/scripts/init.py @@ -507,6 +507,8 @@ def main(): help='optional: set threshold for number of samples sequenced from a lineage') parser.add_argument('-sig', '--signature', required=False, action='store_true', help='optional: for given (set of) lineage return the set of signature mutations') + parser.add_argument('-cs', '--collapse_sublineages', required=False, action='store_true', + help='optional: will collapse sublineages to reduce the size of "other lineages" which not exceeds sample size cut-off iteratively merging sublineages to higher level lineages until sample cut-off is reached ignore sublineages, converting lineage to lineage-complexes') parser.add_argument('-is', '--ignore_sublineages', required=False, action='store_true', help='optional: will ignore sublineages, to compute signature mutations for lineage-complexes') parser.add_argument('-out', '--output', metavar='', required=False, @@ -628,6 +630,10 @@ def main(): print("Compute mutation frequency ...") + if args.collapse_sublineages: + df_dna_aa_profile["lineage_decompressed"] = [aliasor.uncompress(x) for x in df_dna_aa_profile["lineage"]] + + if args.matrix or args.signature: print("Matrix will be created on the fly..") @@ -757,10 +763,10 @@ def sort_gene_mutations(item): #fair number of cases where zero samples are correctly assigned, and all are assigned to one specific other lineage #-> distance matrix of pairwise distances between all C^2 profiles - pairwise_distances = pdist(binary_c2_profiles.values, metric='hamming') - distance_matrix = squareform(pairwise_distances) - pairwise_distance_c2_df = pd.DataFrame(distance_matrix, columns=dict_lineage_char_mutations.keys(), index=dict_lineage_char_mutations.keys()) - pairwise_distance_c2_df.to_csv(f"output/frequency_verification/{date_range}/pairwise_distance_c2_{date_range}.csv") + #pairwise_distances = pdist(binary_c2_profiles.values, metric='hamming') + #distance_matrix = squareform(pairwise_distances) + #pairwise_distance_c2_df = pd.DataFrame(distance_matrix, columns=dict_lineage_char_mutations.keys(), index=dict_lineage_char_mutations.keys()) + #pairwise_distance_c2_df.to_csv(f"output/frequency_verification/{date_range}/pairwise_distance_c2_{date_range}.csv") multi_class_confusion_matrix = pd.DataFrame(0, columns=dict_lineage_char_mutations.keys(), index=dict_lineage_char_mutations.keys()) @@ -771,23 +777,6 @@ def calculate_confusion_matrix(row, binary_c2_profiles, sorted_all_unique_char_m binary_sample_profiles = binary_sample_profiles[sorted_all_unique_char_mutations] predicted_lineage = find_min_hamming_distance(binary_sample_profiles, binary_c2_profiles) return actual_lineage, predicted_lineage - - ''' - def calculate_precision(confusion_matrix): - true_positives = confusion_matrix.values.diagonal().sum() # Get diagonal values - false_positives = confusion_matrix.values.sum() - true_positives # Sum of non-diagonal values in columns - precision = true_positives / (true_positives + false_positives) - return true_positives, precision - ''' - - def calculate_recall(confusion_matrix, num_test_samples, true_positives): - false_negatives = 0 - for true_lineage in confusion_matrix.index: - if true_lineage in list(num_test_samples.keys()): - occurrences = num_test_samples[true_lineage] - false_negatives += occurrences - multi_class_confusion_matrix.loc[true_lineage].sum() - recall = true_positives / (true_positives + false_negatives) - return recall # Get number of available CPU cores num_cores = 32 # Change this value based on your CPU capacity @@ -798,60 +787,49 @@ def calculate_recall(confusion_matrix, num_test_samples, true_positives): for _, row in df_mutation_profile_without_others.iterrows() ) - ''' for actual_lineage, predicted_lineage in results: if actual_lineage in multi_class_confusion_matrix.index and predicted_lineage in multi_class_confusion_matrix.columns: multi_class_confusion_matrix.loc[actual_lineage, predicted_lineage] += 1 #multi_class_confusion_matrix.to_csv(f"output/frequency_verification/{date_range}/multiclass_confusion_matrix_{args.cut_off_frequency}_{date_range}.csv") - ''' + #multiclass precision and recall def calculate_precision_recall(confusion_matrix): precision_values = [] recall_values = [] - for pangolin, predictions in confusion_matrix.iterrows(): - true_positives = predictions[pangolin] - false_positives = predictions.values.sum() - true_positives - precision = true_positives / (true_positives + false_positives) - false_negatives = dict_filter_num_lineage[pangolin] - predictions.values.sum() - recall = true_positives / (true_positives + false_negatives) + for pangolin, predictions in confusion_matrix.iterrows(): + precision = predictions[pangolin] / predictions.values.sum() + recall = predictions[pangolin] / dict_filter_num_lineage[pangolin] precision_values.append(float(precision)), recall_values.append(float(recall)) - return precision_values, recall_values + return precision_values, recall_values - #precision_values, recall_values = calculate_precision_recall(multi_class_confusion_matrix) - #avg_precision, avg_recall = sum(precision_values) / len(precision_values), sum(recall_values) / len(recall_values) - #print(avg_precision, avg_recall) + precision_values, recall_values = calculate_precision_recall(multi_class_confusion_matrix) + avg_precision, avg_recall = sum(precision_values) / len(precision_values), sum(recall_values) / len(recall_values) + avg_precision, avg_recall = math.ceil(avg_precision * 1000) / 1000, math.ceil(avg_recall * 1000) / 1000 + print("Precision: ", avg_precision, "Recall: ", avg_recall) #Random Classifier + ''' def generate_predictions(lineages): predictions = [] possible_values = list(set(lineages)) # Unique values in the 'lineages' column possible_values.append('no assignment') for _ in range(len(lineages)): - random_values = [random.choice(possible_values) for _ in range(1000)] # Select 1000 random values + random_values = [random.choice(possible_values) for _ in range(10000)] # Select 1000 random values avg_prediction = max(set(random_values), key=random_values.count) # Select most frequent value predictions.append(avg_prediction) return predictions - # Add 'predicted_lineages' column to the dataframe df_mutation_profile_without_others['predicted_lineages'] = generate_predictions(df_mutation_profile_without_others['lineage']) random_classifier = pd.DataFrame(0, columns=dict_lineage_char_mutations.keys(), index=dict_lineage_char_mutations.keys()) for idx, rows in df_mutation_profile_without_others.iterrows(): - random_classifier.loc[rows['lineage'], rows['predicted_lineages']] += 1 + rc_actual_lineage, rc_predicted_lineage = rows['lineage'], rows['predicted_lineages'] + if rc_actual_lineage in multi_class_confusion_matrix.index and rc_predicted_lineage in multi_class_confusion_matrix.columns: + random_classifier.loc[rc_actual_lineage, rc_predicted_lineage] += 1 + rc_precision_values, rc_recall_values = calculate_precision_recall(random_classifier) rc_avg_precision, rc_avg_recall = sum(rc_precision_values) / len(rc_precision_values), sum(rc_recall_values) / len(rc_recall_values) print("Precision:", rc_avg_precision, "Recall: ", rc_avg_recall) - #print("Precision: ", calculate_precision(multi_class_confusion_matrix)) - ''' - plt.figure(figsize=(8, 6)) - plt.plot(recall_values, precision_values, marker='o', linestyle='-') - plt.xlabel('Recall') - plt.ylabel('Precision') - plt.title('Precision-Recall trade-off curve') - plt.grid(True) - plt.ylim(0.0, 1.0) - #plt.savefig(f"output/frequency_verification/{date_range}/precision-recall-curve_{args.cut_off_frequency}_{date_range}.png") - plt.show() ''' if args.cross_validation: @@ -862,80 +840,62 @@ def generate_predictions(lineages): counter=0 for train, test in kf.split(df_mutation_profile_without_others): train_df, test_df = df_mutation_profile_without_others.iloc[train].reset_index(drop=True), df_mutation_profile_without_others.iloc[test].reset_index(drop=True) + _, dict_test_num_lineage = init_num_lineages('lineage', test_df) - df_test_dna_aa_profile, dict_test_num_lineage = init_num_lineages('lineage', test_df) - - all_characteristic_mutations = [item for sublist in train_df[f"{args.mutation_level}_profile"].to_list() for item in sublist] - all_unique_char_mutations = sorted(list(dict.fromkeys(all_characteristic_mutations))) - filtered_N_all_char_mutations = [item for item in all_unique_char_mutations if 'N' not in item] - filtered_N_all_char_mutations = [element for element in filtered_N_all_char_mutations if element.startswith('del') or (re.compile(r'^[a-zA-Z]+\d+[a-zA-Z]+$')).match(element)] - sorted_all_unique_char_mutations = sorted(filtered_N_all_char_mutations, key=custom_sort_key) - - df_dna_aa_profile, dict_num_lineage = init_num_lineages('lineage', train_df) - df_dna_aa_profile["dna_profile"] = df_dna_aa_profile["dna_profile"].apply(lambda x: ' '.join(x)) - dict_num_lineage = dict(sorted(dict_num_lineage.items())) - sorted_lineage_list = list(dict_num_lineage.keys()) - - df_lineage_mutation_frequency = lineage_mutation_frequency(f"{args.mutation_level}_profile", df_dna_aa_profile, sorted_lineage_list, dict_num_lineage) - df_lineage_mutation_frequency = df_lineage_mutation_frequency.fillna(0)[(df_lineage_mutation_frequency.fillna(0) >= args.cut_off_frequency).any(axis=1)] - - dict_lineage_char_mutations = {} - for lineage in df_lineage_mutation_frequency.columns: - char_mutations = df_lineage_mutation_frequency.index[df_lineage_mutation_frequency[lineage] >= args.cut_off_frequency].tolist() - sorted_char_mutations = sorted(char_mutations, key=custom_sort_key) - dict_lineage_char_mutations[lineage] = sorted_char_mutations + all_train_characteristic_mutations = [item for sublist in train_df[f"{args.mutation_level}_profile"].to_list() for item in sublist] + all_train_unique_char_mutations = sorted(list(dict.fromkeys(all_train_characteristic_mutations))) + filtered_N_train_char_mutations = [item for item in all_train_unique_char_mutations if 'N' not in item] + filtered_N_train_char_mutations = [element for element in filtered_N_train_char_mutations if element.startswith('del') or (re.compile(r'^[a-zA-Z]+\d+[a-zA-Z]+$')).match(element)] + sorted_train_unique_char_mutations = sorted(filtered_N_train_char_mutations, key=custom_sort_key) + + df_dna_aa_train_profile, dict_num_train_lineage = init_num_lineages('lineage', train_df) + df_dna_aa_train_profile[f"{args.mutation_level}_profile"] = df_dna_aa_train_profile[f"{args.mutation_level}_profile"].apply(lambda x: ' '.join(x)) + dict_num_train_lineage = dict(sorted(dict_num_train_lineage.items())) + sorted_train_lineage_list = list(dict_num_train_lineage.keys()) + + df_train_lineage_mutation_frequency = lineage_mutation_frequency(f"{args.mutation_level}_profile", df_dna_aa_train_profile, sorted_train_lineage_list, dict_num_train_lineage) + df_train_lineage_mutation_frequency = df_train_lineage_mutation_frequency.fillna(0)[(df_train_lineage_mutation_frequency.fillna(0) >= args.cut_off_frequency).any(axis=1)] + + dict_train_lineage_char_mutations = {} + for train_lineage in df_train_lineage_mutation_frequency.columns: + train_char_mutations = df_train_lineage_mutation_frequency.index[df_train_lineage_mutation_frequency[train_lineage] >= args.cut_off_frequency].tolist() + sorted_train_char_mutations = sorted(train_char_mutations, key=custom_sort_key) + dict_train_lineage_char_mutations[train_lineage] = sorted_train_char_mutations - binary_c2_profiles = pd.DataFrame(0, columns=sorted_all_unique_char_mutations, index=dict_lineage_char_mutations.keys()) - for lineage, mutations in dict_lineage_char_mutations.items(): - binary_c2_profiles.loc[lineage, mutations] = 1 + binary_train_c2_profiles = pd.DataFrame(0, columns=sorted_train_unique_char_mutations, index=dict_train_lineage_char_mutations.keys()) + for train_lineage, train_mutations in dict_train_lineage_char_mutations.items(): + binary_train_c2_profiles.loc[train_lineage, train_mutations] = 1 - multi_class_confusion_matrix = pd.DataFrame(0, columns=dict_lineage_char_mutations.keys(), index=dict_lineage_char_mutations.keys()) - results = Parallel(n_jobs=num_cores)( - delayed(calculate_confusion_matrix)(row, binary_c2_profiles, sorted_all_unique_char_mutations) + multi_class_train_test_confusion_matrix = pd.DataFrame(0, columns=dict_train_lineage_char_mutations.keys(), index=dict_train_lineage_char_mutations.keys()) + test_results = Parallel(n_jobs=num_cores)( + delayed(calculate_confusion_matrix)(row, binary_train_c2_profiles, sorted_train_unique_char_mutations) for _, row in test_df.iterrows() ) - for actual_lineage, predicted_lineage in results: - if actual_lineage in multi_class_confusion_matrix.index and predicted_lineage in multi_class_confusion_matrix.columns: - multi_class_confusion_matrix.loc[actual_lineage, predicted_lineage] += 1 - + for actual_train_lineage, predicted_test_lineage in test_results: + if actual_train_lineage in multi_class_train_test_confusion_matrix.index and predicted_test_lineage in multi_class_train_test_confusion_matrix.columns: + multi_class_train_test_confusion_matrix.loc[actual_train_lineage, predicted_test_lineage] += 1 + counter+=1 - precision_values = [] - recall_values = [] - for pangolin, predictions in multi_class_confusion_matrix.iterrows(): - true_positives = predictions[pangolin] - false_positives = predictions.values.sum() - true_positives - if true_positives + false_positives == 0: - precision = 0 + test_precision_values = [] + test_recall_values = [] + for train_pangolin, test_predictions in multi_class_train_test_confusion_matrix.iterrows(): + if test_predictions.values.sum() == 0: + test_precision = 0 else: - precision = true_positives / (true_positives + false_positives) - - false_negatives = dict_filter_num_lineage[pangolin] - predictions.values.sum() - if true_positives + false_negatives == 0: - precision = 0 # Define what precision should be if denominator is zero - else: - recall = true_positives / (true_positives + false_negatives) - - precision_values.append(float(precision)), recall_values.append(float(recall)) - avg_precision, avg_recall = sum(precision_values) / len(precision_values), sum(recall_values) / len(recall_values) - - print(counter, "Precision:", avg_precision, "Recall: ", avg_recall) - - #print(multi_class_confusion_matrix.values.sum()) - ''' + test_precision = test_predictions[train_pangolin] / test_predictions.values.sum() + test_precision_values.append(float(test_precision)) + + if train_pangolin in dict_test_num_lineage: + if dict_test_num_lineage[train_pangolin] == 0: + test_recall = 0 # Define what precision should be if denominator is zero + else: + test_recall = test_predictions[train_pangolin] / dict_test_num_lineage[train_pangolin] + test_recall_values.append(float(test_recall)) + avg_test_precision, avg_test_recall = sum(test_precision_values) / len(test_precision_values), sum(test_recall_values) / len(test_recall_values) + avg_test_precision, avg_test_recall = math.ceil(avg_test_precision * 1000) / 1000, math.ceil(avg_test_recall * 1000) / 1000 + print(f"{counter}. fold", "Precision:", avg_test_precision, "Recall: ", avg_test_recall) - #print(f"Confusion matrix nr.{counter}:", "\n", multi_class_confusion_matrix) - #true_positives, precision = calculate_precision(multi_class_confusion_matrix) - multi_class_confusion_matrix.to_csv(f"output/frequency_verification/{date_range}/10-fold_confusion_matrix_{args.cut_off_frequency}_{counter}_{date_range}.csv") - plt.figure(figsize=(8, 6)) - plt.plot(recall_values, precision_values, marker='o', linestyle='-') - plt.xlabel('Recall') - plt.ylabel('Precision') - plt.title('Precision-Recall Curve') - plt.grid(True) - plt.ylim(0.0, 1.0) - plt.savefig(f"output/frequency_verification/{date_range}/10-fold_precision-recall_{args.cut_off_frequency}_{counter}_{date_range}.png") - ''' quit() else: #parent-child relationship @@ -1375,7 +1335,7 @@ def generate_predictions(lineages): with open(multi_fasta, "w") as f: SeqIO.write(records, f, "fasta") - print(f"Merged consensus file is created in {multi_fasta} and contains {num_of_lineages}") + print(f"Merged consensus file is created in {multi_fasta} and contains {dict_num_lineage}") if args.consensus_check: print("consensus seq check if the mutations are right")