Skip to content

Commit

Permalink
Merge pull request #74 from AllenCell/cellprofiler_barplots
Browse files Browse the repository at this point in the history
add cellprofiler and no finetuning model to drug analysis + box plots
  • Loading branch information
ritvikvasan authored Dec 9, 2024
2 parents b38b1f2 + 8dd8bf6 commit 7f9374a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 46 deletions.
21 changes: 12 additions & 9 deletions configs/results/npm1_perturb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@ image_path:
pc_path:
model_checkpoints:
[
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_pointcloud_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_seg.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_seg.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Classical_image_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_image_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/Rotation_invariant_pointcloud_SDF.ckpt",
"./morphology_appropriate_representation_learning/model_checkpoints/npm1/Rotation_invariant_pointcloud_SDF.ckpt",
]
names:
[
"Rotation_invariant_pointcloud_SDF",
"Classical_image_SDF",
"Classical_image_seg",
"Rotation_invariant_image_SDF",
"Rotation_invariant_image_seg",
"Classical_image_SDF",
"Rotation_invariant_image_SDF",
"Rotation_invariant_pointcloud_SDF",
"Rotation_invariant_pointcloud_SDF_no_finetuning",
]
data_paths:
[
"/data/npm1_perturb/pc.yaml",
"/data/npm1_perturb/classical_image_sdf.yaml",
"/data/npm1_perturb/classical_image_seg.yaml",
"/data/npm1_perturb/so3_image_sdf.yaml",
"/data/npm1_perturb/so3_image_seg.yaml",
"/data/npm1_perturb/classical_image_sdf.yaml",
"/data/npm1_perturb/so3_image_sdf.yaml",
"/data/npm1_perturb/pc.yaml",
"/data/npm1_perturb/pc.yaml",
]
19 changes: 14 additions & 5 deletions src/br/analysis/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def setup_evaluation_params(manifest, run_names):
eval_scaled_img_params = [{}] * len(run_names)

if "SDF" in "\t".join(run_names):
loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names)
sample_points_list = [False] * len(run_names)
skew_scale = None

if "mesh_folder" not in manifest.columns:
return (
eval_scaled_img,
eval_scaled_img_params,
loss_eval_list,
sample_points_list,
skew_scale,
)

eval_scaled_img_resolution = 32
gt_mesh_dir = manifest["mesh_folder"].iloc[0]
gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0]
Expand All @@ -192,9 +205,6 @@ def setup_evaluation_params(manifest, run_names):
"mesh_ext": "stl",
}
)
loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names)
sample_points_list = [False] * len(run_names)
skew_scale = None
else:
loss_eval_list = None
skew_scale = 100
Expand Down Expand Up @@ -670,7 +680,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d
this_save_path_input = Path(save_path) / Path(this_run_name) / Path("input")
this_save_path_input.mkdir(parents=True, exist_ok=True)
np.save(this_save_path_input / Path(f"{cell_id}.npy"), input)

if canonical is not None:
this_save_path_canon = (
Path(save_path) / Path(this_run_name) / Path("canonical")
Expand Down Expand Up @@ -715,7 +725,6 @@ def _plot_image(input, recon, recon_canonical, dataset_name):
recon = recon.T
recon_canonical = recon_canonical.T

i = 2
fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4))
ax.imshow(input, cmap="gray_r")
ax1.imshow(recon, cmap="gray_r")
Expand Down
25 changes: 19 additions & 6 deletions src/br/analysis/run_drugdata_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@
import sys
from pathlib import Path

import pandas as pd

from br.chandrasekaran_et_al.utils import _plot, perturbation_detection
from br.models.compute_features import get_embeddings
from br.models.utils import get_all_configs_per_dataset


def _get_featurecols(df):
"""returna list of featuredata columns."""
return [c for c in df.columns if "mu" in c]
return [c for c in df.columns if "mu_" in c]


def _get_featurecols_cellprofiler(df):
"""returna list of featuredata columns."""
return [i for i in df.columns if "Mean" in i or "StDev" in i or "Median" in i]


def _get_featuredata(df):
def _get_featuredata(df, get_featurecols_fn):
"""return dataframe of just featuredata columns."""
return df[_get_featurecols(df)]
return df[get_featurecols_fn(df)]


def main(args):
Expand All @@ -26,17 +33,23 @@ def main(args):
dataset_name = args.dataset_name
DATASET_INFO = get_all_configs_per_dataset(results_path)
dataset = DATASET_INFO[dataset_name]
run_names = dataset["names"]
run_names = dataset["names"] + ["cellprofiler"]

all_ret, df = get_embeddings(run_names, args.dataset_name, DATASET_INFO, args.embeddings_path)
all_ret["well_position"] = "A0" # dummy
all_ret["Assay_Plate_Barcode"] = "Plate0" # dummy

pert = perturbation_detection(all_ret, _get_featurecols, _get_featuredata)
df_1 = all_ret.loc[~all_ret["model"].isin(["cellprofiler"])].reset_index(drop=True)
pert1 = perturbation_detection(df_1, _get_featurecols, _get_featuredata)

df_2 = all_ret.loc[all_ret["model"].isin(["cellprofiler"])].reset_index(drop=True)
pert2 = perturbation_detection(df_2, _get_featurecols_cellprofiler, _get_featuredata)

pert = pd.concat([pert1, pert2], axis=0).reset_index(drop=True)

this_save_path = Path(args.save_path)
this_save_path.mkdir(parents=True, exist_ok=True)
_plot(pert, this_save_path)
_plot(pert, this_save_path, run_names)


if __name__ == "__main__":
Expand Down
69 changes: 43 additions & 26 deletions src/br/chandrasekaran_et_al/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata):
replicate_feature = "Metadata_broad_sample"
batch_size = 100000
null_size = 100000

all_rep = []
for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())):
df_feats = all_ret.loc[all_ret["model"] == model].reset_index(drop=True)
Expand Down Expand Up @@ -124,7 +123,7 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata):
neg_diffby = ["Metadata_negcon"]

metadata_df = get_metadata(modality_1_df)
feature_df = get_featuredata(modality_1_df)
feature_df = get_featuredata(modality_1_df, get_featurecols)
feature_values = feature_df.values

result = run_pipeline(
Expand Down Expand Up @@ -152,6 +151,7 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata):
cell,
modality_1_timepoint,
)

replicability_map_df["model"] = model
all_rep.append(replicability_map_df)

Expand All @@ -161,47 +161,64 @@ def perturbation_detection(all_ret, get_featurecols, get_featuredata):
return all_rep


def _plot(all_rep, save_path):
def _plot(all_rep, save_path, run_names):
sns.set_context("talk")
sns.set(font_scale=1.7)
sns.set_style("white")

test = all_rep.sort_values(by="q_value").reset_index(drop=True)
test["Drugs"] = test["Metadata_broad_sample"]

x_order = (
test.loc[test["model"] == "SO3_pointcloud_SDF"]
.sort_values(by="q_value")["Metadata_broad_sample"]
.values
)
ordered_drugs = (
all_rep.groupby(["Metadata_broad_sample"])
.mean()
.sort_values(by="q_value")
.reset_index()["Metadata_broad_sample"]
)
all_rep["Drugs"] = all_rep["Metadata_broad_sample"]
map_ = {
"Actinomyocin D 0.5ug per mL": "Actinomyocin D",
"Jasplakinolide 50 nM (E5)": "Jasplakinolide",
"Paclitaxel 5uM (E2)": "Paclitaxel",
"Staurosporine 1uM (E8)": "Staurosporine",
"Nocodazole 0.1uM (E4)": "Nocodazole",
"Roscovitine 10uM (E9)": "Roscovitine 10uM",
"Torin 1uM": "Torin",
"Rapamycin 1uM (E7)": "Rapamycin",
"H89 10uM (E3)": "H89",
"Monensin 1.1uM": "Monensin",
"Rotenone 0.5uM (E6)": "Rotenone",
"Roscovitine 5uM (E10)": "Roscovitine 5uM",
"BIX 1uM": "BIX",
"Bafilomycin A1 0.1uM": "Bafilomycin A1",
"Latrunculin A1 0.1uM": "Latrunculin A1",
"Chloroquin 40uM": "Chloroquin",
"Brefeldin 5uM": "Brefeldin",
}
all_rep["Drugs"] = all_rep["Drugs"].replace(map_)

tmp_ = all_rep.loc[
all_rep["model"].isin(
[
"Classical_image_SDF",
"Rotation_invariant_image_SDF",
"Rotation_invariant_pointcloud_SDF",
]
)
]
ordered_drugs = tmp_.groupby(["Drugs"]).mean().sort_values(by="q_value").reset_index()["Drugs"]
x_order = ordered_drugs

test = all_rep.sort_values(by="q_value").reset_index(drop=True)

g = sns.catplot(
data=test,
x="Drugs",
y="q_value",
kind="bar",
hue="model",
kind="point",
order=x_order,
hue_order=[
"Classical_image_seg",
"Rotation_invariant_image_seg",
"Classical_image_SDF",
"Rotation_invariant_image_SDF",
"Rotation_invariant_pointcloud_SDF",
],
palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF"],
hue_order=run_names,
palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF", "#91db57", "#db57d3"],
aspect=2,
height=5,
dodge=True,
)
g.set_xticklabels(rotation=90)
g.set(ylim=(0, 0.1))
plt.axhline(y=0.05, color="black")
g.set(ylabel='q value')
this_path = Path(save_path)
Path(this_path).mkdir(parents=True, exist_ok=True)
g.savefig(this_path / "q_values.png", dpi=300, bbox_inches="tight")
Expand Down

0 comments on commit 7f9374a

Please sign in to comment.