From 3a79ce15617351d791de8a913bd48200c9bb1fa3 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 6 Dec 2024 16:35:47 -0800 Subject: [PATCH 1/2] add cellprofiler and no finetuning model to drug analysis + box plots --- configs/results/npm1_perturb.yaml | 21 +++++++++++--------- src/br/analysis/analysis_utils.py | 19 +++++++++++++----- src/br/analysis/run_drugdata_analysis.py | 25 ++++++++++++++++++------ src/br/chandrasekaran_et_al/utils.py | 21 ++++++++------------ 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/configs/results/npm1_perturb.yaml b/configs/results/npm1_perturb.yaml index 492e61b..5e73ec0 100644 --- a/configs/results/npm1_perturb.yaml +++ b/configs/results/npm1_perturb.yaml @@ -3,25 +3,28 @@ image_path: pc_path: model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_pointcloud_SDF.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_SDF.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_seg.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_SDF.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_seg.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_SDF.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_SDF.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_pointcloud_SDF.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/npm1/Rotation_invariant_pointcloud_SDF.ckpt", ] names: [ - "Rotation_invariant_pointcloud_SDF", - "Classical_image_SDF", "Classical_image_seg", - "Rotation_invariant_image_SDF", "Rotation_invariant_image_seg", + "Classical_image_SDF", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", + "Rotation_invariant_pointcloud_SDF_no_finetuning", ] data_paths: [ - "/data/npm1_perturb/pc.yaml", - "/data/npm1_perturb/classical_image_sdf.yaml", "/data/npm1_perturb/classical_image_seg.yaml", - "/data/npm1_perturb/so3_image_sdf.yaml", "/data/npm1_perturb/so3_image_seg.yaml", + "/data/npm1_perturb/classical_image_sdf.yaml", + "/data/npm1_perturb/so3_image_sdf.yaml", + "/data/npm1_perturb/pc.yaml", + "/data/npm1_perturb/pc.yaml", ] diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index c4bb50a..7d448b6 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -168,6 +168,19 @@ def setup_evaluation_params(manifest, run_names): eval_scaled_img_params = [{}] * len(run_names) if "SDF" in "\t".join(run_names): + loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) + sample_points_list = [False] * len(run_names) + skew_scale = None + + if "mesh_folder" not in manifest.columns: + return ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) + eval_scaled_img_resolution = 32 gt_mesh_dir = manifest["mesh_folder"].iloc[0] gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0] @@ -190,9 +203,6 @@ def setup_evaluation_params(manifest, run_names): "mesh_ext": "stl", } ) - loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) - sample_points_list = [False] * len(run_names) - skew_scale = None else: loss_eval_list = None skew_scale = 100 @@ -668,7 +678,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d this_save_path_input = Path(save_path) / Path(this_run_name) / Path("input") this_save_path_input.mkdir(parents=True, exist_ok=True) np.save(this_save_path_input / Path(f"{cell_id}.npy"), input) - + if canonical is not None: this_save_path_canon = ( Path(save_path) / Path(this_run_name) / Path("canonical") @@ -713,7 +723,6 @@ def _plot_image(input, recon, recon_canonical, dataset_name): recon = recon.T recon_canonical = recon_canonical.T - i = 2 fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) ax.imshow(input, cmap="gray_r") ax1.imshow(recon, cmap="gray_r") diff --git a/src/br/analysis/run_drugdata_analysis.py b/src/br/analysis/run_drugdata_analysis.py index fbc18a1..f942035 100644 --- a/src/br/analysis/run_drugdata_analysis.py +++ b/src/br/analysis/run_drugdata_analysis.py @@ -3,6 +3,8 @@ import sys from pathlib import Path +import pandas as pd + from br.chandrasekaran_et_al.utils import _plot, perturbation_detection from br.models.compute_features import get_embeddings from br.models.utils import get_all_configs_per_dataset @@ -10,12 +12,17 @@ def _get_featurecols(df): """returna list of featuredata columns.""" - return [c for c in df.columns if "mu" in c] + return [c for c in df.columns if "mu_" in c] + + +def _get_featurecols_cellprofiler(df): + """returna list of featuredata columns.""" + return [i for i in df.columns if "Mean" in i or "StDev" in i or "Median" in i] -def _get_featuredata(df): +def _get_featuredata(df, get_featurecols_fn): """return dataframe of just featuredata columns.""" - return df[_get_featurecols(df)] + return df[get_featurecols_fn(df)] def main(args): @@ -26,17 +33,23 @@ def main(args): dataset_name = args.dataset_name DATASET_INFO = get_all_configs_per_dataset(results_path) dataset = DATASET_INFO[dataset_name] - run_names = dataset["names"] + run_names = dataset["names"] + ["cellprofiler"] all_ret, df = get_embeddings(run_names, args.dataset_name, DATASET_INFO, args.embeddings_path) all_ret["well_position"] = "A0" # dummy all_ret["Assay_Plate_Barcode"] = "Plate0" # dummy - pert = perturbation_detection(all_ret, _get_featurecols, _get_featuredata) + df_1 = all_ret.loc[~all_ret["model"].isin(["cellprofiler"])].reset_index(drop=True) + pert1 = perturbation_detection(df_1, _get_featurecols, _get_featuredata) + + df_2 = all_ret.loc[all_ret["model"].isin(["cellprofiler"])].reset_index(drop=True) + pert2 = perturbation_detection(df_2, _get_featurecols_cellprofiler, _get_featuredata) + + pert = pd.concat([pert1, pert2], axis=0).reset_index(drop=True) this_save_path = Path(args.save_path) this_save_path.mkdir(parents=True, exist_ok=True) - _plot(pert, this_save_path) + _plot(pert, this_save_path, run_names) if __name__ == "__main__": diff --git a/src/br/chandrasekaran_et_al/utils.py b/src/br/chandrasekaran_et_al/utils.py index 12d887d..f8feb2e 100644 --- a/src/br/chandrasekaran_et_al/utils.py +++ b/src/br/chandrasekaran_et_al/utils.py @@ -25,12 +25,11 @@ from br.chandrasekaran_et_al import utils -def perturbation_detection(all_ret, get_featurecols, get_featuredata): +def perturbation_detection(all_ret, get_featurecols, get_featuredata, fit_pca=False): cols = get_featurecols(all_ret) replicate_feature = "Metadata_broad_sample" batch_size = 100000 null_size = 100000 - all_rep = [] for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): df_feats = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) @@ -124,7 +123,7 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata): neg_diffby = ["Metadata_negcon"] metadata_df = get_metadata(modality_1_df) - feature_df = get_featuredata(modality_1_df) + feature_df = get_featuredata(modality_1_df, get_featurecols) feature_values = feature_df.values result = run_pipeline( @@ -152,6 +151,7 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata): cell, modality_1_timepoint, ) + replicability_map_df["model"] = model all_rep.append(replicability_map_df) @@ -161,7 +161,7 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata): return all_rep -def _plot(all_rep, save_path): +def _plot(all_rep, save_path, run_names): sns.set_context("talk") sns.set(font_scale=1.7) sns.set_style("white") @@ -186,19 +186,14 @@ def _plot(all_rep, save_path): data=test, x="Drugs", y="q_value", + kind="bar", hue="model", - kind="point", order=x_order, - hue_order=[ - "Classical_image_seg", - "Rotation_invariant_image_seg", - "Classical_image_SDF", - "Rotation_invariant_image_SDF", - "Rotation_invariant_pointcloud_SDF", - ], - palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF"], + hue_order=run_names, + palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF", "#91db57", "#db57d3"], aspect=2, height=5, + dodge=True, ) g.set_xticklabels(rotation=90) plt.axhline(y=0.05, color="black") From 61f6af15bfee47858cab23238c8335257dc18488 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 9 Dec 2024 15:33:18 -0800 Subject: [PATCH 2/2] adjust plots to fix labels and ylim --- src/br/chandrasekaran_et_al/utils.py | 52 ++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/src/br/chandrasekaran_et_al/utils.py b/src/br/chandrasekaran_et_al/utils.py index f8feb2e..683986f 100644 --- a/src/br/chandrasekaran_et_al/utils.py +++ b/src/br/chandrasekaran_et_al/utils.py @@ -25,7 +25,7 @@ from br.chandrasekaran_et_al import utils -def perturbation_detection(all_ret, get_featurecols, get_featuredata, fit_pca=False): +def perturbation_detection(all_ret, get_featurecols, get_featuredata): cols = get_featurecols(all_ret) replicate_feature = "Metadata_broad_sample" batch_size = 100000 @@ -166,22 +166,42 @@ def _plot(all_rep, save_path, run_names): sns.set(font_scale=1.7) sns.set_style("white") - test = all_rep.sort_values(by="q_value").reset_index(drop=True) - test["Drugs"] = test["Metadata_broad_sample"] - - x_order = ( - test.loc[test["model"] == "SO3_pointcloud_SDF"] - .sort_values(by="q_value")["Metadata_broad_sample"] - .values - ) - ordered_drugs = ( - all_rep.groupby(["Metadata_broad_sample"]) - .mean() - .sort_values(by="q_value") - .reset_index()["Metadata_broad_sample"] - ) + all_rep["Drugs"] = all_rep["Metadata_broad_sample"] + map_ = { + "Actinomyocin D 0.5ug per mL": "Actinomyocin D", + "Jasplakinolide 50 nM (E5)": "Jasplakinolide", + "Paclitaxel 5uM (E2)": "Paclitaxel", + "Staurosporine 1uM (E8)": "Staurosporine", + "Nocodazole 0.1uM (E4)": "Nocodazole", + "Roscovitine 10uM (E9)": "Roscovitine 10uM", + "Torin 1uM": "Torin", + "Rapamycin 1uM (E7)": "Rapamycin", + "H89 10uM (E3)": "H89", + "Monensin 1.1uM": "Monensin", + "Rotenone 0.5uM (E6)": "Rotenone", + "Roscovitine 5uM (E10)": "Roscovitine 5uM", + "BIX 1uM": "BIX", + "Bafilomycin A1 0.1uM": "Bafilomycin A1", + "Latrunculin A1 0.1uM": "Latrunculin A1", + "Chloroquin 40uM": "Chloroquin", + "Brefeldin 5uM": "Brefeldin", + } + all_rep["Drugs"] = all_rep["Drugs"].replace(map_) + + tmp_ = all_rep.loc[ + all_rep["model"].isin( + [ + "Classical_image_SDF", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", + ] + ) + ] + ordered_drugs = tmp_.groupby(["Drugs"]).mean().sort_values(by="q_value").reset_index()["Drugs"] x_order = ordered_drugs + test = all_rep.sort_values(by="q_value").reset_index(drop=True) + g = sns.catplot( data=test, x="Drugs", @@ -196,7 +216,9 @@ def _plot(all_rep, save_path, run_names): dodge=True, ) g.set_xticklabels(rotation=90) + g.set(ylim=(0, 0.1)) plt.axhline(y=0.05, color="black") + g.set(ylabel='q value') this_path = Path(save_path) Path(this_path).mkdir(parents=True, exist_ok=True) g.savefig(this_path / "q_values.png", dpi=300, bbox_inches="tight")