From d5e19427435be5fa1f5ef938026849e5f35281e9 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 10:38:20 -0800 Subject: [PATCH 01/35] delete hidden snakemake folder --- ...fRGF0YV9Qcm9jZXNzaW5nL3RyYXNoL21lcmdlL21hbmlmZXN0LnBhcnF1ZXQ= | 1 - 1 file changed, 1 deletion(-) delete mode 100644 src/br/data/preprocessing/image_preprocessing/.snakemake/incomplete/L2FsbGVuL2FpY3MvbW9kZWxpbmcvcml0dmlrL3Byb2plY3RzL3NlY29uZF9jbG9uZXMvVmFyaWFuY2VfRGF0YV9Qcm9jZXNzaW5nL3RyYXNoL21lcmdlL21hbmlmZXN0LnBhcnF1ZXQ= diff --git a/src/br/data/preprocessing/image_preprocessing/.snakemake/incomplete/L2FsbGVuL2FpY3MvbW9kZWxpbmcvcml0dmlrL3Byb2plY3RzL3NlY29uZF9jbG9uZXMvVmFyaWFuY2VfRGF0YV9Qcm9jZXNzaW5nL3RyYXNoL21lcmdlL21hbmlmZXN0LnBhcnF1ZXQ= b/src/br/data/preprocessing/image_preprocessing/.snakemake/incomplete/L2FsbGVuL2FpY3MvbW9kZWxpbmcvcml0dmlrL3Byb2plY3RzL3NlY29uZF9jbG9uZXMvVmFyaWFuY2VfRGF0YV9Qcm9jZXNzaW5nL3RyYXNoL21lcmdlL21hbmlmZXN0LnBhcnF1ZXQ= deleted file mode 100644 index 5d3401d..0000000 --- a/src/br/data/preprocessing/image_preprocessing/.snakemake/incomplete/L2FsbGVuL2FpY3MvbW9kZWxpbmcvcml0dmlrL3Byb2plY3RzL3NlY29uZF9jbG9uZXMvVmFyaWFuY2VfRGF0YV9Qcm9jZXNzaW5nL3RyYXNoL21lcmdlL21hbmlmZXN0LnBhcnF1ZXQ= +++ /dev/null @@ -1 +0,0 @@ -{"external_jobid": null} \ No newline at end of file From beb3110c11fcca12a04c5166407688323bb69f41 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 11:20:19 -0800 Subject: [PATCH 02/35] change pcna result order --- configs/results/pcna.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/results/pcna.yaml b/configs/results/pcna.yaml index 9c4f184..d8d6679 100644 --- a/configs/results/pcna.yaml +++ b/configs/results/pcna.yaml @@ -3,25 +3,25 @@ image_path: ./morphology_appropriate_representation_learning/preprocessed_data/p pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/pcna/manifest.csv model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_pointcloud.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_image.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_pointcloud.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud_jitter.ckpt", ] names: [ - "Classical_pointcloud", - "Rotation_invariant_pointcloud", "Classical_image", "Rotation_invariant_image", + "Classical_pointcloud", + "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_jitter", ] data_paths: [ - "./configs/data/pcna/pc.yaml", - "./configs/data/pcna/pc_intensity.yaml", "./configs/data/pcna/image.yaml", "./configs/data/pcna/image.yaml", + "./configs/data/pcna/pc.yaml", + "./configs/data/pcna/pc_intensity.yaml", # "./src/br/configs/data/pcna/pc_intensity_jitter.yaml", "./configs/data/pcna/pc_intensity.yaml", ] From 9458d4ab6e1222b28cbc1c4eb89f8e2b35756fc1 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 11:56:54 -0800 Subject: [PATCH 03/35] add setup function to prereq to handle SDF/pc differences --- src/br/analysis/prereq.py | 145 +++++++++++++++++++++++------------ src/br/models/load_models.py | 11 ++- 2 files changed, 103 insertions(+), 53 deletions(-) diff --git a/src/br/analysis/prereq.py b/src/br/analysis/prereq.py index 6871025..13d0370 100644 --- a/src/br/analysis/prereq.py +++ b/src/br/analysis/prereq.py @@ -1,16 +1,24 @@ # Free up cache +import argparse import gc +import os +import subprocess +import pandas as pd import torch +from br.models.compute_features import compute_features +from br.models.load_models import get_data_and_models +from br.models.save_embeddings import ( + get_pc_loss_chamfer, + save_embeddings, + save_emissions, +) +from br.models.utils import get_all_configs_per_dataset + gc.collect() torch.cuda.empty_cache() -import argparse -import os -import subprocess -from pathlib import Path - # Based on the utilization, set the GPU ID @@ -89,32 +97,6 @@ def config_gpu(): # Setting a GPU ID is crucial for the script to work well! -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import yaml -from hydra.utils import instantiate -from PIL import Image -from torch.utils.data import DataLoader, Dataset - -from br.features.archetype import AA_Fast -from br.features.plot import collect_outputs, plot, plot_stratified_pc -from br.features.reconstruction import stratified_latent_walk -from br.features.utils import ( - normalize_intensities_and_get_colormap, - normalize_intensities_and_get_colormap_apply, -) -from br.models.compute_features import compute_features, get_embeddings -from br.models.load_models import get_data_and_models -from br.models.save_embeddings import ( - get_pc_loss, - get_pc_loss_chamfer, - save_embeddings, - save_emissions, -) -from br.models.utils import get_all_configs_per_dataset - def main(args): # Set working directory and paths @@ -126,41 +108,94 @@ def main(args): debug = args.debug # Load data and models - data_list, all_models, run_names, model_sizes = get_data_and_models( + data_list, all_models, run_names, model_sizes, manifest = get_data_and_models( dataset_name, batch_size, results_path, debug ) # Save model sizes to CSV - gg = pd.DataFrame() - gg["model"] = run_names - gg["model_size"] = model_sizes - gg.to_csv(os.path.join(save_path, "model_sizes.csv")) + sizes_ = pd.DataFrame() + sizes_["model"] = run_names + sizes_["model_size"] = model_sizes + sizes_.to_csv(os.path.join(save_path, "model_sizes.csv")) - compute_embeddings() + save_embeddings_across_models(args, manifest, data_list, all_models, run_names) compute_relevant_features() -def compute_embeddings(): +def _setup_evaluation_params(manifest, run_names): + """Return evaluation params related to. + + 1. loss_eval_list - which loss to use for each model (Defaults to Chamfer loss) + 2. skew_scale - Hyperparameter associated with sampling of pointclouds from images + 3. sample_points_list - whether to sample pointclouds for each model + 4. eval_scaled_img - whether to scale the images for evaluation (specific to SDF models) + 5. eval_scaled_img_params - parameters like mesh paths, scale factors, pointcloud paths associated + with evaluating scaled images + """ + eval_scaled_img = [False] * len(run_names) + eval_scaled_img_params = [{}] * len(run_names) + + if "SDF" in "\t".join(run_names): + eval_scaled_img_resolution = 32 + gt_mesh_dir = manifest["mesh_folder"].iloc[0] + gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0] + gt_scale_factor_dict_path = manifest["scale_factor"].iloc[0] + eval_scaled_img_params = [] + for name_ in run_names: + if "seg" in name_: + model_type = "seg" + elif "SDF" in name_: + model_type = "sdf" + elif "pointcloud" in name_: + model_type = "iae" + eval_scaled_img_params.append( + { + "eval_scaled_img_model_type": model_type, + "eval_scaled_img_resolution": eval_scaled_img_resolution, + "gt_mesh_dir": gt_mesh_dir, + "gt_scale_factor_dict_path": gt_scale_factor_dict_path, + "gt_sampled_pts_dir": gt_sampled_pts_dir, + "mesh_ext": "stl", + } + ) + loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) + sample_points_list = [False] * len(run_names) + skew_scale = None + else: + loss_eval_list = None + skew_scale = 100 + sample_points_list = [] + for name_ in run_names: + if "image" in name_: + sample_points_list.append(True) + else: + sample_points_list.append(False) + return eval_scaled_img, eval_scaled_img_params, loss_eval_list, sample_points_list, skew_scale + + +def save_embeddings_across_models(args, manifest, data_list, all_models, run_names): + """ + Save embeddings across models + """ # Compute embeddings and reconstructions for each model - debug = False splits_list = ["train", "val", "test"] - meta_key = "rule" - eval_scaled_img = [False] * 5 - eval_scaled_img_params = [{}] * 5 - loss_eval_list = None - # sample_points_list = [True, True, False, False, False] # This is also different for each of PCNA and Cellpack - RITVIK - sample_points_list = [False, False, True, True, False] - skew_scale = 100 + ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) = _setup_evaluation_params(manifest, run_names) save_embeddings( - save_path, + args.save_path, data_list, all_models, run_names, - debug, + args.debug, splits_list, device, - meta_key, + args.meta_key, loss_eval_list, sample_points_list, skew_scale, @@ -272,6 +307,18 @@ def compute_relevant_features(): parser.add_argument( "--results_path", type=str, required=True, help="Path to the results directory." ) + parser.add_argument( + "--meta_key", + type=str, + required=True, + help="Metadata to add to the embeddings aside from CellId", + ) + parser.add_argument( + "--sdf", + type=bool, + required=True, + help="boolean indicating whether the experiments involve SDFs", + ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") diff --git a/src/br/models/load_models.py b/src/br/models/load_models.py index e87653d..3c6c1de 100644 --- a/src/br/models/load_models.py +++ b/src/br/models/load_models.py @@ -1,3 +1,4 @@ +import pandas as pd import yaml from cyto_dl.models.utils.mlflow import get_config, load_model_from_checkpoint from hydra._internal.utils import _locate @@ -10,6 +11,7 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", device="cuda:0"): MODEL_INFO = get_all_configs_per_dataset(results_path) models = MODEL_INFO[dataset] + model_manifest = pd.read_csv(models["orig_df"]) model_sizes = [] all_models = [] for j, ckpt_path in enumerate(models["model_checkpoints"]): @@ -30,13 +32,14 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", devic ) model_sizes.append(config["model/params/total"]) - return all_models, models["names"], model_sizes + return all_models, models["names"], model_sizes, model_manifest def load_model_from_mlflow(dataset, results_path, split="val"): TRACKING_URI = "https://mlflow.a100.int.allencell.org" MODEL_INFO = get_all_configs_per_dataset(results_path) models = MODEL_INFO[dataset] + model_manifest = pd.read_csv(models["orig_df"]) model_sizes = [] all_models = [] for i in models["run_ids"]: @@ -51,12 +54,12 @@ def load_model_from_mlflow(dataset, results_path, split="val"): config = get_config(TRACKING_URI, i, "./tmp") model_sizes.append(config["model/params/total"]) - return all_models, models["names"], model_sizes + return all_models, models["names"], model_sizes, model_manifest def get_data_and_models(dataset_name, batch_size, results_path, debug=False): data_list = get_data(dataset_name, batch_size, results_path, debug) - all_models, run_names, model_sizes = load_model_from_path( + all_models, run_names, model_sizes, model_manifest = load_model_from_path( dataset_name, results_path ) # default list of models in load_models.py - return data_list, all_models, run_names, model_sizes + return data_list, all_models, run_names, model_sizes, model_manifest From 8069271d7dfccf233d702033e5849541109beb98 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 14:04:19 -0800 Subject: [PATCH 04/35] add classes and labels to results config --- configs/results/cellpack.yaml | 2 ++ configs/results/npm1.yaml | 3 +++ configs/results/other_polymorphic.yaml | 2 ++ configs/results/other_punctate.yaml | 2 ++ configs/results/pcna.yaml | 2 ++ 5 files changed, 11 insertions(+) diff --git a/configs/results/cellpack.yaml b/configs/results/cellpack.yaml index ed588ed..59ef98f 100644 --- a/configs/results/cellpack.yaml +++ b/configs/results/cellpack.yaml @@ -25,3 +25,5 @@ data_paths: [ # "./src/br/configs/data/cellpack/pc_jitter.yaml", "./configs/data/cellpack/pc.yaml", ] +classification_label: ["rule"] +regression_label: diff --git a/configs/results/npm1.yaml b/configs/results/npm1.yaml index 97c2f9c..58f4d10 100644 --- a/configs/results/npm1.yaml +++ b/configs/results/npm1.yaml @@ -25,3 +25,6 @@ data_paths: "./configs/data/npm1/classical_image_sdf.yaml", "./configs/data/npm1/classical_image_seg.yaml", ] +classification_label: ["STR_connectivity_cc_thresh"] +regression_label: + ["mean_centroid_distances", "mean_nucleolus_volume", "mean_nucleolus_area"] diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index 2b4319f..cba8773 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -25,3 +25,5 @@ data_paths: "./configs/data/other_polymorphic/classical_image_sdf.yaml", "./configs/data/other_polymorphic/classical_image_seg.yaml", ] +classification_label: ["structure_name"] +regression_label: ["avg_dists", "mean_volume", "mean_surface_area"] diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index da1f528..2e5a109 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -25,3 +25,5 @@ data_paths: "./configs/data/other_punctate/pc_intensity.yaml", "./configs/data/other_punctate/pc_intensity_structurenorm.yaml", ] +classification_label: ["structure_name", "cell_stage"] +regression_label: diff --git a/configs/results/pcna.yaml b/configs/results/pcna.yaml index d8d6679..9e0e8f4 100644 --- a/configs/results/pcna.yaml +++ b/configs/results/pcna.yaml @@ -25,3 +25,5 @@ data_paths: [ # "./src/br/configs/data/pcna/pc_intensity_jitter.yaml", "./configs/data/pcna/pc_intensity.yaml", ] +classification_label: ["cell_stage_fine", "flag_comment"] +regression_label: From acf40559bdc3bd6bbbbf65dae96c0e15705e7751 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 14:04:49 -0800 Subject: [PATCH 05/35] move to run_embeddings and run_features --- src/br/analysis/__init__.py | 0 src/br/analysis/analysis_utils.py | 164 +++++++++++++++++++++++++++++ src/br/analysis/run_embeddings.py | 114 +++++++++++++++++++++ src/br/analysis/run_features.py | 165 ++++++++++++++++++++++++++++++ src/br/features/classification.py | 5 +- src/br/features/regression.py | 2 +- src/br/models/compute_features.py | 5 +- src/br/models/load_models.py | 25 ++++- 8 files changed, 472 insertions(+), 8 deletions(-) create mode 100644 src/br/analysis/__init__.py create mode 100644 src/br/analysis/analysis_utils.py create mode 100644 src/br/analysis/run_embeddings.py create mode 100644 src/br/analysis/run_features.py diff --git a/src/br/analysis/__init__.py b/src/br/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py new file mode 100644 index 0000000..318174b --- /dev/null +++ b/src/br/analysis/analysis_utils.py @@ -0,0 +1,164 @@ +import subprocess +import torch +from br.models.utils import get_all_configs_per_dataset + + +def get_gpu_info(): + # Run nvidia-smi command and get the output + cmd = [ + "nvidia-smi", + "--query-gpu=index,uuid,name,utilization.gpu", + "--format=csv,noheader,nounits", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + return result.stdout.strip() + + +def check_mig(): + # Check if MIG is enabled + cmd = ["nvidia-smi", "-L"] + result = subprocess.run(cmd, capture_output=True, text=True) + return "MIG" in result.stdout + + +def get_mig_ids(): + # Get the MIG UUIDs + cmd = ["nvidia-smi", "-L"] + result = subprocess.run(cmd, capture_output=True, text=True) + mig_ids = [] + for line in result.stdout.splitlines(): + if "MIG" in line: + mig_id = line.split("(UUID: ")[-1].strip(")") + mig_ids.append(mig_id) + return mig_ids + + +def config_gpu(): + selected_gpu_id_or_uuid = "" + is_mig = check_mig() + + gpu_info = get_gpu_info() + lines = gpu_info.splitlines() + + for line in lines: + index, uuid, name, utilization = map(str.strip, line.split(",")) + + # If utilization is [N/A], treat it as less than 10 + if utilization == "[N/A]": + utilization = -1 # Assign a value less than 10 to simulate "idle" + else: + utilization = int(utilization) + + # Check if GPU utilization is under 10% (indicating it's idle) + if utilization < 10: + if is_mig: + mig_ids = get_mig_ids() + if mig_ids: + selected_gpu_id_or_uuid = mig_ids[0] # Select the first MIG ID + break # Exit the loop after finding the first MIG ID + else: + selected_gpu_id_or_uuid = uuid + print(f"Selected UUID is {selected_gpu_id_or_uuid}") + break + return selected_gpu_id_or_uuid + + +def _setup_evaluation_params(manifest, run_names): + """Return evaluation params related to. + + 1. loss_eval_list - which loss to use for each model (Defaults to Chamfer loss) + 2. skew_scale - Hyperparameter associated with sampling of pointclouds from images + 3. sample_points_list - whether to sample pointclouds for each model + 4. eval_scaled_img - whether to scale the images for evaluation (specific to SDF models) + 5. eval_scaled_img_params - parameters like mesh paths, scale factors, pointcloud paths associated + with evaluating scaled images + """ + eval_scaled_img = [False] * len(run_names) + eval_scaled_img_params = [{}] * len(run_names) + + if "SDF" in "\t".join(run_names): + eval_scaled_img_resolution = 32 + gt_mesh_dir = manifest["mesh_folder"].iloc[0] + gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0] + gt_scale_factor_dict_path = manifest["scale_factor"].iloc[0] + eval_scaled_img_params = [] + for name_ in run_names: + if "seg" in name_: + model_type = "seg" + elif "SDF" in name_: + model_type = "sdf" + elif "pointcloud" in name_: + model_type = "iae" + eval_scaled_img_params.append( + { + "eval_scaled_img_model_type": model_type, + "eval_scaled_img_resolution": eval_scaled_img_resolution, + "gt_mesh_dir": gt_mesh_dir, + "gt_scale_factor_dict_path": gt_scale_factor_dict_path, + "gt_sampled_pts_dir": gt_sampled_pts_dir, + "mesh_ext": "stl", + } + ) + loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) + sample_points_list = [False] * len(run_names) + skew_scale = None + else: + loss_eval_list = None + skew_scale = 100 + sample_points_list = [] + for name_ in run_names: + if "image" in name_: + sample_points_list.append(True) + else: + sample_points_list.append(False) + return eval_scaled_img, eval_scaled_img_params, loss_eval_list, sample_points_list, skew_scale + + +def _setup_evolve_params(run_names, data_config_list, keys): + eval_meshed_img = [False] * len(run_names) + eval_meshed_img_model_type = [None] * len(run_names) + compute_evolve_dataloaders = False + if "SDF" in "\t".join(run_names): + compute_evolve_dataloaders = True + eval_meshed_img = [True] * len(run_names) + eval_meshed_img_model_type = [] + for name_ in run_names: + if "seg" in name_: + model_type = "seg" + elif "SDF" in name_: + model_type = "sdf" + elif "pointcloud" in name_: + model_type = "iae" + eval_meshed_img_model_type.append(model_type) + + evolve_params = { + 'modality_list_evolve': keys, + 'config_list_evolve': data_config_list, + 'num_evolve_samples': 40, + 'compute_evolve_dataloaders': compute_evolve_dataloaders, + 'eval_meshed_img': eval_meshed_img, + 'eval_meshed_img_model_type': eval_meshed_img_model_type, + "skew_scale": None, + "only_embedding": False, + "fit_pca": False, + "pc_is_iae":False + } + return evolve_params + + +def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): + DATA_LIST = get_all_configs_per_dataset(results_path) + data_config_list = DATA_LIST[dataset_name]["data_paths"] + class_label = DATA_LIST[dataset_name]['classification_label'] + regression_label = DATA_LIST[dataset_name]['regression_label'] + evolve_params = _setup_evolve_params(run_names, data_config_list, keys) + classification_params = {"class_labels": class_label, "df_feat": manifest} + rot_inv_params = {"squeeze_2d": False, "id": "cell_id", "max_batches": 4000} + regression_params = {"df_feat": manifest, "target_cols": regression_label, "feature_df_path": None} + compactness_params = { + "method": "mle", + "num_PCs": None, + "blobby_outlier_max_cc": None, + "check_duplicates": True, + } + return rot_inv_params, compactness_params, classification_params, evolve_params, regression_params \ No newline at end of file diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py new file mode 100644 index 0000000..127673f --- /dev/null +++ b/src/br/analysis/run_embeddings.py @@ -0,0 +1,114 @@ +# Free up cache +import argparse +import gc +import os +import torch +from br.models.load_models import get_data_and_models +from br.models.save_embeddings import ( + save_embeddings, +) +import sys +from br.analysis.analysis_utils import config_gpu, _setup_evaluation_params + + +def main(args): + # Free up cache + gc.collect() + torch.cuda.empty_cache() + + # Based on the utilization, set the GPU ID + # Setting a GPU ID is crucial for the script to work well! + selected_gpu_id_or_uuid = config_gpu() + + # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID + if selected_gpu_id_or_uuid: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid + print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") + else: + print("No suitable GPU or MIG ID found. Exiting...") + + # Set the device + device = "cuda:0" + + # Set working directory and paths + os.chdir(args.src_path) + + # Load data and models + data_list, all_models, run_names, model_sizes, manifest, _, _ = get_data_and_models( + args.dataset_name, args.batch_size, args.results_path, args.debug + ) + + # Load evaluation params + ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) = _setup_evaluation_params(manifest, run_names) + + # save embeddings for each model + save_embeddings( + args.save_path, + data_list, + all_models, + run_names, + args.debug, + ["train", "val", "test"], # splits to compute embeddings + device, + args.meta_key, + loss_eval_list, + sample_points_list, + skew_scale, + eval_scaled_img, + eval_scaled_img_params, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing embeddings") + parser.add_argument( + "--src_path", type=str, required=True, help="Path to the source directory." + ) + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument( + "--results_path", type=str, required=True, help="Path to the results directory." + ) + parser.add_argument( + "--meta_key", + type=str, + required=True, + help="Metadata to add to the embeddings aside from CellId", + ) + parser.add_argument( + "--sdf", + type=bool, + required=True, + help="boolean indicating whether the experiments involve SDFs", + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") + parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") + + args = parser.parse_args() + + # Validate that required paths are provided + if not args.src_path or not args.save_path or not args.results_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + +""" +Example +os.chdir(r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/") +save_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/test_cellpack_save_embeddings/" +results_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/configs/results/" +dataset_name = "cellpack" +batch_size = 2 +debug = True + +""" diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py new file mode 100644 index 0000000..8a43478 --- /dev/null +++ b/src/br/analysis/run_features.py @@ -0,0 +1,165 @@ +# Free up cache +import argparse +import gc +import os +import pandas as pd +import torch +from br.models.compute_features import compute_features +from br.models.load_models import get_data_and_models +from br.models.save_embeddings import ( + save_emissions, +) +import sys +from br.analysis.analysis_utils import config_gpu, _setup_evaluation_params + + +def main(args): + # Free up cache + gc.collect() + torch.cuda.empty_cache() + + # Based on the utilization, set the GPU ID + # Setting a GPU ID is crucial for the script to work well! + selected_gpu_id_or_uuid = config_gpu() + + # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID + if selected_gpu_id_or_uuid: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid + print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") + else: + print("No suitable GPU or MIG ID found. Exiting...") + + # Set the device + device = "cuda:0" + + # Set working directory and paths + os.chdir(args.src_path) + + # set batch size to 1 for emission stats/features + batch_size = 1 + + # Load data and models + data_list, all_models, run_names, model_sizes, manifest, keys, latent_dims = get_data_and_models( + args.dataset_name, batch_size, args.results_path, args.debug + ) + max_embed_dim = min(latent_dims) + + # Save model sizes to CSV + sizes_ = pd.DataFrame() + sizes_["model"] = run_names + sizes_["model_size"] = model_sizes + sizes_.to_csv(os.path.join(args.save_path, "model_sizes.csv")) + + # Load evaluation params + ( + eval_scaled_img, + eval_scaled_img_params, + loss_eval_list, + sample_points_list, + skew_scale, + ) = _setup_evaluation_params(manifest, run_names) + + # Save emission stats for each model + max_batches = 40 + save_emissions( + args.save_path, + data_list, + all_models, + run_names, + max_batches, + args.debug, + device, + loss_eval_list, + sample_points_list, + skew_scale, + eval_scaled_img, + eval_scaled_img_params, + ) + + # Compute multi-metric benchmarking params + rot_inv_params, compactness_params, classification_params, evolve_params, regression_params = _get_feature_params(results_path, dataset_name, manifest, keys, run_names) + + metric_list = [ + "Rotation Invariance Error", + "Evolution Energy", + "Reconstruction", + "Classification", + "Compactness", + ] + if len(regression_params['target_cols']) > 0: + metric_list.append('Regression') + + # Compute multi-metric benchmarking features + compute_features( + dataset=args.dataset_name, + results_path=args.results_path, + embeddings_path=args.embeddings_path, + save_folder=args.save_path, + data_list=data_list, + all_models=all_models, + run_names=run_names, + use_sample_points_list=sample_points_list, + keys=keys, + device=device, + max_embed_dim=max_embed_dim, + splits_list=["train", "val", "test"], + compute_embeds=False, + classification_params=classification_params, + regression_params=regression_params, + metric_list=metric_list, + loss_eval_list=loss_eval_list, + evolve_params=evolve_params, + rot_inv_params=rot_inv_params, + compactness_params=compactness_params, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for Benchmarking Representations") + parser.add_argument( + "--src_path", type=str, required=True, help="Path to the source directory." + ) + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument( + "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." + ) + parser.add_argument( + "--results_path", type=str, required=True, help="Path to the results directory." + ) + parser.add_argument( + "--meta_key", + type=str, + required=True, + help="Metadata to add to the embeddings aside from CellId", + ) + parser.add_argument( + "--sdf", + type=bool, + required=True, + help="boolean indicating whether the experiments involve SDFs", + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") + + args = parser.parse_args() + + # Validate that required paths are provided + if not args.src_path or not args.save_path or not args.results_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + +""" +Example +os.chdir(r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/") +save_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/test_cellpack_save_embeddings/" +results_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/configs/results/" +dataset_name = "cellpack" +batch_size = 2 +debug = True + +""" diff --git a/src/br/features/classification.py b/src/br/features/classification.py index 8e2ee5c..00abfce 100644 --- a/src/br/features/classification.py +++ b/src/br/features/classification.py @@ -8,7 +8,7 @@ from tqdm import tqdm -def get_classification_df(all_ret, target_col): +def get_classification_df(all_ret, target_col, df_feat=None): ret_dict5 = { "model": [], "top_1_acc": [], @@ -16,8 +16,11 @@ def get_classification_df(all_ret, target_col): "top_3_acc": [], "cv": [], } + for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): this_mo = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) + if df_feat and target_col not in this_mo.columns: + this_mo = this_mo.merge(df_feat, on="CellId") k1, k2, k3 = get_classification(this_mo, target_col) for i in range(len(k1)): ret_dict5["model"].append(model) diff --git a/src/br/features/regression.py b/src/br/features/regression.py index f505793..1e0440c 100644 --- a/src/br/features/regression.py +++ b/src/br/features/regression.py @@ -17,7 +17,7 @@ def get_regression_df(all_ret, target_cols, feature_df_path, df_feat=None): for target in target_cols: for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): this_mo = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) - if feature_df_path and target not in this_mo.columns: + if df_feat and target not in this_mo.columns: this_mo = this_mo.merge(df_feat, on="CellId") test_r2, test_mse = get_regression(this_mo, target) for i in range(len(test_r2)): diff --git a/src/br/models/compute_features.py b/src/br/models/compute_features.py index 0080228..e88438c 100644 --- a/src/br/models/compute_features.py +++ b/src/br/models/compute_features.py @@ -106,11 +106,11 @@ def compute_features( compute_embeds: bool = False, metric_list: list = METRIC_LIST, loss_eval_list: list = None, - classification_params: dict = {"class_labels": ["cell_stage_fine"]}, + classification_params: dict = {"class_labels": ["cell_stage_fine"], "df_feat": None}, regression_params: dict = { "feature_df_path": None, "target_cols": [], - "df_feat": [], + "df_feat": None, }, evolve_params: dict = { "modality_list_evolve": [], @@ -221,6 +221,7 @@ def compute_features( ret_dict_classification = get_classification_df( all_ret, target_col, + classification_params["df_feat"], ) ret_dict_classification.to_csv(path / Path(f"classification_{target_col}.csv")) diff --git a/src/br/models/load_models.py b/src/br/models/load_models.py index 3c6c1de..71a04b1 100644 --- a/src/br/models/load_models.py +++ b/src/br/models/load_models.py @@ -14,6 +14,8 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", devic model_manifest = pd.read_csv(models["orig_df"]) model_sizes = [] all_models = [] + x_labels = [] + latent_dims = [] for j, ckpt_path in enumerate(models["model_checkpoints"]): if "model_paths" in models.keys(): config_path = models["model_paths"][j] @@ -22,6 +24,8 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", devic with open(config_path) as stream: config = yaml.safe_load(stream) model_conf = config["model"] + x_label = model_conf["x_label"] + latent_dim = model_conf["latent_dim"] model_class = model_conf.pop("_target_") model_conf = instantiate(model_conf) model_class = _locate(model_class) @@ -31,8 +35,10 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", devic ).eval() ) model_sizes.append(config["model/params/total"]) + x_labels.append(x_label) + latent_dims.append(latent_dim) - return all_models, models["names"], model_sizes, model_manifest + return all_models, models["names"], model_sizes, model_manifest, x_labels, latent_dims def load_model_from_mlflow(dataset, results_path, split="val"): @@ -42,6 +48,8 @@ def load_model_from_mlflow(dataset, results_path, split="val"): model_manifest = pd.read_csv(models["orig_df"]) model_sizes = [] all_models = [] + x_labels = [] + latent_dims = [] for i in models["run_ids"]: all_models.append( load_model_from_checkpoint( @@ -53,13 +61,22 @@ def load_model_from_mlflow(dataset, results_path, split="val"): ) config = get_config(TRACKING_URI, i, "./tmp") model_sizes.append(config["model/params/total"]) + x_labels.append(config["model"]["x_label"]) + latent_dims.append(config["model"]["latent_dim"]) - return all_models, models["names"], model_sizes, model_manifest + return all_models, models["names"], model_sizes, model_manifest, x_labels, latent_dims def get_data_and_models(dataset_name, batch_size, results_path, debug=False): data_list = get_data(dataset_name, batch_size, results_path, debug) - all_models, run_names, model_sizes, model_manifest = load_model_from_path( + ( + all_models, + run_names, + model_sizes, + model_manifest, + x_labels, + latent_dims, + ) = load_model_from_path( dataset_name, results_path ) # default list of models in load_models.py - return data_list, all_models, run_names, model_sizes, model_manifest + return data_list, all_models, run_names, model_sizes, model_manifest, x_labels, latent_dims From be3fd4511736b95a867a8896b3f93360d637ab2b Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 14:05:28 -0800 Subject: [PATCH 06/35] remove hidden snakemake files --- README.md | 1 + subpackages/image_preprocessing/README.md | 2 ++ 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index a026a92..694e6b5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Code for training and benchmarking morphology appropriate representation learning methods. # Preprocessing + This README gives instructions for running our analysis against preprocessed image data. Because the preprocessing can take a long time, we are publishing preprocessed data along with the original movies. To do the preprocessing yourself, see the instructions in [subpackages/image_preprocessing/README.md](./subpackages/image_preprocessing/README.md). # Installation diff --git a/subpackages/image_preprocessing/README.md b/subpackages/image_preprocessing/README.md index 23f35a2..109ba62 100644 --- a/subpackages/image_preprocessing/README.md +++ b/subpackages/image_preprocessing/README.md @@ -5,11 +5,13 @@ Code for preprocessing 3D single cell images ## Installation Move to this `image_preprocessing` directory. + ```bash cd subpackages/image_preprocessing ``` Install dependencies. + ```bash conda create --name preprocessing_env python=3.10 conda activate preprocessing_env From 1689542fe947f737670e162ca8c60050f67413b7 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 15:16:42 -0800 Subject: [PATCH 07/35] use cytodl config path for dataloading --- configs/results/cellpack.yaml | 14 +++++++------- configs/results/npm1.yaml | 10 +++++----- configs/results/other_polymorphic.yaml | 10 +++++----- configs/results/other_punctate.yaml | 10 +++++----- configs/results/pcna.yaml | 10 +++++----- src/br/analysis/run_embeddings.py | 17 +++++++++-------- src/br/analysis/run_features.py | 8 ++++---- src/br/data/get_datamodules.py | 5 +++++ 8 files changed, 45 insertions(+), 39 deletions(-) diff --git a/configs/results/cellpack.yaml b/configs/results/cellpack.yaml index 59ef98f..67cd66e 100644 --- a/configs/results/cellpack.yaml +++ b/configs/results/cellpack.yaml @@ -17,13 +17,13 @@ names: "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_jitter", ] -data_paths: [ - "./configs/data/cellpack/image.yaml", - "./configs/data/cellpack/image.yaml", - "./configs/data/cellpack/pc.yaml", - "./configs/data/cellpack/pc.yaml", - # "./src/br/configs/data/cellpack/pc_jitter.yaml", - "./configs/data/cellpack/pc.yaml", +data_paths: + [ + "/data/cellpack/image.yaml", + "/data/cellpack/image.yaml", + "/data/cellpack/pc.yaml", + "/data/cellpack/pc.yaml", + "/data/cellpack/pc.yaml", ] classification_label: ["rule"] regression_label: diff --git a/configs/results/npm1.yaml b/configs/results/npm1.yaml index 58f4d10..d2c2dd0 100644 --- a/configs/results/npm1.yaml +++ b/configs/results/npm1.yaml @@ -19,11 +19,11 @@ names: ] data_paths: [ - "./configs/data/npm1/pc.yaml", - "./configs/data/npm1/so3_image_sdf.yaml", - "./configs/data/npm1/so3_image_seg.yaml", - "./configs/data/npm1/classical_image_sdf.yaml", - "./configs/data/npm1/classical_image_seg.yaml", + "/data/npm1/pc.yaml", + "/data/npm1/so3_image_sdf.yaml", + "/data/npm1/so3_image_seg.yaml", + "/data/npm1/classical_image_sdf.yaml", + "/data/npm1/classical_image_seg.yaml", ] classification_label: ["STR_connectivity_cc_thresh"] regression_label: diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index cba8773..76b9455 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -19,11 +19,11 @@ names: ] data_paths: [ - "./configs/data/other_polymorphic/pc.yaml", - "./configs/data/other_polymorphic/so3_image_sdf.yaml", - "./configs/data/other_polymorphic/so3_image_seg.yaml", - "./configs/data/other_polymorphic/classical_image_sdf.yaml", - "./configs/data/other_polymorphic/classical_image_seg.yaml", + "/other_polymorphic/pc.yaml", + "/other_polymorphic/so3_image_sdf.yaml", + "/other_polymorphic/so3_image_seg.yaml", + "/other_polymorphic/classical_image_sdf.yaml", + "/other_polymorphic/classical_image_seg.yaml", ] classification_label: ["structure_name"] regression_label: ["avg_dists", "mean_volume", "mean_surface_area"] diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index 2e5a109..5b1772b 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -19,11 +19,11 @@ names: ] data_paths: [ - "./configs/data/other_punctate/image.yaml", - "./configs/data/other_punctate/image.yaml", - "./configs/data/other_punctate/pc.yaml", - "./configs/data/other_punctate/pc_intensity.yaml", - "./configs/data/other_punctate/pc_intensity_structurenorm.yaml", + "/data/other_punctate/image.yaml", + "/data/other_punctate/image.yaml", + "/data/other_punctate/pc.yaml", + "/data/other_punctate/pc_intensity.yaml", + "/data/other_punctate/pc_intensity_structurenorm.yaml", ] classification_label: ["structure_name", "cell_stage"] regression_label: diff --git a/configs/results/pcna.yaml b/configs/results/pcna.yaml index 9e0e8f4..e481b22 100644 --- a/configs/results/pcna.yaml +++ b/configs/results/pcna.yaml @@ -18,12 +18,12 @@ names: "Rotation_invariant_pointcloud_jitter", ] data_paths: [ - "./configs/data/pcna/image.yaml", - "./configs/data/pcna/image.yaml", - "./configs/data/pcna/pc.yaml", - "./configs/data/pcna/pc_intensity.yaml", + "/data/pcna/image.yaml", + "/data/pcna/image.yaml", + "/data/pcna/pc.yaml", + "/data/pcna/pc_intensity.yaml", # "./src/br/configs/data/pcna/pc_intensity_jitter.yaml", - "./configs/data/pcna/pc_intensity.yaml", + "/data/pcna/pc_intensity.yaml", ] classification_label: ["cell_stage_fine", "flag_comment"] regression_label: diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 127673f..6373ec4 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -31,12 +31,15 @@ def main(args): # Set the device device = "cuda:0" - # Set working directory and paths - os.chdir(args.src_path) + # # Set working directory and paths + # os.chdir(args.src_path) + + # Get config path from CYTODL_CONFIG_PATH + config_path = os.environ.get('CYTODL_CONFIG_PATH') # Load data and models data_list, all_models, run_names, model_sizes, manifest, _, _ = get_data_and_models( - args.dataset_name, args.batch_size, args.results_path, args.debug + args.dataset_name, args.batch_size, config_path + '/results/', args.debug ) # Load evaluation params @@ -74,13 +77,11 @@ def main(args): parser.add_argument( "--save_path", type=str, required=True, help="Path to save the embeddings." ) - parser.add_argument( - "--results_path", type=str, required=True, help="Path to the results directory." - ) parser.add_argument( "--meta_key", type=str, - required=True, + default=None, + required=False, help="Metadata to add to the embeddings aside from CellId", ) parser.add_argument( @@ -96,7 +97,7 @@ def main(args): args = parser.parse_args() # Validate that required paths are provided - if not args.src_path or not args.save_path or not args.results_path or not args.dataset_name: + if not args.src_path or not args.save_path or not args.dataset_name: print("Error: Required arguments are missing.") sys.exit(1) diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 8a43478..3a45a75 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -39,9 +39,12 @@ def main(args): # set batch size to 1 for emission stats/features batch_size = 1 + # Get config path from CYTODL_CONFIG_PATH + config_path = os.environ.get('CYTODL_CONFIG_PATH') + # Load data and models data_list, all_models, run_names, model_sizes, manifest, keys, latent_dims = get_data_and_models( - args.dataset_name, batch_size, args.results_path, args.debug + args.dataset_name, batch_size, config_path + '/results/', args.debug ) max_embed_dim = min(latent_dims) @@ -126,9 +129,6 @@ def main(args): parser.add_argument( "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." ) - parser.add_argument( - "--results_path", type=str, required=True, help="Path to the results directory." - ) parser.add_argument( "--meta_key", type=str, diff --git a/src/br/data/get_datamodules.py b/src/br/data/get_datamodules.py index ed34c36..c399325 100644 --- a/src/br/data/get_datamodules.py +++ b/src/br/data/get_datamodules.py @@ -9,8 +9,13 @@ def get_data(dataset_name, batch_size, results_path, debug=False): DATA_LIST = get_all_configs_per_dataset(results_path) config_list = DATA_LIST[dataset_name] + + # Get config path from CYTODL_CONFIG_PATH + cytodl_config_path = os.environ.get('CYTODL_CONFIG_PATH') + data = [] for config_path in config_list["data_paths"]: + config_path = cytodl_config_path + config_path with open(config_path) as stream: config = yaml.safe_load(stream) if batch_size: From c734d7f637e0a2c65b4a3439e7926a275292f801 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Wed, 20 Nov 2024 16:55:47 -0800 Subject: [PATCH 08/35] debug runs --- src/br/analysis/analysis_utils.py | 34 +++++++++----- src/br/analysis/run_embeddings.py | 39 +++++++--------- src/br/analysis/run_features.py | 77 +++++++++++++++++-------------- src/br/data/get_datamodules.py | 2 +- src/br/models/compute_features.py | 2 +- 5 files changed, 85 insertions(+), 69 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 318174b..523e695 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -1,5 +1,7 @@ import subprocess + import torch + from br.models.utils import get_all_configs_per_dataset @@ -132,16 +134,16 @@ def _setup_evolve_params(run_names, data_config_list, keys): eval_meshed_img_model_type.append(model_type) evolve_params = { - 'modality_list_evolve': keys, - 'config_list_evolve': data_config_list, - 'num_evolve_samples': 40, - 'compute_evolve_dataloaders': compute_evolve_dataloaders, - 'eval_meshed_img': eval_meshed_img, - 'eval_meshed_img_model_type': eval_meshed_img_model_type, + "modality_list_evolve": keys, + "config_list_evolve": data_config_list, + "num_evolve_samples": 40, + "compute_evolve_dataloaders": compute_evolve_dataloaders, + "eval_meshed_img": eval_meshed_img, + "eval_meshed_img_model_type": eval_meshed_img_model_type, "skew_scale": None, "only_embedding": False, "fit_pca": False, - "pc_is_iae":False + "pc_is_iae": False, } return evolve_params @@ -149,16 +151,26 @@ def _setup_evolve_params(run_names, data_config_list, keys): def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): DATA_LIST = get_all_configs_per_dataset(results_path) data_config_list = DATA_LIST[dataset_name]["data_paths"] - class_label = DATA_LIST[dataset_name]['classification_label'] - regression_label = DATA_LIST[dataset_name]['regression_label'] + class_label = DATA_LIST[dataset_name]["classification_label"] + regression_label = DATA_LIST[dataset_name]["regression_label"] evolve_params = _setup_evolve_params(run_names, data_config_list, keys) classification_params = {"class_labels": class_label, "df_feat": manifest} rot_inv_params = {"squeeze_2d": False, "id": "cell_id", "max_batches": 4000} - regression_params = {"df_feat": manifest, "target_cols": regression_label, "feature_df_path": None} + regression_params = { + "df_feat": manifest, + "target_cols": regression_label, + "feature_df_path": None, + } compactness_params = { "method": "mle", "num_PCs": None, "blobby_outlier_max_cc": None, "check_duplicates": True, } - return rot_inv_params, compactness_params, classification_params, evolve_params, regression_params \ No newline at end of file + return ( + rot_inv_params, + compactness_params, + classification_params, + evolve_params, + regression_params, + ) diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 6373ec4..8a4ec39 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -2,13 +2,13 @@ import argparse import gc import os +import sys + import torch + +from br.analysis.analysis_utils import _setup_evaluation_params, config_gpu from br.models.load_models import get_data_and_models -from br.models.save_embeddings import ( - save_embeddings, -) -import sys -from br.analysis.analysis_utils import config_gpu, _setup_evaluation_params +from br.models.save_embeddings import save_embeddings def main(args): @@ -19,6 +19,10 @@ def main(args): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() + selected_gpu_id_or_uuid = "MIG-5c1d3311-7294-5551-9e4f-3535560f5f82" + import ipdb + + ipdb.set_trace() # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: @@ -35,11 +39,11 @@ def main(args): # os.chdir(args.src_path) # Get config path from CYTODL_CONFIG_PATH - config_path = os.environ.get('CYTODL_CONFIG_PATH') + config_path = os.environ.get("CYTODL_CONFIG_PATH") # Load data and models data_list, all_models, run_names, model_sizes, manifest, _, _ = get_data_and_models( - args.dataset_name, args.batch_size, config_path + '/results/', args.debug + args.dataset_name, args.batch_size, config_path + "/results/", args.debug ) # Load evaluation params @@ -58,7 +62,7 @@ def main(args): all_models, run_names, args.debug, - ["train", "val", "test"], # splits to compute embeddings + ["train", "val", "test"], # splits to compute embeddings device, args.meta_key, loss_eval_list, @@ -71,9 +75,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Script for computing embeddings") - parser.add_argument( - "--src_path", type=str, required=True, help="Path to the source directory." - ) parser.add_argument( "--save_path", type=str, required=True, help="Path to save the embeddings." ) @@ -97,19 +98,13 @@ def main(args): args = parser.parse_args() # Validate that required paths are provided - if not args.src_path or not args.save_path or not args.dataset_name: + if not args.save_path or not args.dataset_name: print("Error: Required arguments are missing.") sys.exit(1) main(args) -""" -Example -os.chdir(r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/") -save_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/test_cellpack_save_embeddings/" -results_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/configs/results/" -dataset_name = "cellpack" -batch_size = 2 -debug = True - -""" + """ + Example run: + python src/br/analysis/run_embeddings.py --save_path "./testing/" --sdf False --dataset_name "pcna" --batch_size 5 --debug True + """ diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 3a45a75..1f29c46 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -2,15 +2,19 @@ import argparse import gc import os +import sys + import pandas as pd import torch + +from br.analysis.analysis_utils import ( + _get_feature_params, + _setup_evaluation_params, + config_gpu, +) from br.models.compute_features import compute_features from br.models.load_models import get_data_and_models -from br.models.save_embeddings import ( - save_emissions, -) -import sys -from br.analysis.analysis_utils import config_gpu, _setup_evaluation_params +from br.models.save_embeddings import save_emissions def main(args): @@ -21,6 +25,7 @@ def main(args): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() + selected_gpu_id_or_uuid = "MIG-5c1d3311-7294-5551-9e4f-3535560f5f82" # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: @@ -33,19 +38,22 @@ def main(args): # Set the device device = "cuda:0" - # Set working directory and paths - os.chdir(args.src_path) - # set batch size to 1 for emission stats/features batch_size = 1 # Get config path from CYTODL_CONFIG_PATH - config_path = os.environ.get('CYTODL_CONFIG_PATH') + config_path = os.environ.get("CYTODL_CONFIG_PATH") # Load data and models - data_list, all_models, run_names, model_sizes, manifest, keys, latent_dims = get_data_and_models( - args.dataset_name, batch_size, config_path + '/results/', args.debug - ) + ( + data_list, + all_models, + run_names, + model_sizes, + manifest, + keys, + latent_dims, + ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) max_embed_dim = min(latent_dims) # Save model sizes to CSV @@ -64,6 +72,7 @@ def main(args): ) = _setup_evaluation_params(manifest, run_names) # Save emission stats for each model + args.debug = True max_batches = 40 save_emissions( args.save_path, @@ -81,7 +90,15 @@ def main(args): ) # Compute multi-metric benchmarking params - rot_inv_params, compactness_params, classification_params, evolve_params, regression_params = _get_feature_params(results_path, dataset_name, manifest, keys, run_names) + ( + rot_inv_params, + compactness_params, + classification_params, + evolve_params, + regression_params, + ) = _get_feature_params( + config_path + "/results/", args.dataset_name, manifest, keys, run_names + ) metric_list = [ "Rotation Invariance Error", @@ -90,13 +107,13 @@ def main(args): "Classification", "Compactness", ] - if len(regression_params['target_cols']) > 0: - metric_list.append('Regression') - + if regression_params["target_cols"]: + metric_list.append("Regression") + # Compute multi-metric benchmarking features compute_features( dataset=args.dataset_name, - results_path=args.results_path, + results_path=config_path + "/results/", embeddings_path=args.embeddings_path, save_folder=args.save_path, data_list=data_list, @@ -119,10 +136,7 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script for Benchmarking Representations") - parser.add_argument( - "--src_path", type=str, required=True, help="Path to the source directory." - ) + parser = argparse.ArgumentParser(description="Script for computing features") parser.add_argument( "--save_path", type=str, required=True, help="Path to save the embeddings." ) @@ -132,7 +146,8 @@ def main(args): parser.add_argument( "--meta_key", type=str, - required=True, + default=None, + required=False, help="Metadata to add to the embeddings aside from CellId", ) parser.add_argument( @@ -142,24 +157,18 @@ def main(args): help="boolean indicating whether the experiments involve SDFs", ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") - parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") + parser.add_argument("--debug", type=bool, default=False, help="Enable debug mode.") args = parser.parse_args() # Validate that required paths are provided - if not args.src_path or not args.save_path or not args.results_path or not args.dataset_name: + if not args.embeddings_path or not args.save_path or not args.dataset_name: print("Error: Required arguments are missing.") sys.exit(1) main(args) -""" -Example -os.chdir(r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/") -save_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/test_cellpack_save_embeddings/" -results_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/configs/results/" -dataset_name = "cellpack" -batch_size = 2 -debug = True - -""" + """ + Example run: + python src/br/analysis/run_features.py --save_path "./testing/" --embeddings_path "/allen/aics/modeling/ritvik/projects/second_clones/benchmarking_representations/test_pcna_save_embeddings_revisit/" --sdf False --dataset_name "pcna" + """ diff --git a/src/br/data/get_datamodules.py b/src/br/data/get_datamodules.py index c399325..c302348 100644 --- a/src/br/data/get_datamodules.py +++ b/src/br/data/get_datamodules.py @@ -11,7 +11,7 @@ def get_data(dataset_name, batch_size, results_path, debug=False): config_list = DATA_LIST[dataset_name] # Get config path from CYTODL_CONFIG_PATH - cytodl_config_path = os.environ.get('CYTODL_CONFIG_PATH') + cytodl_config_path = os.environ.get("CYTODL_CONFIG_PATH") data = [] for config_path in config_list["data_paths"]: diff --git a/src/br/models/compute_features.py b/src/br/models/compute_features.py index e88438c..d57b607 100644 --- a/src/br/models/compute_features.py +++ b/src/br/models/compute_features.py @@ -166,7 +166,7 @@ def compute_features( if "Reconstruction" in metric_list: print("Getting reconstruction") - rec_df = all_ret[["model", "split", "loss"]].groupby(["model", "split"]).mean() + rec_df = all_ret.loc[all_ret['split'] == 'test'].reset_index(drop=True) rec_df.to_csv(path / "reconstruction.csv") metric_list.pop(metric_list.index("Reconstruction")) From cf470b3d74f100e59a159dcc1f130c102e526134 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 13:03:57 -0800 Subject: [PATCH 09/35] test runs for embeddings and features pass for PCNA dataset --- src/br/analysis/analysis_utils.py | 10 +- src/br/analysis/fig2_cellpack.py | 180 ------------------------------ src/br/analysis/run_embeddings.py | 2 +- src/br/analysis/run_features.py | 17 ++- src/br/features/classification.py | 2 +- src/br/features/plot.py | 16 ++- src/br/features/regression.py | 2 +- src/br/models/.gitkeep | 0 src/br/models/utils.py | 6 +- 9 files changed, 30 insertions(+), 205 deletions(-) delete mode 100644 src/br/models/.gitkeep diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 523e695..d50658b 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -1,7 +1,6 @@ import subprocess - import torch - +import os from br.models.utils import get_all_configs_per_dataset @@ -119,9 +118,8 @@ def _setup_evaluation_params(manifest, run_names): def _setup_evolve_params(run_names, data_config_list, keys): eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) - compute_evolve_dataloaders = False + compute_evolve_dataloaders = True if "SDF" in "\t".join(run_names): - compute_evolve_dataloaders = True eval_meshed_img = [True] * len(run_names) eval_meshed_img_model_type = [] for name_ in run_names: @@ -140,7 +138,7 @@ def _setup_evolve_params(run_names, data_config_list, keys): "compute_evolve_dataloaders": compute_evolve_dataloaders, "eval_meshed_img": eval_meshed_img, "eval_meshed_img_model_type": eval_meshed_img_model_type, - "skew_scale": None, + "skew_scale": 100, "only_embedding": False, "fit_pca": False, "pc_is_iae": False, @@ -151,6 +149,8 @@ def _setup_evolve_params(run_names, data_config_list, keys): def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): DATA_LIST = get_all_configs_per_dataset(results_path) data_config_list = DATA_LIST[dataset_name]["data_paths"] + cytodl_config_path = os.environ.get("CYTODL_CONFIG_PATH") + data_config_list = [cytodl_config_path + i for i in data_config_list] class_label = DATA_LIST[dataset_name]["classification_label"] regression_label = DATA_LIST[dataset_name]["regression_label"] evolve_params = _setup_evolve_params(run_names, data_config_list, keys) diff --git a/src/br/analysis/fig2_cellpack.py b/src/br/analysis/fig2_cellpack.py index 1fb68cb..bb4d6fd 100644 --- a/src/br/analysis/fig2_cellpack.py +++ b/src/br/analysis/fig2_cellpack.py @@ -15,186 +15,6 @@ # %% # %load_ext autoreload # %autoreload 2 -import os - -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 -os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-ffdee303-0dd4-513d-b18c-beba028b49c7" -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import yaml -from hydra.utils import instantiate -from PIL import Image -from torch.utils.data import DataLoader, Dataset - -from br.features.archetype import AA_Fast -from br.features.plot import collect_outputs, plot, plot_stratified_pc -from br.features.reconstruction import stratified_latent_walk -from br.features.utils import ( - normalize_intensities_and_get_colormap, - normalize_intensities_and_get_colormap_apply, -) -from br.models.compute_features import compute_features, get_embeddings -from br.models.load_models import get_data_and_models -from br.models.save_embeddings import ( - get_pc_loss, - get_pc_loss_chamfer, - save_embeddings, - save_emissions, -) -from br.models.utils import get_all_configs_per_dataset - -device = "cuda:0" - -# %% [markdown] -# # Load data and models - -# %% -# Set paths -os.chdir("../../benchmarking_representations/") -save_path = "./test_cellpack_save_embeddings/" - -# %% -# Get datamodules, models, runs, model sizes - -dataset_name = "cellpack" -batch_size = 2 -debug = True -results_path = "./configs/results/" -data_list, all_models, run_names, model_sizes = get_data_and_models( - dataset_name, batch_size, results_path, debug -) -gg = pd.DataFrame() -gg["model"] = run_names -gg["model_size"] = model_sizes -gg.to_csv(save_path + "model_sizes.csv") - -# %% [markdown] -# # Compute embeddings and emissions - -# %% -# Compute embeddings and reconstructions for each model - -debug = False -splits_list = ["train", "val", "test"] -meta_key = "rule" -eval_scaled_img = [False] * 5 -eval_scaled_img_params = [{}] * 5 -loss_eval_list = None -sample_points_list = [True, True, False, False, False] -skew_scale = 100 -save_embeddings( - save_path, - data_list, - all_models, - run_names, - debug, - splits_list, - device, - meta_key, - loss_eval_list, - sample_points_list, - skew_scale, - eval_scaled_img, - eval_scaled_img_params, -) - -# %% -# Save emission stats for each model - -max_batches = 2 -save_emissions( - save_path, - data_list, - all_models, - run_names, - max_batches, - debug, - device, - loss_eval_list, - sample_points_list, - skew_scale, - eval_scaled_img, - eval_scaled_img_params, -) - -# %% [markdown] -# # Compute benchmarking features - -# %% -# Compute multi-metric benchmarking features - -keys = ["pcloud"] * 5 -max_embed_dim = 256 -DATA_LIST = get_all_configs_per_dataset(results_path) -data_config_list = DATA_LIST[dataset_name]["data_paths"] - -evolve_params = { - "modality_list_evolve": keys, - "config_list_evolve": data_config_list, - "num_evolve_samples": 40, - "compute_evolve_dataloaders": False, - "eval_meshed_img": [False] * 5, - "skew_scale": 100, - "eval_meshed_img_model_type": [None] * 5, - "only_embedding": False, - "fit_pca": False, -} - -loss_eval = get_pc_loss_chamfer() -loss_eval_list = [loss_eval] * 5 -use_sample_points_list = [True, True, False, False, False] - -classification_params = {"class_labels": ["rule"]} -rot_inv_params = {"squeeze_2d": False, "id": "cell_id", "max_batches": 4000} - -regression_params = {"df_feat": None, "target_cols": None, "feature_df_path": None} - -compactness_params = { - "method": "mle", - "num_PCs": None, - "blobby_outlier_max_cc": None, - "check_duplicates": True, -} - -splits_list = ["train", "val", "test"] -compute_embeds = False - -metric_list = [ - "Rotation Invariance Error", - "Evolution Energy", - "Reconstruction", - "Classification", - "Compactness", -] - - -compute_features( - dataset=dataset_name, - results_path=results_path, - embeddings_path=save_path, - save_folder=save_path, - data_list=data_list, - all_models=all_models, - run_names=run_names, - use_sample_points_list=use_sample_points_list, - keys=keys, - device=device, - max_embed_dim=max_embed_dim, - splits_list=splits_list, - compute_embeds=compute_embeds, - classification_params=classification_params, - regression_params=regression_params, - metric_list=metric_list, - loss_eval_list=loss_eval_list, - evolve_params=evolve_params, - rot_inv_params=rot_inv_params, - compactness_params=compactness_params, -) # %% [markdown] # # Polar plot viz diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 8a4ec39..380810d 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -106,5 +106,5 @@ def main(args): """ Example run: - python src/br/analysis/run_embeddings.py --save_path "./testing/" --sdf False --dataset_name "pcna" --batch_size 5 --debug True + python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name "pcna" --batch_size 5 --debug False """ diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 1f29c46..dd05085 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -15,6 +15,7 @@ from br.models.compute_features import compute_features from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_emissions +from br.features.plot import collect_outputs, plot def main(args): @@ -25,7 +26,6 @@ def main(args): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() - selected_gpu_id_or_uuid = "MIG-5c1d3311-7294-5551-9e4f-3535560f5f82" # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: @@ -72,7 +72,6 @@ def main(args): ) = _setup_evaluation_params(manifest, run_names) # Save emission stats for each model - args.debug = True max_batches = 40 save_emissions( args.save_path, @@ -134,6 +133,18 @@ def main(args): compactness_params=compactness_params, ) + # Polar plot visualization + # Load saved csvs + csvs = [i for i in os.listdir(args.save_path) if i.split('.')[-1] == 'csv'] + csvs = [i.split('.')[0] for i in csvs] + # Remove non metric related csvs + csvs = [i for i in csvs if i not in run_names and i not in keys] + # classification and regression metrics are unique to each dataset + unique_metrics = [i for i in csvs if "classification" in i or "regression" in i] + # Collect dataframe and make plots + df, df_non_agg = collect_outputs(args.save_path, "std", run_names, csvs) + plot(args.save_path, df, run_names, args.dataset_name, "std", unique_metrics) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Script for computing features") @@ -170,5 +181,5 @@ def main(args): """ Example run: - python src/br/analysis/run_features.py --save_path "./testing/" --embeddings_path "/allen/aics/modeling/ritvik/projects/second_clones/benchmarking_representations/test_pcna_save_embeddings_revisit/" --sdf False --dataset_name "pcna" + python src/br/analysis/run_features.py --save_path "./outputs/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --sdf False --dataset_name "pcna" """ diff --git a/src/br/features/classification.py b/src/br/features/classification.py index 00abfce..3eb04fd 100644 --- a/src/br/features/classification.py +++ b/src/br/features/classification.py @@ -19,7 +19,7 @@ def get_classification_df(all_ret, target_col, df_feat=None): for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): this_mo = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) - if df_feat and target_col not in this_mo.columns: + if df_feat is not None and target_col not in this_mo.columns: this_mo = this_mo.merge(df_feat, on="CellId") k1, k2, k3 = get_classification(this_mo, target_col) for i in range(len(k1)): diff --git a/src/br/features/plot.py b/src/br/features/plot.py index e436a5d..f9e7ea5 100644 --- a/src/br/features/plot.py +++ b/src/br/features/plot.py @@ -157,7 +157,6 @@ def plot( df, models, title, - colors_list=None, norm="std", unique_expressivity_metrics=None, ): @@ -192,7 +191,6 @@ def plot( # if colors_list is not None: # colors = pal.as_hex() # else: - colors = colors_list all_models = [] for i in models: @@ -203,11 +201,11 @@ def plot( this_model.append(val) all_models.append(this_model) if len(models) == 5: - colors = ["#636EFA", "#00CC96", "#AB63FA", "#FFA15A", "#EF553B"] + colors = ["#9CA2D6", "#6277D1", "#CF8D84", "#CE553B", "#2ED9FF"] elif len(models) == 4: - colors = ["#636EFA", "#00CC96", "#AB63FA", "#EF553B"] + colors = ["#9CA2D6", "#6277D1", "#CF8D84", "#CE553B"] elif len(models) == 2: - colors = ["#636EFA", "#EF553B"] + colors = ["#9CA2D6", "#6277D1"] else: pal = sns.color_palette("pastel") colors = pal.as_hex() @@ -238,8 +236,8 @@ def plot( title=go.layout.Title(text=f"{title}"), polar={"radialaxis": {"visible": True, "range": range_vals, "dtick": 2}}, showlegend=True, - margin=dict(l=170, r=150, t=120, b=80), - legend=dict(orientation="h", xanchor="center", x=1.2, y=1.5), + margin=dict(l=170, r=150, t=20, b=80), + legend=dict(orientation="h", xanchor="center", x=1.2, y=1.8), font=dict( family="Myriad Pro", size=20, # Set the font size here @@ -248,8 +246,8 @@ def plot( ), ) - fig.write_image(path / f"{title}.png", scale=2) - fig.write_image(path / f"{title}.pdf", scale=2) + fig.write_image(path / f"{title}.png", scale=3) + fig.write_image(path / f"{title}.pdf", scale=3) # fig.write_image(path / f"{title}.eps", scale=2) # fig.write_image(path / f"{title}.pdf") diff --git a/src/br/features/regression.py b/src/br/features/regression.py index 1e0440c..2504084 100644 --- a/src/br/features/regression.py +++ b/src/br/features/regression.py @@ -17,7 +17,7 @@ def get_regression_df(all_ret, target_cols, feature_df_path, df_feat=None): for target in target_cols: for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): this_mo = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) - if df_feat and target not in this_mo.columns: + if df_feat is not None and target not in this_mo.columns: this_mo = this_mo.merge(df_feat, on="CellId") test_r2, test_mse = get_regression(this_mo, target) for i in range(len(test_r2)): diff --git a/src/br/models/.gitkeep b/src/br/models/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/src/br/models/utils.py b/src/br/models/utils.py index f2c2c4d..2762219 100644 --- a/src/br/models/utils.py +++ b/src/br/models/utils.py @@ -124,11 +124,7 @@ def sample_points(orig, skew_scale): pcloud = [] for i in range(orig.shape[0]): raw = orig[i, 0] - try: - new_cents = _sample(raw, skew_scale) - except: - print("exception") - new_cents = _sample(raw, 100) + new_cents = _sample(raw, skew_scale) pcloud.append(new_cents) pcloud = np.stack(pcloud, axis=0) return torch.tensor(pcloud) From 1292b7f58bae2c5353532db49bb8d5f3c2e1d734 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 13:05:25 -0800 Subject: [PATCH 10/35] remove pdb --- src/br/analysis/run_embeddings.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 380810d..2627dd0 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -19,10 +19,6 @@ def main(args): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() - selected_gpu_id_or_uuid = "MIG-5c1d3311-7294-5551-9e4f-3535560f5f82" - import ipdb - - ipdb.set_trace() # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: From 34447f6a1e58d7b8a90918492a8b22adb88165fc Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 14:07:06 -0800 Subject: [PATCH 11/35] move gpu stuff to utils --- src/br/analysis/analysis_utils.py | 23 ++- src/br/analysis/fig2_cellpack.py | 333 +++++++++++++++--------------- src/br/analysis/run_embeddings.py | 27 +-- src/br/analysis/run_features.py | 29 +-- src/br/models/compute_features.py | 2 +- src/br/models/load_models.py | 38 ++-- 6 files changed, 214 insertions(+), 238 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index d50658b..2ae25fb 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -1,6 +1,9 @@ +import gc +import os import subprocess + import torch -import os + from br.models.utils import get_all_configs_per_dataset @@ -64,6 +67,24 @@ def config_gpu(): return selected_gpu_id_or_uuid +def _setup_gpu(): + # Free up cache + gc.collect() + torch.cuda.empty_cache() + + # Based on the utilization, set the GPU ID + # Setting a GPU ID is crucial for the script to work well! + selected_gpu_id_or_uuid = config_gpu() + + # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID + if selected_gpu_id_or_uuid: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid + print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") + else: + print("No suitable GPU or MIG ID found. Exiting...") + + def _setup_evaluation_params(manifest, run_names): """Return evaluation params related to. diff --git a/src/br/analysis/fig2_cellpack.py b/src/br/analysis/fig2_cellpack.py index bb4d6fd..c2f90b2 100644 --- a/src/br/analysis/fig2_cellpack.py +++ b/src/br/analysis/fig2_cellpack.py @@ -1,178 +1,171 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.16.4 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- - -# %% -# %load_ext autoreload -# %autoreload 2 - -# %% [markdown] -# # Polar plot viz - -# %% -# Holistic viz of features - -model_order = [ - "Classical_image", - "Rotation_invariant_image", - "Classical_pointcloud", - "Rotation_invariant_pointcloud", -] -metric_list = [ - "reconstruction", - "emissions", - "classification_rule", - "compactness", - "evolution_energy", - "model_sizes", - "rotation_invariance_error", -] -norm = "std" -title = "cellpack_comparison" -colors_list = None -unique_expressivity_metrics = ["Classification_rule"] -df, df_non_agg = collect_outputs(save_path, norm, model_order, metric_list) -plot(save_path, df, model_order, title, colors_list, norm, unique_expressivity_metrics) - -# %% [markdown] -# # Latent walks - -# %% -# Load model and embeddings - -run_names = ["Rotation_invariant_pointcloud_jitter"] -DATASET_INFO = get_all_configs_per_dataset(results_path) -all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path) -model = all_models[-1] - -# %% -# Params for viz -key = "pcloud" -stratify_key = "rule" -z_max = 0.3 -z_ind = 1 -flip = True -views = ["xy"] -xlim = [-20, 20] -ylim = [-20, 20] - -# %% -# Compute stratified latent walk - -this_save_path = Path(save_path) / Path("latent_walks") -this_save_path.mkdir(parents=True, exist_ok=True) - -stratified_latent_walk( - model, - device, - all_ret, - "pcloud", - 256, - 256, - 2, - this_save_path, - stratify_key, - latent_walk_range=[-2, 0, 2], - z_max=z_max, - z_ind=z_ind, +import os +import torch +from Pathlib import Path +import pandas as pd +from br.analysis.analysis_utils import _setup_gpu +from br.models.compute_features import get_embeddings +from br.models.load_models import _load_model_from_path +from br.models.utils import get_all_configs_per_dataset +from br.features.plot import plot_stratified_pc +from br.features.reconstruction import stratified_latent_walk, save_pcloud +import argparse +from br.features.utils import ( + normalize_intensities_and_get_colormap, + normalize_intensities_and_get_colormap_apply, ) -# %% -# Save reconstruction plots -items = os.listdir(this_save_path) -fnames = [i for i in items if i.split(".")[-1] == "csv"] -fnames = [i for i in fnames if i.split("_")[1] == "0"] -names = [i.split(".")[0] for i in fnames] -cm_name = "inferno" - -all_df = [] -for idx, _ in enumerate(fnames): - fname = fnames[idx] - df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) - df, cmap, vmin, vmax = normalize_intensities_and_get_colormap( - df, pcts=[5, 95], cm_name=cm_name - ) - df[stratify_key] = names[idx] - all_df.append(df) -df = pd.concat(all_df, axis=0).reset_index(drop=True) - -plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip) - -# %% [markdown] -# # Archetype analysis - -# %% -# Fit 6 archetypes -this_ret = all_ret -labels = this_ret["rule"].values -matrix = this_ret[[i for i in this_ret.columns if "mu" in i]].values - -n_archetypes = 6 -aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix) -archetypes_df = pd.DataFrame(aa.Z, columns=[f"mu_{i}" for i in range(matrix.shape[1])]) - -# %% -# Save reconstructions - -this_save_path = Path(save_path) / Path("archetypes") -this_save_path.mkdir(parents=True, exist_ok=True) - -model = model.eval() -key = "pcloud" -all_xhat = [] -with torch.no_grad(): - for i in range(n_archetypes): - z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0) - z_inf = z_inf.to(device) - z_inf = z_inf.float() - decoder = model.decoder[key] - xhat = decoder(z_inf) - xhat = xhat.detach().cpu().numpy() - xhat = save_pcloud(xhat[0], this_save_path, i, z_max, z_ind) - all_xhat.append(xhat) - - from br.features.plot import plot_pc_saved +from br.features.archetype import AA_Fast +import numpy as np + + +def main(args): + _setup_gpu() + device = "cuda:0" + + config_path = os.environ.get("CYTODL_CONFIG_PATH") + results_path = config_path + "/results/" + + run_name = "Rotation_invariant_pointcloud_jitter" + DATASET_INFO = get_all_configs_per_dataset(results_path) + models = DATASET_INFO[args.dataset_name] + checkpoints = models["model_checkpoints"] + checkpoints = [i for i in checkpoints if run_name in i] + assert len(checkpoints) == 1 + all_ret, df = get_embeddings([run_name], args.dataset_name, DATASET_INFO, args.save_path) + model, x_label, latent_dim, model_size = _load_model_from_path(checkpoints[0], False, device) + + # Compute stratified latent walk + key = "pcloud" + stratify_key = "rule" + z_max = 0.3 + z_ind = 1 + flip = True + views = ["xy"] + xlim = [-20, 20] + ylim = [-20, 20] + this_save_path = Path(args.save_path) / Path("latent_walks") + this_save_path.mkdir(parents=True, exist_ok=True) + + stratified_latent_walk( + model, + device, + all_ret, + "pcloud", + 256, + 256, + 2, + this_save_path, + stratify_key, + latent_walk_range=[-2, 0, 2], + z_max=z_max, + z_ind=z_ind, + ) -names = [str(i) for i in range(n_archetypes)] -key = "archetype" - -plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim) - -# %% -# Save numpy arrays - -key = "archetype" -items = os.listdir(this_save_path) -fnames = [i for i in items if i.split(".")[-1] == "csv"] -names = [i.split(".")[0] for i in fnames] - -df = pd.DataFrame([]) -for idx, _ in enumerate(fnames): - fname = fnames[idx] - print(fname) - dft = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) - dft[key] = names[idx] - df = pd.concat([df, dft], ignore_index=True) + # Save reconstruction plots + items = os.listdir(this_save_path) + fnames = [i for i in items if i.split(".")[-1] == "csv"] + fnames = [i for i in fnames if i.split("_")[1] == "0"] + names = [i.split(".")[0] for i in fnames] + cm_name = "inferno" + + all_df = [] + for idx, _ in enumerate(fnames): + fname = fnames[idx] + df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + df, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df, pcts=[5, 95], cm_name=cm_name + ) + df[stratify_key] = names[idx] + all_df.append(df) + df = pd.concat(all_df, axis=0).reset_index(drop=True) + + plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip) + + # Archetype analysis + # Fit 6 archetypes + this_ret = all_ret + matrix = this_ret[[i for i in this_ret.columns if "mu" in i]].values + + n_archetypes = 6 + aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix) + archetypes_df = pd.DataFrame(aa.Z, columns=[f"mu_{i}" for i in range(matrix.shape[1])]) + + this_save_path = Path(args.save_path) / Path("archetypes") + this_save_path.mkdir(parents=True, exist_ok=True) + + model = model.eval() + key = "pcloud" + all_xhat = [] + with torch.no_grad(): + for i in range(n_archetypes): + z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0) + z_inf = z_inf.to(device) + z_inf = z_inf.float() + decoder = model.decoder[key] + xhat = decoder(z_inf) + xhat = xhat.detach().cpu().numpy() + xhat = save_pcloud(xhat[0], this_save_path, i, z_max, z_ind) + all_xhat.append(xhat) + + names = [str(i) for i in range(n_archetypes)] + key = "archetype" + + plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim) + + # Save numpy arrays + key = "archetype" + items = os.listdir(this_save_path) + fnames = [i for i in items if i.split(".")[-1] == "csv"] + names = [i.split(".")[0] for i in fnames] + + df = pd.DataFrame([]) + for idx, _ in enumerate(fnames): + fname = fnames[idx] + dft = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + dft[key] = names[idx] + df = pd.concat([df, dft], ignore_index=True) + + archetypes = ["0", "1", "2", "3", "4", "5"] + + for arch in archetypes: + this_df = df.loc[df["archetype"] == arch].reset_index(drop=True) + np_arr = this_df[["x", "y", "z"]].values + np.save(this_save_path / Path(f"{arch}.npy"), np_arr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing embeddings") + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument( + "--meta_key", + type=str, + default=None, + required=False, + help="Metadata to add to the embeddings aside from CellId", + ) + parser.add_argument( + "--sdf", + type=bool, + required=True, + help="boolean indicating whether the experiments involve SDFs", + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") + parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") -archetypes = ["0", "1", "2", "3", "4", "5"] + args = parser.parse_args() -for arch in archetypes: - this_df = df.loc[df["archetype"] == arch].reset_index(drop=True) - np_arr = this_df[["x", "y", "z"]].values - print(np_arr.shape) - np.save(this_save_path / Path(f"{arch}.npy"), np_arr) + # Validate that required paths are provided + if not args.save_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) -# %% + main(args) -# %% + """ + Example run: + python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name "pcna" --batch_size 5 --debug False + """ diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 2627dd0..ec8daae 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -1,39 +1,18 @@ # Free up cache import argparse -import gc import os import sys - -import torch - -from br.analysis.analysis_utils import _setup_evaluation_params, config_gpu +from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings def main(args): - # Free up cache - gc.collect() - torch.cuda.empty_cache() - - # Based on the utilization, set the GPU ID - # Setting a GPU ID is crucial for the script to work well! - selected_gpu_id_or_uuid = config_gpu() - # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID - if selected_gpu_id_or_uuid: - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid - print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") - else: - print("No suitable GPU or MIG ID found. Exiting...") - - # Set the device + # Setup GPUs and set the device + _setup_gpu() device = "cuda:0" - # # Set working directory and paths - # os.chdir(args.src_path) - # Get config path from CYTODL_CONFIG_PATH config_path = os.environ.get("CYTODL_CONFIG_PATH") diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index dd05085..69a17ac 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -1,41 +1,24 @@ # Free up cache import argparse -import gc import os import sys import pandas as pd -import torch from br.analysis.analysis_utils import ( _get_feature_params, _setup_evaluation_params, - config_gpu, + _setup_gpu, ) +from br.features.plot import collect_outputs, plot from br.models.compute_features import compute_features from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_emissions -from br.features.plot import collect_outputs, plot def main(args): - # Free up cache - gc.collect() - torch.cuda.empty_cache() - - # Based on the utilization, set the GPU ID - # Setting a GPU ID is crucial for the script to work well! - selected_gpu_id_or_uuid = config_gpu() - - # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID - if selected_gpu_id_or_uuid: - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid - print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") - else: - print("No suitable GPU or MIG ID found. Exiting...") - - # Set the device + # Setup GPUs and set the device + _setup_gpu() device = "cuda:0" # set batch size to 1 for emission stats/features @@ -135,8 +118,8 @@ def main(args): # Polar plot visualization # Load saved csvs - csvs = [i for i in os.listdir(args.save_path) if i.split('.')[-1] == 'csv'] - csvs = [i.split('.')[0] for i in csvs] + csvs = [i for i in os.listdir(args.save_path) if i.split(".")[-1] == "csv"] + csvs = [i.split(".")[0] for i in csvs] # Remove non metric related csvs csvs = [i for i in csvs if i not in run_names and i not in keys] # classification and regression metrics are unique to each dataset diff --git a/src/br/models/compute_features.py b/src/br/models/compute_features.py index d57b607..0f4eeb7 100644 --- a/src/br/models/compute_features.py +++ b/src/br/models/compute_features.py @@ -166,7 +166,7 @@ def compute_features( if "Reconstruction" in metric_list: print("Getting reconstruction") - rec_df = all_ret.loc[all_ret['split'] == 'test'].reset_index(drop=True) + rec_df = all_ret.loc[all_ret["split"] == "test"].reset_index(drop=True) rec_df.to_csv(path / "reconstruction.csv") metric_list.pop(metric_list.index("Reconstruction")) diff --git a/src/br/models/load_models.py b/src/br/models/load_models.py index 71a04b1..2dbd0db 100644 --- a/src/br/models/load_models.py +++ b/src/br/models/load_models.py @@ -8,6 +8,22 @@ from br.models.utils import get_all_configs_per_dataset +def _load_model_from_path(ckpt_path, strict, device): + config_path = ckpt_path.split("ckpt")[0] + "yaml" + with open(config_path) as stream: + config = yaml.safe_load(stream) + model_conf = config["model"] + x_label = model_conf["x_label"] + latent_dim = model_conf["latent_dim"] + model_class = model_conf.pop("_target_") + model_conf = instantiate(model_conf) + model_class = _locate(model_class) + model_ = model_class.load_from_checkpoint( + ckpt_path, **model_conf, strict=strict, map_location=device + ).eval() + return model_, x_label, latent_dim, config["model/params/total"] + + def load_model_from_path(dataset, results_path, strict=False, split="val", device="cuda:0"): MODEL_INFO = get_all_configs_per_dataset(results_path) models = MODEL_INFO[dataset] @@ -17,27 +33,11 @@ def load_model_from_path(dataset, results_path, strict=False, split="val", devic x_labels = [] latent_dims = [] for j, ckpt_path in enumerate(models["model_checkpoints"]): - if "model_paths" in models.keys(): - config_path = models["model_paths"][j] - else: - config_path = ckpt_path.split("ckpt")[0] + "yaml" - with open(config_path) as stream: - config = yaml.safe_load(stream) - model_conf = config["model"] - x_label = model_conf["x_label"] - latent_dim = model_conf["latent_dim"] - model_class = model_conf.pop("_target_") - model_conf = instantiate(model_conf) - model_class = _locate(model_class) - all_models.append( - model_class.load_from_checkpoint( - ckpt_path, **model_conf, strict=strict, map_location=device - ).eval() - ) - model_sizes.append(config["model/params/total"]) + model, x_label, latent_dim, model_size = _load_model_from_path(ckpt_path, strict, device) + all_models.append(model) + model_sizes.append(model_size) x_labels.append(x_label) latent_dims.append(latent_dim) - return all_models, models["names"], model_sizes, model_manifest, x_labels, latent_dims From 512ccd284d246592cfd4710ce0970746c2f93c05 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 15:53:30 -0800 Subject: [PATCH 12/35] move fig notebooks into single analysis script --- src/br/analysis/analysis_utils.py | 217 +++++++++++++++++++++++++++ src/br/analysis/fig2_cellpack.py | 171 --------------------- src/br/analysis/punctate_analysis.py | 102 +++++++++++++ src/br/analysis/run_embeddings.py | 1 + src/br/features/plot.py | 2 +- 5 files changed, 321 insertions(+), 172 deletions(-) delete mode 100644 src/br/analysis/fig2_cellpack.py create mode 100644 src/br/analysis/punctate_analysis.py diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 2ae25fb..1c0d7db 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -1,9 +1,18 @@ import gc import os import subprocess +from pathlib import Path +import numpy as np +import pandas as pd import torch +from br.features.plot import plot_pc_saved, plot_stratified_pc +from br.features.reconstruction import save_pcloud +from br.features.utils import ( + normalize_intensities_and_get_colormap, + normalize_intensities_and_get_colormap_apply, +) from br.models.utils import get_all_configs_per_dataset @@ -75,6 +84,7 @@ def _setup_gpu(): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() + selected_gpu_id_or_uuid = 'MIG-ffdee303-0dd4-513d-b18c-beba028b49c7' # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: @@ -137,6 +147,9 @@ def _setup_evaluation_params(manifest, run_names): def _setup_evolve_params(run_names, data_config_list, keys): + """ + Set up dataloader parameters specific to the evolution energy metric + """ eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) compute_evolve_dataloaders = True @@ -168,6 +181,14 @@ def _setup_evolve_params(run_names, data_config_list, keys): def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): + """ + Get parameters associated with calculation of + 1. Rot invariance + 2. Compactness + 3. Classification + 4. Evolution/Interpolation distance + 5. Regression + """ DATA_LIST = get_all_configs_per_dataset(results_path) data_config_list = DATA_LIST[dataset_name]["data_paths"] cytodl_config_path = os.environ.get("CYTODL_CONFIG_PATH") @@ -195,3 +216,199 @@ def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): evolve_params, regression_params, ) + + +def _dataset_specific_subsetting(all_ret, dataset_name): + """ + Subset each dataset for analysis. + E.g. For PCNA dataset, only look at interphase. + Also specify dataset specific visualization params + - z_max (Max value of z at which to slice data) + - z_ind (Which index is Z - 1, 2, 3) + - views = ['xy'] (show xy projection) + - xlim, ylim = [-20, 20] (max scaling for visualization max projection) + - flip = True when Z and Y are swapped + """ + if dataset_name == "pcna": + interphase_stages = [ + "G1", + "earlyS", + "earlyS-midS", + "midS", + "midS-lateS", + "lateS", + "lateS-G2", + "G2", + ] + all_ret = all_ret.loc[all_ret["cell_stage_fine"].isin(interphase_stages)].reset_index( + drop=True + ) + stratify_key = "cell_stage_fine" + viz_params = {"z_max": 0.3, "z_ind": 2, "flip": False} + n_archetypes = 8 + elif dataset_name == "cellpack": + stratify_key = "rule" + viz_params = {"z_max": 0.3, "z_ind": 1, "flip": True} + n_archetypes = 6 + elif dataset_name == "other_punctate": + structs = ["NUP153", "SON", "HIST1H2BJ", "SMC1A", "CETN2", "SLC25A17", "RAB5A"] + all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)].reset_index(drop=True) + stratify_key = "structure_name" + viz_params = {"z_max": None, "z_ind": 2, "flip": False} + n_archetypes = 7 + else: + raise ValueError("Dataset not in pre-configured list") + viz_params["views"] = ["xy"] + viz_params["xlim"] = [-20, 20] + viz_params["ylim"] = [-20, 20] + return all_ret, stratify_key, n_archetypes, viz_params + + +def _latent_walk_save_recons(this_save_path, stratify_key, viz_params): + """ + Visualize saved latent walks from csvs + this_save_path - folder where csvs are saved + stratify_key - metadata by which PCs are stratified (e.g. "rule") + viz_params - parameters associated with visualization (e.g. xlims, ylims) + """ + items = os.listdir(this_save_path) + fnames = [i for i in items if i.split(".")[-1] == "csv"] # get csvs + fnames = [i for i in fnames if i.split("_")[1] == "0"] # get 1st PC + names = [i.split(".")[0] for i in fnames] + cm_name = "YlGnBu" + all_df = [] + for idx, _ in enumerate(fnames): + fname = fnames[idx] + df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + df, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df, pcts=[5, 95], cm_name=cm_name + ) + df[stratify_key] = names[idx] + all_df.append(df) + df = pd.concat(all_df, axis=0).reset_index(drop=True) + + plot_stratified_pc( + df, + viz_params["xlim"], + viz_params["ylim"], + stratify_key, + this_save_path, + cmap, + viz_params["flip"], + ) + + +def _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path): + """ + Visualize saved archetypes from archetype matrix dataframe + """ + all_xhat = [] + with torch.no_grad(): + for i in range(len(archetypes_df)): + z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0) + z_inf = z_inf.to(device) + z_inf = z_inf.float() + decoder = model.decoder[key] + xhat = decoder(z_inf) + xhat = xhat.detach().cpu().numpy() + xhat = save_pcloud( + xhat[0], this_save_path, i, viz_params["z_max"], viz_params["z_ind"] + ) + all_xhat.append(xhat) + + names = [str(i) for i in range(len(archetypes_df))] + key = "archetype" + plot_pc_saved( + this_save_path, + names, + key, + viz_params["flip"], + 0.5, + viz_params["views"], + viz_params["xlim"], + viz_params["ylim"], + ) + + # Save numpy arrays for mitsuba visualization + key = "archetype" + items = os.listdir(this_save_path) + fnames = [i for i in items if i.split(".")[-1] == "csv"] + names = [i.split(".")[0] for i in fnames] + + df = pd.DataFrame([]) + for idx, _ in enumerate(fnames): + fname = fnames[idx] + dft = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + dft[key] = names[idx] + df = pd.concat([df, dft], ignore_index=True) + + for arch in names: + this_df = df.loc[df["archetype"] == arch].reset_index(drop=True) + np_arr = this_df[["x", "y", "z"]].values + np.save(this_save_path / Path(f"{arch}.npy"), np_arr) + + +def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bins=None): + """ + Psuedotime analysis for PCNA and NPM1 dataset + """ + if not bins: + # Pseudotime bins based on npm1 dataset from WTC-11 hIPS single cell image dataset + bins = [ + (247.407, 390.752), + (390.752, 533.383), + (533.383, 676.015), + (676.015, 818.646), + (818.646, 961.277), + ] + correct_bins = [] + for ind, row in all_ret.iterrows(): + this_bin = [] + for bin_ in bins: + if (row["volume_of_nucleus_um3"] > bin_[0]) and ( + row["volume_of_nucleus_um3"] <= bin_[1] + ): + this_bin.append(bin_) + if row["volume_of_nucleus_um3"] < bins[0][0]: + this_bin.append(bin_) + if row["volume_of_nucleus_um3"] > bins[4][1]: + this_bin.append(bin_) + assert len(this_bin) == 1 + correct_bins.append(this_bin[0]) + all_ret["vol_bins"] = correct_bins + all_ret["vol_bins_inds"] = pd.factorize(all_ret["vol_bins"])[0] + + # Save reconstructions per bin + this_save_path = Path(save_path) / Path("pseudo_time") + this_save_path.mkdir(parents=True, exist_ok=True) + + cols = [i for i in all_ret.columns if "mu" in i] + for ind, gr in all_ret.groupby(["vol_bins"]): + this_stage_df = gr.reset_index(drop=True) + this_stage_mu = this_stage_df[cols].values + mean_mu = this_stage_mu.mean(axis=0) + dist = (this_stage_mu - mean_mu) ** 2 + dist = np.sum(dist, axis=1) + z_inf = torch.tensor(mean_mu).unsqueeze(axis=0) + z_inf = z_inf.to(device) + z_inf = z_inf.float() + + decoder = model.decoder["pcloud"] + xhat = decoder(z_inf) + xhat = save_pcloud( + xhat[0], this_save_path, str(ind), viz_params["z_max"], viz_params["z_ind"] + ) + + names = os.listdir(this_save_path) + names = [i for i in names if i.split(".")[-1] == "csv"] + names = [i.split(".csv")[0] for i in names] + plot_pc_saved( + this_save_path, + names, + key, + viz_params["flip"], + 0.5, + viz_params["views"], + viz_params["xlim"], + viz_params["ylim"], + ) diff --git a/src/br/analysis/fig2_cellpack.py b/src/br/analysis/fig2_cellpack.py deleted file mode 100644 index c2f90b2..0000000 --- a/src/br/analysis/fig2_cellpack.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -from Pathlib import Path -import pandas as pd -from br.analysis.analysis_utils import _setup_gpu -from br.models.compute_features import get_embeddings -from br.models.load_models import _load_model_from_path -from br.models.utils import get_all_configs_per_dataset -from br.features.plot import plot_stratified_pc -from br.features.reconstruction import stratified_latent_walk, save_pcloud -import argparse -from br.features.utils import ( - normalize_intensities_and_get_colormap, - normalize_intensities_and_get_colormap_apply, -) - -from br.features.plot import plot_pc_saved -from br.features.archetype import AA_Fast -import numpy as np - - -def main(args): - _setup_gpu() - device = "cuda:0" - - config_path = os.environ.get("CYTODL_CONFIG_PATH") - results_path = config_path + "/results/" - - run_name = "Rotation_invariant_pointcloud_jitter" - DATASET_INFO = get_all_configs_per_dataset(results_path) - models = DATASET_INFO[args.dataset_name] - checkpoints = models["model_checkpoints"] - checkpoints = [i for i in checkpoints if run_name in i] - assert len(checkpoints) == 1 - all_ret, df = get_embeddings([run_name], args.dataset_name, DATASET_INFO, args.save_path) - model, x_label, latent_dim, model_size = _load_model_from_path(checkpoints[0], False, device) - - # Compute stratified latent walk - key = "pcloud" - stratify_key = "rule" - z_max = 0.3 - z_ind = 1 - flip = True - views = ["xy"] - xlim = [-20, 20] - ylim = [-20, 20] - this_save_path = Path(args.save_path) / Path("latent_walks") - this_save_path.mkdir(parents=True, exist_ok=True) - - stratified_latent_walk( - model, - device, - all_ret, - "pcloud", - 256, - 256, - 2, - this_save_path, - stratify_key, - latent_walk_range=[-2, 0, 2], - z_max=z_max, - z_ind=z_ind, - ) - - # Save reconstruction plots - items = os.listdir(this_save_path) - fnames = [i for i in items if i.split(".")[-1] == "csv"] - fnames = [i for i in fnames if i.split("_")[1] == "0"] - names = [i.split(".")[0] for i in fnames] - cm_name = "inferno" - - all_df = [] - for idx, _ in enumerate(fnames): - fname = fnames[idx] - df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) - df, cmap, vmin, vmax = normalize_intensities_and_get_colormap( - df, pcts=[5, 95], cm_name=cm_name - ) - df[stratify_key] = names[idx] - all_df.append(df) - df = pd.concat(all_df, axis=0).reset_index(drop=True) - - plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip) - - # Archetype analysis - # Fit 6 archetypes - this_ret = all_ret - matrix = this_ret[[i for i in this_ret.columns if "mu" in i]].values - - n_archetypes = 6 - aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix) - archetypes_df = pd.DataFrame(aa.Z, columns=[f"mu_{i}" for i in range(matrix.shape[1])]) - - this_save_path = Path(args.save_path) / Path("archetypes") - this_save_path.mkdir(parents=True, exist_ok=True) - - model = model.eval() - key = "pcloud" - all_xhat = [] - with torch.no_grad(): - for i in range(n_archetypes): - z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0) - z_inf = z_inf.to(device) - z_inf = z_inf.float() - decoder = model.decoder[key] - xhat = decoder(z_inf) - xhat = xhat.detach().cpu().numpy() - xhat = save_pcloud(xhat[0], this_save_path, i, z_max, z_ind) - all_xhat.append(xhat) - - names = [str(i) for i in range(n_archetypes)] - key = "archetype" - - plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim) - - # Save numpy arrays - key = "archetype" - items = os.listdir(this_save_path) - fnames = [i for i in items if i.split(".")[-1] == "csv"] - names = [i.split(".")[0] for i in fnames] - - df = pd.DataFrame([]) - for idx, _ in enumerate(fnames): - fname = fnames[idx] - dft = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) - dft[key] = names[idx] - df = pd.concat([df, dft], ignore_index=True) - - archetypes = ["0", "1", "2", "3", "4", "5"] - - for arch in archetypes: - this_df = df.loc[df["archetype"] == arch].reset_index(drop=True) - np_arr = this_df[["x", "y", "z"]].values - np.save(this_save_path / Path(f"{arch}.npy"), np_arr) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script for computing embeddings") - parser.add_argument( - "--save_path", type=str, required=True, help="Path to save the embeddings." - ) - parser.add_argument( - "--meta_key", - type=str, - default=None, - required=False, - help="Metadata to add to the embeddings aside from CellId", - ) - parser.add_argument( - "--sdf", - type=bool, - required=True, - help="boolean indicating whether the experiments involve SDFs", - ) - parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") - parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") - parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") - - args = parser.parse_args() - - # Validate that required paths are provided - if not args.save_path or not args.dataset_name: - print("Error: Required arguments are missing.") - sys.exit(1) - - main(args) - - """ - Example run: - python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name "pcna" --batch_size 5 --debug False - """ diff --git a/src/br/analysis/punctate_analysis.py b/src/br/analysis/punctate_analysis.py new file mode 100644 index 0000000..5ad4c03 --- /dev/null +++ b/src/br/analysis/punctate_analysis.py @@ -0,0 +1,102 @@ +import os +from pathlib import Path +import pandas as pd +import sys +from br.analysis.analysis_utils import ( + _setup_gpu, + _latent_walk_save_recons, + _dataset_specific_subsetting, + _archetypes_save_recons, + _pseudo_time_analysis +) +from br.models.compute_features import get_embeddings +from br.models.load_models import _load_model_from_path +from br.models.utils import get_all_configs_per_dataset +from br.features.reconstruction import stratified_latent_walk +import argparse +from br.features.archetype import AA_Fast + + +def main(args): + _setup_gpu() + device = "cuda:0" + + config_path = os.environ.get("CYTODL_CONFIG_PATH") + results_path = config_path + "/results/" + + run_name = "Rotation_invariant_pointcloud_jitter" + DATASET_INFO = get_all_configs_per_dataset(results_path) + models = DATASET_INFO[args.dataset_name] + checkpoints = models["model_checkpoints"] + checkpoints = [i for i in checkpoints if run_name in i] + assert len(checkpoints) == 1 + all_ret, df = get_embeddings([run_name], args.dataset_name, DATASET_INFO, args.embeddings_path) + model, x_label, latent_dim, model_size = _load_model_from_path(checkpoints[0], False, device) + + all_ret, stratify_key, n_archetypes, viz_params = _dataset_specific_subsetting( + all_ret, args.dataset_name + ) + + # Compute stratified latent walk + key = "pcloud" + this_save_path = Path(args.save_path) / Path("latent_walks") + this_save_path.mkdir(parents=True, exist_ok=True) + + stratified_latent_walk( + model, + device, + all_ret, + "pcloud", + 256, + 256, + 2, + this_save_path, + stratify_key, + latent_walk_range=[-2, 0, 2], + z_max=viz_params['z_max'], + z_ind=viz_params['z_ind'], + ) + + # Save reconstruction plots + _latent_walk_save_recons(this_save_path, stratify_key, viz_params) + + # Archetype analysis + matrix = all_ret[[i for i in all_ret.columns if "mu" in i]].values + aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix) + archetypes_df = pd.DataFrame(aa.Z, columns=[f"mu_{i}" for i in range(matrix.shape[1])]) + + this_save_path = Path(args.save_path) / Path("archetypes") + this_save_path.mkdir(parents=True, exist_ok=True) + + _archetypes_save_recons( + model, archetypes_df, device, key, viz_params, this_save_path + ) + + # Pseudotime analysis + if "volume_of_nucleus_um3" in all_ret.columns: + _pseudo_time_analysis(model, all_ret, args.save_path, device, key, viz_params) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing embeddings") + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument( + "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + + args = parser.parse_args() + + # Validate that required paths are provided + if not args.save_path or not args.embeddings_path: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + + """ + Example run: + python src/br/analysis/punctate_analysis.py --save_path "./testing/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" + """ diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index ec8daae..4b667a1 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -2,6 +2,7 @@ import argparse import os import sys + from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings diff --git a/src/br/features/plot.py b/src/br/features/plot.py index f9e7ea5..b9fd7bf 100644 --- a/src/br/features/plot.py +++ b/src/br/features/plot.py @@ -283,7 +283,7 @@ def plot_pc_saved( if "inorm" not in df.columns: df, cmap, _, _ = normalize_intensities_and_get_colormap(df=df, pcts=[5, 95]) else: - cmap = "inferno" + cmap = "YlGnBu" for sub_key in df[key].unique(): df_sub = df.loc[df[key] == sub_key] From 178a826cb851b7833c6657cb4e7df7780192ca8a Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 16:49:54 -0800 Subject: [PATCH 13/35] working analysis script for punctate structures --- src/br/analysis/analysis_utils.py | 117 ++++++++++++++++++++++----- src/br/analysis/punctate_analysis.py | 43 ++++++---- 2 files changed, 121 insertions(+), 39 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 1c0d7db..b4f80a3 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -3,9 +3,11 @@ import subprocess from pathlib import Path +import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch +import yaml from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud @@ -84,7 +86,7 @@ def _setup_gpu(): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() - selected_gpu_id_or_uuid = 'MIG-ffdee303-0dd4-513d-b18c-beba028b49c7' + selected_gpu_id_or_uuid = "MIG-ffdee303-0dd4-513d-b18c-beba028b49c7" # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: @@ -147,9 +149,7 @@ def _setup_evaluation_params(manifest, run_names): def _setup_evolve_params(run_names, data_config_list, keys): - """ - Set up dataloader parameters specific to the evolution energy metric - """ + """Set up dataloader parameters specific to the evolution energy metric.""" eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) compute_evolve_dataloaders = True @@ -183,7 +183,7 @@ def _setup_evolve_params(run_names, data_config_list, keys): def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): """ Get parameters associated with calculation of - 1. Rot invariance + 1. Rot invariance 2. Compactness 3. Classification 4. Evolution/Interpolation distance @@ -219,10 +219,9 @@ def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): def _dataset_specific_subsetting(all_ret, dataset_name): - """ - Subset each dataset for analysis. - E.g. For PCNA dataset, only look at interphase. - Also specify dataset specific visualization params + """Subset each dataset for analysis. E.g. For PCNA dataset, only look at interphase. Also + specify dataset specific visualization params. + - z_max (Max value of z at which to slice data) - z_ind (Which index is Z - 1, 2, 3) - views = ['xy'] (show xy projection) @@ -252,9 +251,10 @@ def _dataset_specific_subsetting(all_ret, dataset_name): n_archetypes = 6 elif dataset_name == "other_punctate": structs = ["NUP153", "SON", "HIST1H2BJ", "SMC1A", "CETN2", "SLC25A17", "RAB5A"] - all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)].reset_index(drop=True) + all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)] + all_ret = all_ret.loc[all_ret["cell_stage"].isin(["M0"])].reset_index(drop=True) stratify_key = "structure_name" - viz_params = {"z_max": None, "z_ind": 2, "flip": False} + viz_params = {"z_max": None, "z_ind": 2, "flip": False, "structs": structs} n_archetypes = 7 else: raise ValueError("Dataset not in pre-configured list") @@ -264,17 +264,79 @@ def _dataset_specific_subsetting(all_ret, dataset_name): return all_ret, stratify_key, n_archetypes, viz_params -def _latent_walk_save_recons(this_save_path, stratify_key, viz_params): - """ - Visualize saved latent walks from csvs +def _viz_other_punctate(this_save_path, viz_params, stratify_key): + # Norms based on Viana 2023 + # norms used for model training + model_norms = "./src/br/data/preprocessing/pc_preprocessing/model_structnorms.yaml" + with open(model_norms) as stream: + model_norms = yaml.safe_load(stream) + + # norms used for viz + viz_norms = "./src/br/data/preprocessing/pc_preprocessing/viz_structnorms.yaml" + with open(viz_norms) as stream: + viz_norms = yaml.safe_load(stream) + + items = os.listdir(this_save_path) + for struct in viz_params["structs"]: + fnames = [i for i in items if i.split(".")[-1] == "csv"] + fnames = [i for i in fnames if i.split("_")[1] == "0"] + fnames = [i for i in fnames if i.split("_")[0] in [struct]] + names = [i.split(".")[0] for i in fnames] + + renorm = model_norms[struct] + this_viz_norm = viz_norms[struct] + use_vmin = this_viz_norm[0] + use_vmax = this_viz_norm[1] + + all_df = [] + for idx, _ in enumerate(fnames): + fname = fnames[idx] + df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + df["s"] = df["s"] / 10 # scalar values were scaled by 10 during training + df["s"] = df["s"] * (renorm[1] - renorm[0]) + renorm[0] # use model norms + df[stratify_key] = names[idx] + all_df.append(df) + df = pd.concat(all_df, axis=0).reset_index(drop=True) + if struct in ["NUP153", "SON", "HIST1H2BJ", "SMC1A"]: + df = df.loc[df["z"] < 0.2].reset_index(drop=True) + df = normalize_intensities_and_get_colormap_apply(df, use_vmin, use_vmax) + cmap = plt.get_cmap("YlGnBu") + plot_stratified_pc( + df, + viz_params["xlim"], + viz_params["ylim"], + stratify_key, + this_save_path, + cmap, + viz_params["flip"], + ) + + for pc_bin in df[stratify_key].unique(): + this_df = df.loc[df[stratify_key] == pc_bin].reset_index(drop=True) + print(this_df.shape, struct, pc_bin) + np_arr = this_df[["x", "y", "z"]].values + colors = cmap(this_df["inorm"].values)[:, :3] + np_arr2 = colors + np_arr = np.concatenate([np_arr, np_arr2], axis=1) + np.save(this_save_path / Path(f"{stratify_key}_{pc_bin}.npy"), np_arr) + cmap = plt.get_cmap("YlGnBu") + + +def _latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_name): + """Visualize saved latent walks from csvs. + this_save_path - folder where csvs are saved stratify_key - metadata by which PCs are stratified (e.g. "rule") viz_params - parameters associated with visualization (e.g. xlims, ylims) """ + if dataset_name == "other_punctate": + return _viz_other_punctate(this_save_path, viz_params, stratify_key) + items = os.listdir(this_save_path) - fnames = [i for i in items if i.split(".")[-1] == "csv"] # get csvs - fnames = [i for i in fnames if i.split("_")[1] == "0"] # get 1st PC + fnames = [i for i in items if i.split(".")[-1] == "csv"] # get csvs + fnames = [i for i in fnames if i.split("_")[1] == "0"] # get 1st PC names = [i.split(".")[0] for i in fnames] + cm_name = "YlGnBu" all_df = [] for idx, _ in enumerate(fnames): @@ -297,11 +359,24 @@ def _latent_walk_save_recons(this_save_path, stratify_key, viz_params): viz_params["flip"], ) + df, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df, pcts=[5, 95], cm_name="YlGnBu" + ) + + for idx, _ in enumerate(fnames): + fname = fnames[idx] + df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) + this_name = names[idx] + df = normalize_intensities_and_get_colormap_apply(df, vmin, vmax) + np_arr = df[["x", "y", "z"]].values + colors = cmap(df["inorm"].values)[:, :3] + np_arr2 = colors + np_arr = np.concatenate([np_arr, np_arr2], axis=1) + np.save(this_save_path / Path(f"{this_name}.npy"), np_arr) + def _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path): - """ - Visualize saved archetypes from archetype matrix dataframe - """ + """Visualize saved archetypes from archetype matrix dataframe.""" all_xhat = [] with torch.no_grad(): for i in range(len(archetypes_df)): @@ -349,9 +424,7 @@ def _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bins=None): - """ - Psuedotime analysis for PCNA and NPM1 dataset - """ + """Psuedotime analysis for PCNA and NPM1 dataset.""" if not bins: # Pseudotime bins based on npm1 dataset from WTC-11 hIPS single cell image dataset bins = [ diff --git a/src/br/analysis/punctate_analysis.py b/src/br/analysis/punctate_analysis.py index 5ad4c03..f640a35 100644 --- a/src/br/analysis/punctate_analysis.py +++ b/src/br/analysis/punctate_analysis.py @@ -1,20 +1,22 @@ +import argparse import os +import sys from pathlib import Path + import pandas as pd -import sys + from br.analysis.analysis_utils import ( - _setup_gpu, - _latent_walk_save_recons, + _archetypes_save_recons, _dataset_specific_subsetting, - _archetypes_save_recons, - _pseudo_time_analysis + _latent_walk_save_recons, + _pseudo_time_analysis, + _setup_gpu, ) +from br.features.archetype import AA_Fast +from br.features.reconstruction import stratified_latent_walk from br.models.compute_features import get_embeddings from br.models.load_models import _load_model_from_path from br.models.utils import get_all_configs_per_dataset -from br.features.reconstruction import stratified_latent_walk -import argparse -from br.features.archetype import AA_Fast def main(args): @@ -24,7 +26,7 @@ def main(args): config_path = os.environ.get("CYTODL_CONFIG_PATH") results_path = config_path + "/results/" - run_name = "Rotation_invariant_pointcloud_jitter" + run_name = args.run_name DATASET_INFO = get_all_configs_per_dataset(results_path) models = DATASET_INFO[args.dataset_name] checkpoints = models["model_checkpoints"] @@ -53,12 +55,12 @@ def main(args): this_save_path, stratify_key, latent_walk_range=[-2, 0, 2], - z_max=viz_params['z_max'], - z_ind=viz_params['z_ind'], + z_max=viz_params["z_max"], + z_ind=viz_params["z_ind"], ) # Save reconstruction plots - _latent_walk_save_recons(this_save_path, stratify_key, viz_params) + _latent_walk_save_recons(this_save_path, stratify_key, viz_params, args.dataset_name) # Archetype analysis matrix = all_ret[[i for i in all_ret.columns if "mu" in i]].values @@ -68,9 +70,7 @@ def main(args): this_save_path = Path(args.save_path) / Path("archetypes") this_save_path.mkdir(parents=True, exist_ok=True) - _archetypes_save_recons( - model, archetypes_df, device, key, viz_params, this_save_path - ) + _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path) # Pseudotime analysis if "volume_of_nucleus_um3" in all_ret.columns: @@ -82,6 +82,7 @@ def main(args): parser.add_argument( "--save_path", type=str, required=True, help="Path to save the embeddings." ) + parser.add_argument("--run_name", type=str, required=True, help="Name of model") parser.add_argument( "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." ) @@ -97,6 +98,14 @@ def main(args): main(args) """ - Example run: - python src/br/analysis/punctate_analysis.py --save_path "./testing/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" + Example runs for each dataset: + + cellpack dataset + python src/br/analysis/punctate_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" + + pcna dataset + python src/br/analysis/punctate_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" + + Other punctate structures dataset: + python src/br/analysis/punctate_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" """ From 4faae2b5f120d28559a2c17cd68c1ae9f6bda662 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 17:12:43 -0800 Subject: [PATCH 14/35] add sdf analysis utils and merge script --- src/br/analysis/analysis_utils.py | 65 +++++++++++++++++++ .../{punctate_analysis.py => run_analysis.py} | 56 ++++++++++------ 2 files changed, 101 insertions(+), 20 deletions(-) rename src/br/analysis/{punctate_analysis.py => run_analysis.py} (73%) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index b4f80a3..c806794 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -6,8 +6,11 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import pyvista as pv import torch import yaml +from sklearn.decomposition import PCA +from tqdm import tqdm from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud @@ -256,6 +259,16 @@ def _dataset_specific_subsetting(all_ret, dataset_name): stratify_key = "structure_name" viz_params = {"z_max": None, "z_ind": 2, "flip": False, "structs": structs} n_archetypes = 7 + elif dataset_name == 'npm1': + stratify_key = 'STR_connectivity_cc_thresh' + n_archetypes = 5 + viz_params = None + elif dataset_name == 'other_polymorphic': + stratify_key = 'structure_name' + structs = ["NPM1", "FBL", "LAMP1", "ST6GAL1"] + all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)] + n_archetypes = 4 + viz_params = None else: raise ValueError("Dataset not in pre-configured list") viz_params["views"] = ["xy"] @@ -485,3 +498,55 @@ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bi viz_params["xlim"], viz_params["ylim"], ) + + +def _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, latent_dim): + lw_dict = {stratify_key: [], "PC": [], "bin": [], "CellId": []} + for strat in all_ret[stratify_key].unique(): + this_sub_m = all_ret.loc[all_ret[stratify_key] == strat].reset_index(drop=True) + all_features = this_sub_m[[i for i in this_sub_m.columns if "mu" in i]].values + dim_size = latent_dim + pca = PCA(n_components=dim_size) + pca_features = pca.fit_transform(all_features) + pca_std_list = pca_features.std(axis=0) + for rank in [0, 1]: + latent_walk_range = [-2, 0, 2] + for value_index, value in enumerate( + tqdm(latent_walk_range, total=len(latent_walk_range)) + ): + z_inf = torch.zeros(1, dim_size) + z_inf[:, rank] += value * pca_std_list[rank] + z_inf = pca.inverse_transform(z_inf).numpy() + + dist = (all_features - z_inf) ** 2 + dist = np.sum(dist, axis=1) + closest_idx = np.argmin(dist) + closest_real_id = this_sub_m.iloc[closest_idx]["CellId"] + mesh = pv.read( + all_ret.loc[all_ret["CellId"] == closest_real_id]["mesh_path"].iloc[0] + ) + mesh.save(this_save_path / Path(f"{strat}_{rank}_{value_index}.ply")) + + lw_dict[stratify_key].append(strat) + lw_dict["PC"].append(rank) + lw_dict["bin"].append(value_index) + lw_dict["CellId"].append(closest_real_id) + lw_dict = pd.DataFrame(lw_dict) + lw_dict.to_csv(this_save_path / "latent_walk.csv") + + +def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features): + arch_dict = {"CellId": [], "archetype": []} + for i in range(len(archetypes_df)): + this_mu = archetypes_df.iloc[i].values + dist = (all_features - this_mu) ** 2 + dist = np.sum(dist, axis=1) + closest_idx = np.argmin(dist) + closest_real_id = all_ret.iloc[closest_idx]["CellId"] + print(dist, closest_real_id) + mesh = pv.read(all_ret.loc[all_ret["CellId"] == closest_real_id]["mesh_path"].iloc[0]) + mesh.save(this_save_path / Path(f"{i}.ply")) + arch_dict["archetype"].append(i) + arch_dict["CellId"].append(closest_real_id) + arch_dict = pd.DataFrame(arch_dict) + arch_dict.to_csv(this_save_path / "archetypes.csv") diff --git a/src/br/analysis/punctate_analysis.py b/src/br/analysis/run_analysis.py similarity index 73% rename from src/br/analysis/punctate_analysis.py rename to src/br/analysis/run_analysis.py index f640a35..aed3db1 100644 --- a/src/br/analysis/punctate_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -6,8 +6,10 @@ import pandas as pd from br.analysis.analysis_utils import ( + _archetypes_polymorphic, _archetypes_save_recons, _dataset_specific_subsetting, + _latent_walk_polymorphic, _latent_walk_save_recons, _pseudo_time_analysis, _setup_gpu, @@ -40,27 +42,30 @@ def main(args): ) # Compute stratified latent walk - key = "pcloud" + key = "pcloud" # all analysis on pointcloud models this_save_path = Path(args.save_path) / Path("latent_walks") this_save_path.mkdir(parents=True, exist_ok=True) - stratified_latent_walk( - model, - device, - all_ret, - "pcloud", - 256, - 256, - 2, - this_save_path, - stratify_key, - latent_walk_range=[-2, 0, 2], - z_max=viz_params["z_max"], - z_ind=viz_params["z_ind"], - ) - - # Save reconstruction plots - _latent_walk_save_recons(this_save_path, stratify_key, viz_params, args.dataset_name) + if args.sdf: + _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, latent_dim) + else: + stratified_latent_walk( + model, + device, + all_ret, + "pcloud", + latent_dim, + latent_dim, + 2, + this_save_path, + stratify_key, + latent_walk_range=[-2, 0, 2], + z_max=viz_params["z_max"], + z_ind=viz_params["z_ind"], + ) + + # Save reconstruction plots + _latent_walk_save_recons(this_save_path, stratify_key, viz_params, args.dataset_name) # Archetype analysis matrix = all_ret[[i for i in all_ret.columns if "mu" in i]].values @@ -70,7 +75,10 @@ def main(args): this_save_path = Path(args.save_path) / Path("archetypes") this_save_path.mkdir(parents=True, exist_ok=True) - _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path) + if args.sdf: + _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, matrix) + else: + _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path) # Pseudotime analysis if "volume_of_nucleus_um3" in all_ret.columns: @@ -87,7 +95,12 @@ def main(args): "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") - + parser.add_argument( + "--sdf", + type=bool, + required=True, + help="boolean indicating whether the model involves SDFs", + ) args = parser.parse_args() # Validate that required paths are provided @@ -108,4 +121,7 @@ def main(args): Other punctate structures dataset: python src/br/analysis/punctate_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" + + npm1 dataset: + python src/br/analysis/punctate_analysis.py --save_path "./testing_npm/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1/" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" """ From 5e9e07b0785e921fcd06a0130c3c3b29b4dda603 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 17:21:47 -0800 Subject: [PATCH 15/35] fix nonetype error --- src/br/analysis/analysis_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index c806794..93fe2ca 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -262,13 +262,13 @@ def _dataset_specific_subsetting(all_ret, dataset_name): elif dataset_name == 'npm1': stratify_key = 'STR_connectivity_cc_thresh' n_archetypes = 5 - viz_params = None + viz_params = {} elif dataset_name == 'other_polymorphic': stratify_key = 'structure_name' structs = ["NPM1", "FBL", "LAMP1", "ST6GAL1"] all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)] n_archetypes = 4 - viz_params = None + viz_params = {} else: raise ValueError("Dataset not in pre-configured list") viz_params["views"] = ["xy"] From eceefa49f0ec029b3889926ebde862b5d985aebd Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Thu, 21 Nov 2024 17:33:33 -0800 Subject: [PATCH 16/35] working sdf analysis runs --- src/br/analysis/analysis_utils.py | 10 +++++----- src/br/analysis/run_analysis.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 93fe2ca..0e518ec 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -502,6 +502,7 @@ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bi def _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, latent_dim): lw_dict = {stratify_key: [], "PC": [], "bin": [], "CellId": []} + mesh_folder = all_ret['mesh_folder'].iloc[0] # mesh folder for strat in all_ret[stratify_key].unique(): this_sub_m = all_ret.loc[all_ret[stratify_key] == strat].reset_index(drop=True) all_features = this_sub_m[[i for i in this_sub_m.columns if "mu" in i]].values @@ -516,15 +517,13 @@ def _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, lat ): z_inf = torch.zeros(1, dim_size) z_inf[:, rank] += value * pca_std_list[rank] - z_inf = pca.inverse_transform(z_inf).numpy() + z_inf = pca.inverse_transform(z_inf) dist = (all_features - z_inf) ** 2 dist = np.sum(dist, axis=1) closest_idx = np.argmin(dist) closest_real_id = this_sub_m.iloc[closest_idx]["CellId"] - mesh = pv.read( - all_ret.loc[all_ret["CellId"] == closest_real_id]["mesh_path"].iloc[0] - ) + mesh = pv.read(mesh_folder + str(closest_real_id) + '.stl') mesh.save(this_save_path / Path(f"{strat}_{rank}_{value_index}.ply")) lw_dict[stratify_key].append(strat) @@ -537,6 +536,7 @@ def _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, lat def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features): arch_dict = {"CellId": [], "archetype": []} + mesh_folder = all_ret['mesh_folder'].iloc[0] # mesh folder for i in range(len(archetypes_df)): this_mu = archetypes_df.iloc[i].values dist = (all_features - this_mu) ** 2 @@ -544,7 +544,7 @@ def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features closest_idx = np.argmin(dist) closest_real_id = all_ret.iloc[closest_idx]["CellId"] print(dist, closest_real_id) - mesh = pv.read(all_ret.loc[all_ret["CellId"] == closest_real_id]["mesh_path"].iloc[0]) + mesh = pv.read(mesh_folder + str(closest_real_id) + '.stl') mesh.save(this_save_path / Path(f"{i}.ply")) arch_dict["archetype"].append(i) arch_dict["CellId"].append(closest_real_id) diff --git a/src/br/analysis/run_analysis.py b/src/br/analysis/run_analysis.py index aed3db1..bedb8e8 100644 --- a/src/br/analysis/run_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -114,14 +114,17 @@ def main(args): Example runs for each dataset: cellpack dataset - python src/br/analysis/punctate_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" + python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" pcna dataset - python src/br/analysis/punctate_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" + python src/br/analysis/run_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" - Other punctate structures dataset: - python src/br/analysis/punctate_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" + other punctate structures dataset: + python src/br/analysis/run_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" npm1 dataset: - python src/br/analysis/punctate_analysis.py --save_path "./testing_npm/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1/" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" + python src/br/analysis/run_analysis.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1/" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True + + other polymorphic dataset: + python src/br/analysis/run_analysis.py --save_path "./outputs_other_polymorphic/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_polymorphic/" --dataset_name "other_polymorphic" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True """ From cfd4eb75a6dd788fb7855f2ce5d5ec81936183c3 Mon Sep 17 00:00:00 2001 From: Fatwir Sheikh Mohammed <81345858+fatwir@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:31:07 -0800 Subject: [PATCH 17/35] Update data paths in the results config for other_polymorphic.yaml --- configs/results/other_polymorphic.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index 76b9455..cccc773 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -19,11 +19,11 @@ names: ] data_paths: [ - "/other_polymorphic/pc.yaml", - "/other_polymorphic/so3_image_sdf.yaml", - "/other_polymorphic/so3_image_seg.yaml", - "/other_polymorphic/classical_image_sdf.yaml", - "/other_polymorphic/classical_image_seg.yaml", + "/data/other_polymorphic/pc.yaml", + "/data/other_polymorphic/so3_image_sdf.yaml", + "/data/other_polymorphic/so3_image_seg.yaml", + "/data/other_polymorphic/classical_image_sdf.yaml", + "/data/other_polymorphic/classical_image_seg.yaml", ] classification_label: ["structure_name"] regression_label: ["avg_dists", "mean_volume", "mean_surface_area"] From 18945d274148e36bcd2ddf60f5c7e25eaa25f4c6 Mon Sep 17 00:00:00 2001 From: Fatwir Sheikh Mohammed <81345858+fatwir@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:41:39 -0800 Subject: [PATCH 18/35] Update data paths in the results config for npm1_perturb.yaml --- configs/results/npm1_perturb.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/results/npm1_perturb.yaml b/configs/results/npm1_perturb.yaml index 5a5a5ea..492e61b 100644 --- a/configs/results/npm1_perturb.yaml +++ b/configs/results/npm1_perturb.yaml @@ -19,9 +19,9 @@ names: ] data_paths: [ - "./configs/data/npm1_perturb/pc.yaml", - "./configs/data/npm1_perturb/classical_image_sdf.yaml", - "./configs/data/npm1_perturb/classical_image_seg.yaml", - "./configs/data/npm1_perturb/so3_image_sdf.yaml", - "./configs/data/npm1_perturb/so3_image_seg.yaml", + "/data/npm1_perturb/pc.yaml", + "/data/npm1_perturb/classical_image_sdf.yaml", + "/data/npm1_perturb/classical_image_seg.yaml", + "/data/npm1_perturb/so3_image_sdf.yaml", + "/data/npm1_perturb/so3_image_seg.yaml", ] From a8b4bb86e6216fca05b11c3e37349fdfe61c104f Mon Sep 17 00:00:00 2001 From: Fatwir Sheikh Mohammed <81345858+fatwir@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:45:24 -0800 Subject: [PATCH 19/35] Update data paths in the results config for other_punctate.yaml --- configs/results/other_punctate.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index 5b1772b..9c6ff8a 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -22,7 +22,7 @@ data_paths: "/data/other_punctate/image.yaml", "/data/other_punctate/image.yaml", "/data/other_punctate/pc.yaml", - "/data/other_punctate/pc_intensity.yaml", + "/data/other_punctate/pc_intensity_structurenorm.yaml", "/data/other_punctate/pc_intensity_structurenorm.yaml", ] classification_label: ["structure_name", "cell_stage"] From 06228bbd31dbef3de1f7077f38905e819e3d8d03 Mon Sep 17 00:00:00 2001 From: Fatwir Sheikh Mohammed <81345858+fatwir@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:32:34 -0800 Subject: [PATCH 20/35] Update other_punctate.yaml Reverted this change! --- configs/results/other_punctate.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index 9c6ff8a..5b1772b 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -22,7 +22,7 @@ data_paths: "/data/other_punctate/image.yaml", "/data/other_punctate/image.yaml", "/data/other_punctate/pc.yaml", - "/data/other_punctate/pc_intensity_structurenorm.yaml", + "/data/other_punctate/pc_intensity.yaml", "/data/other_punctate/pc_intensity_structurenorm.yaml", ] classification_label: ["structure_name", "cell_stage"] From 11ac09d76253318bcbe2b3af746ae8dc5be78d68 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 22 Nov 2024 10:46:48 -0800 Subject: [PATCH 21/35] add str2bool check --- src/br/analysis/analysis_utils.py | 16 +++++++++++++--- src/br/analysis/run_analysis.py | 7 +++---- src/br/analysis/run_embeddings.py | 5 ++--- src/br/analysis/run_features.py | 5 ++--- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 0e518ec..26adfe2 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -2,7 +2,6 @@ import os import subprocess from pathlib import Path - import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -11,7 +10,7 @@ import yaml from sklearn.decomposition import PCA from tqdm import tqdm - +import argparse from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud from br.features.utils import ( @@ -21,6 +20,17 @@ from br.models.utils import get_all_configs_per_dataset +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + def get_gpu_info(): # Run nvidia-smi command and get the output cmd = [ @@ -500,7 +510,7 @@ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bi ) -def _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, latent_dim): +def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): lw_dict = {stratify_key: [], "PC": [], "bin": [], "CellId": []} mesh_folder = all_ret['mesh_folder'].iloc[0] # mesh folder for strat in all_ret[stratify_key].unique(): diff --git a/src/br/analysis/run_analysis.py b/src/br/analysis/run_analysis.py index bedb8e8..a94d7b5 100644 --- a/src/br/analysis/run_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -2,9 +2,7 @@ import os import sys from pathlib import Path - import pandas as pd - from br.analysis.analysis_utils import ( _archetypes_polymorphic, _archetypes_save_recons, @@ -13,6 +11,7 @@ _latent_walk_save_recons, _pseudo_time_analysis, _setup_gpu, + str2bool, ) from br.features.archetype import AA_Fast from br.features.reconstruction import stratified_latent_walk @@ -47,7 +46,7 @@ def main(args): this_save_path.mkdir(parents=True, exist_ok=True) if args.sdf: - _latent_walk_polymorphic(stratify_key, all_ret, x_label, this_save_path, latent_dim) + _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim) else: stratified_latent_walk( model, @@ -97,7 +96,7 @@ def main(args): parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument( "--sdf", - type=bool, + type=str2bool, required=True, help="boolean indicating whether the model involves SDFs", ) diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 4b667a1..4cc5d5b 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -2,8 +2,7 @@ import argparse import os import sys - -from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu +from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu, str2bool from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings @@ -63,7 +62,7 @@ def main(args): ) parser.add_argument( "--sdf", - type=bool, + type=str2bool, required=True, help="boolean indicating whether the experiments involve SDFs", ) diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 69a17ac..9656c01 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -2,13 +2,12 @@ import argparse import os import sys - import pandas as pd - from br.analysis.analysis_utils import ( _get_feature_params, _setup_evaluation_params, _setup_gpu, + str2bool, ) from br.features.plot import collect_outputs, plot from br.models.compute_features import compute_features @@ -146,7 +145,7 @@ def main(args): ) parser.add_argument( "--sdf", - type=bool, + type=str2bool, required=True, help="boolean indicating whether the experiments involve SDFs", ) From ad6c964e9f7605be7d0b117f6d141c6d3662b7bc Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 22 Nov 2024 10:48:16 -0800 Subject: [PATCH 22/35] rename npm1 experiment --- .../{npm1_variance => npm1}/image_sdf_classical.yaml | 0 configs/experiment/{npm1_variance => npm1}/image_sdf_so3.yaml | 0 .../{npm1_variance => npm1}/image_seg_classical.yaml | 0 configs/experiment/{npm1_variance => npm1}/image_seg_so3.yaml | 0 configs/experiment/{npm1_variance => npm1}/pc_implicit.yaml | 0 configs/results/other_punctate.yaml | 3 --- 6 files changed, 3 deletions(-) rename configs/experiment/{npm1_variance => npm1}/image_sdf_classical.yaml (100%) rename configs/experiment/{npm1_variance => npm1}/image_sdf_so3.yaml (100%) rename configs/experiment/{npm1_variance => npm1}/image_seg_classical.yaml (100%) rename configs/experiment/{npm1_variance => npm1}/image_seg_so3.yaml (100%) rename configs/experiment/{npm1_variance => npm1}/pc_implicit.yaml (100%) diff --git a/configs/experiment/npm1_variance/image_sdf_classical.yaml b/configs/experiment/npm1/image_sdf_classical.yaml similarity index 100% rename from configs/experiment/npm1_variance/image_sdf_classical.yaml rename to configs/experiment/npm1/image_sdf_classical.yaml diff --git a/configs/experiment/npm1_variance/image_sdf_so3.yaml b/configs/experiment/npm1/image_sdf_so3.yaml similarity index 100% rename from configs/experiment/npm1_variance/image_sdf_so3.yaml rename to configs/experiment/npm1/image_sdf_so3.yaml diff --git a/configs/experiment/npm1_variance/image_seg_classical.yaml b/configs/experiment/npm1/image_seg_classical.yaml similarity index 100% rename from configs/experiment/npm1_variance/image_seg_classical.yaml rename to configs/experiment/npm1/image_seg_classical.yaml diff --git a/configs/experiment/npm1_variance/image_seg_so3.yaml b/configs/experiment/npm1/image_seg_so3.yaml similarity index 100% rename from configs/experiment/npm1_variance/image_seg_so3.yaml rename to configs/experiment/npm1/image_seg_so3.yaml diff --git a/configs/experiment/npm1_variance/pc_implicit.yaml b/configs/experiment/npm1/pc_implicit.yaml similarity index 100% rename from configs/experiment/npm1_variance/pc_implicit.yaml rename to configs/experiment/npm1/pc_implicit.yaml diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index 5b1772b..3ea855b 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -6,7 +6,6 @@ model_checkpoints: "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_pointcloud.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud_structurenorm.ckpt", ] names: @@ -14,7 +13,6 @@ names: "Classical_image", "Rotation_invariant_image", "Classical_pointcloud", - "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_structurenorm", ] data_paths: @@ -22,7 +20,6 @@ data_paths: "/data/other_punctate/image.yaml", "/data/other_punctate/image.yaml", "/data/other_punctate/pc.yaml", - "/data/other_punctate/pc_intensity.yaml", "/data/other_punctate/pc_intensity_structurenorm.yaml", ] classification_label: ["structure_name", "cell_stage"] From 83712278e2bb6e419c0050442e090f4adaa741f2 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 22 Nov 2024 11:42:44 -0800 Subject: [PATCH 23/35] fix cellpack analysis error (no s key) --- src/br/analysis/analysis_utils.py | 39 +++++++++++++++++-------------- src/br/analysis/run_analysis.py | 10 ++++---- src/br/analysis/run_embeddings.py | 3 ++- src/br/analysis/run_features.py | 4 +++- 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 26adfe2..cf78c99 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -1,7 +1,9 @@ +import argparse import gc import os import subprocess from pathlib import Path + import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -10,7 +12,7 @@ import yaml from sklearn.decomposition import PCA from tqdm import tqdm -import argparse + from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud from br.features.utils import ( @@ -23,12 +25,12 @@ def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") def get_gpu_info(): @@ -269,12 +271,12 @@ def _dataset_specific_subsetting(all_ret, dataset_name): stratify_key = "structure_name" viz_params = {"z_max": None, "z_ind": 2, "flip": False, "structs": structs} n_archetypes = 7 - elif dataset_name == 'npm1': - stratify_key = 'STR_connectivity_cc_thresh' + elif dataset_name == "npm1": + stratify_key = "STR_connectivity_cc_thresh" n_archetypes = 5 viz_params = {} - elif dataset_name == 'other_polymorphic': - stratify_key = 'structure_name' + elif dataset_name == "other_polymorphic": + stratify_key = "structure_name" structs = ["NPM1", "FBL", "LAMP1", "ST6GAL1"] all_ret = all_ret.loc[all_ret["structure_name"].isin(structs)] n_archetypes = 4 @@ -390,11 +392,14 @@ def _latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_n fname = fnames[idx] df = pd.read_csv(f"{this_save_path}/{fname}", index_col=0) this_name = names[idx] - df = normalize_intensities_and_get_colormap_apply(df, vmin, vmax) - np_arr = df[["x", "y", "z"]].values - colors = cmap(df["inorm"].values)[:, :3] - np_arr2 = colors - np_arr = np.concatenate([np_arr, np_arr2], axis=1) + if "s" in df.columns: + df = normalize_intensities_and_get_colormap_apply(df, vmin, vmax) + np_arr = df[["x", "y", "z"]].values + colors = cmap(df["inorm"].values)[:, :3] + np_arr2 = colors + np_arr = np.concatenate([np_arr, np_arr2], axis=1) + else: + np_arr = df[["x", "y", "z"]].values np.save(this_save_path / Path(f"{this_name}.npy"), np_arr) @@ -512,7 +517,7 @@ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bi def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): lw_dict = {stratify_key: [], "PC": [], "bin": [], "CellId": []} - mesh_folder = all_ret['mesh_folder'].iloc[0] # mesh folder + mesh_folder = all_ret["mesh_folder"].iloc[0] # mesh folder for strat in all_ret[stratify_key].unique(): this_sub_m = all_ret.loc[all_ret[stratify_key] == strat].reset_index(drop=True) all_features = this_sub_m[[i for i in this_sub_m.columns if "mu" in i]].values @@ -533,7 +538,7 @@ def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): dist = np.sum(dist, axis=1) closest_idx = np.argmin(dist) closest_real_id = this_sub_m.iloc[closest_idx]["CellId"] - mesh = pv.read(mesh_folder + str(closest_real_id) + '.stl') + mesh = pv.read(mesh_folder + str(closest_real_id) + ".stl") mesh.save(this_save_path / Path(f"{strat}_{rank}_{value_index}.ply")) lw_dict[stratify_key].append(strat) @@ -546,7 +551,7 @@ def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features): arch_dict = {"CellId": [], "archetype": []} - mesh_folder = all_ret['mesh_folder'].iloc[0] # mesh folder + mesh_folder = all_ret["mesh_folder"].iloc[0] # mesh folder for i in range(len(archetypes_df)): this_mu = archetypes_df.iloc[i].values dist = (all_features - this_mu) ** 2 @@ -554,7 +559,7 @@ def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features closest_idx = np.argmin(dist) closest_real_id = all_ret.iloc[closest_idx]["CellId"] print(dist, closest_real_id) - mesh = pv.read(mesh_folder + str(closest_real_id) + '.stl') + mesh = pv.read(mesh_folder + str(closest_real_id) + ".stl") mesh.save(this_save_path / Path(f"{i}.ply")) arch_dict["archetype"].append(i) arch_dict["CellId"].append(closest_real_id) diff --git a/src/br/analysis/run_analysis.py b/src/br/analysis/run_analysis.py index a94d7b5..9a4887a 100644 --- a/src/br/analysis/run_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -2,7 +2,9 @@ import os import sys from pathlib import Path + import pandas as pd + from br.analysis.analysis_utils import ( _archetypes_polymorphic, _archetypes_save_recons, @@ -41,7 +43,7 @@ def main(args): ) # Compute stratified latent walk - key = "pcloud" # all analysis on pointcloud models + key = "pcloud" # all analysis on pointcloud models this_save_path = Path(args.save_path) / Path("latent_walks") this_save_path.mkdir(parents=True, exist_ok=True) @@ -113,13 +115,13 @@ def main(args): Example runs for each dataset: cellpack dataset - python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" + python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False pcna dataset - python src/br/analysis/run_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" + python src/br/analysis/run_analysis.py --save_path "./outputs_pcna/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --dataset_name "pcna" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False other punctate structures dataset: - python src/br/analysis/run_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" + python src/br/analysis/run_analysis.py --save_path "./outputs_other_punctate/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/other_punctate/" --dataset_name "other_punctate" --run_name "Rotation_invariant_pointcloud_structurenorm" --sdf False npm1 dataset: python src/br/analysis/run_analysis.py --save_path "./outputs_npm1/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1/" --dataset_name "npm1" --run_name "Rotation_invariant_pointcloud_SDF" --sdf True diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 4cc5d5b..09d6571 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -2,6 +2,7 @@ import argparse import os import sys + from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu, str2bool from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings @@ -68,7 +69,7 @@ def main(args): ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") - parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") + parser.add_argument("--debug", type=str2bool, default=True, help="Enable debug mode.") args = parser.parse_args() diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 9656c01..77b5585 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -2,7 +2,9 @@ import argparse import os import sys + import pandas as pd + from br.analysis.analysis_utils import ( _get_feature_params, _setup_evaluation_params, @@ -150,7 +152,7 @@ def main(args): help="boolean indicating whether the experiments involve SDFs", ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") - parser.add_argument("--debug", type=bool, default=False, help="Enable debug mode.") + parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") args = parser.parse_args() From 3d50c0d20917eefa7690ed4b9073cff748e933dd Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 22 Nov 2024 12:43:40 -0800 Subject: [PATCH 24/35] remove notebooks and make script for drugdata analysis --- src/br/analysis/run_drugdata_analysis.py | 66 ++ src/br/chandrasekaran_et_al/utils.py | 196 ++++- src/br/notebooks/fig2_cellpack.ipynb | 510 ------------ src/br/notebooks/fig3_pcna.ipynb | 768 ------------------ src/br/notebooks/fig4_other_punctate.ipynb | 496 ----------- src/br/notebooks/fig5_npm1_analysis.ipynb | 585 ------------- src/br/notebooks/fig6_other_polymorphic.ipynb | 495 ----------- src/br/notebooks/fig7_drugdata_analysis.ipynb | 515 ------------ 8 files changed, 251 insertions(+), 3380 deletions(-) create mode 100644 src/br/analysis/run_drugdata_analysis.py delete mode 100644 src/br/notebooks/fig2_cellpack.ipynb delete mode 100644 src/br/notebooks/fig3_pcna.ipynb delete mode 100644 src/br/notebooks/fig4_other_punctate.ipynb delete mode 100644 src/br/notebooks/fig5_npm1_analysis.ipynb delete mode 100644 src/br/notebooks/fig6_other_polymorphic.ipynb delete mode 100644 src/br/notebooks/fig7_drugdata_analysis.ipynb diff --git a/src/br/analysis/run_drugdata_analysis.py b/src/br/analysis/run_drugdata_analysis.py new file mode 100644 index 0000000..d403d06 --- /dev/null +++ b/src/br/analysis/run_drugdata_analysis.py @@ -0,0 +1,66 @@ +import os +from pathlib import Path +from br.models.compute_features import get_embeddings +from br.models.utils import get_all_configs_per_dataset +from br.chandrasekaran_et_al.utils import perturbation_detection, _plot +import sys +import argparse + + +def _get_featurecols(df): + """returna list of featuredata columns""" + return [c for c in df.columns if "mu" in c] + + +def _get_featuredata(df): + """return dataframe of just featuredata columns""" + return df[_get_featurecols(df)] + + +def main(args): + + config_path = os.environ.get("CYTODL_CONFIG_PATH") + results_path = config_path + "/results/" + + dataset_name = args.dataset_name + DATASET_INFO = get_all_configs_per_dataset(results_path) + dataset = DATASET_INFO[dataset_name] + run_names = dataset['names'] + + all_ret, df = get_embeddings( + run_names, args.dataset_name, DATASET_INFO, args.embeddings_path + ) + all_ret["well_position"] = "A0" # dummy + all_ret["Assay_Plate_Barcode"] = "Plate0" # dummy + + pert = perturbation_detection(all_ret, _get_featurecols, _get_featuredata) + + this_save_path = Path(args.save_path) + this_save_path.mkdir(parents=True, exist_ok=True) + _plot(pert, this_save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing perturbation detection metrics") + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the results." + ) + parser.add_argument( + "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + args = parser.parse_args() + + # Validate that required paths are provided + if not args.save_path or not args.embeddings_path: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + + """ + Example runs for each dataset: + + cellpack dataset + python src/br/analysis/run_drugdata_analysis.py --save_path "./outputs_npm1_perturb/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_perturb/" --dataset_name "npm1_perturb" + """ \ No newline at end of file diff --git a/src/br/chandrasekaran_et_al/utils.py b/src/br/chandrasekaran_et_al/utils.py index 2864fb2..07a7e33 100644 --- a/src/br/chandrasekaran_et_al/utils.py +++ b/src/br/chandrasekaran_et_al/utils.py @@ -1,7 +1,8 @@ import glob import itertools import os - +from pathlib import Path +import seaborn as sns import copairs.compute_np as backend import numpy as np import pandas as pd @@ -16,6 +17,189 @@ from copairs.matching import dict_to_dframe from sklearn.metrics import average_precision_score from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm +import matplotlib.pyplot as plt +import pycytominer +from br.chandrasekaran_et_al import utils + + +def perturbation_detection(all_ret, get_featurecols, get_featuredata): + cols = get_featurecols(all_ret) + replicate_feature = "Metadata_broad_sample" + batch_size = 100000 + null_size = 100000 + + all_rep = [] + for model in tqdm(all_ret["model"].unique(), total=len(all_ret["model"].unique())): + df_feats = all_ret.loc[all_ret["model"] == model].reset_index(drop=True) + df_feats["Metadata_ObjectNumber"] = df_feats["CellId"] + + all_normalized_df = [] + for plate in df_feats["Assay_Plate_Barcode"].unique(): + test = df_feats.loc[df_feats["Assay_Plate_Barcode"] == plate].reset_index(drop=True) + + normalized_df = pycytominer.normalize( + profiles=test, + features=cols, + meta_features=[ + "Assay_Plate_Barcode", + "well_position", + "condition_coarse", + "condition", + ], + method="standardize", + mad_robustize_epsilon=0, + samples="all", + ) + normalized_df = pycytominer.normalize( + profiles=normalized_df, + features=cols, + meta_features=[ + "Assay_Plate_Barcode", + "well_position", + "condition_coarse", + "condition", + ], + method="standardize", + samples="condition == 'DMSO (control)'", + ) + + all_normalized_df.append(normalized_df) + df_final = pd.concat(all_normalized_df, axis=0).reset_index(drop=True) + + vals = [] + for ind, row in df_final.iterrows(): + if row["condition"] == "DMSO (control)": + vals.append("negcon") + else: + vals.append(None) + + # more dummy cols + df_final["Metadata_control_type"] = vals + df_final["Metadata_broad_sample"] = df_final["condition"] + df_final["Cell_type"] = "hIPSc" + df_final["Perturbation"] = "compound" + df_final["Time"] = "1" + df_final["Metadata_target_list"] = "none" + df_final["target_list"] = "none" + df_final["Metadata_Plate"] = "Plate0" + + experiment_df = df_final + + replicability_map_df = pd.DataFrame() + replicability_fr_df = pd.DataFrame() + + replicate_feature = "Metadata_broad_sample" + for cell in experiment_df.Cell_type.unique(): + cell_df = experiment_df.query("Cell_type==@cell") + modality_1_perturbation = "compound" + modality_1_experiments_df = cell_df.query("Perturbation==@modality_1_perturbation") + for modality_1_timepoint in modality_1_experiments_df.Time.unique(): + modality_1_timepoint_df = modality_1_experiments_df.query( + "Time==@modality_1_timepoint" + ) + modality_1_df = pd.DataFrame() + for plate in modality_1_timepoint_df.Assay_Plate_Barcode.unique(): + data_df = df_final.loc[df_final["Assay_Plate_Barcode"].isin([plate])] + data_df = data_df.drop( + columns=["Metadata_target_list", "target_list"] + ).reset_index(drop=True) + modality_1_df = utils.concat_profiles(modality_1_df, data_df) + + # Set Metadata_broad_sample value to "DMSO" for DMSO wells + modality_1_df[replicate_feature].fillna("DMSO", inplace=True) + + # Remove empty wells + modality_1_df = remove_empty_wells(modality_1_df) + + modality_1_df["Metadata_negcon"] = np.where( + modality_1_df["Metadata_control_type"] == "negcon", 1, 0 + ) # Create dummy column + + pos_sameby = ["Metadata_broad_sample"] + pos_diffby = [] + neg_sameby = ["Metadata_Plate"] + neg_diffby = ["Metadata_negcon"] + + metadata_df = get_metadata(modality_1_df) + feature_df = get_featuredata(modality_1_df) + feature_values = feature_df.values + + result = run_pipeline( + metadata_df, + feature_values, + pos_sameby, + pos_diffby, + neg_sameby, + neg_diffby, + anti_match=False, + batch_size=batch_size, + null_size=null_size, + ) + result = result.query("Metadata_negcon==0").reset_index(drop=True) + + qthreshold = 0.001 + + replicability_map_df, replicability_fr_df = create_replicability_df( + replicability_map_df, + replicability_fr_df, + result, + pos_sameby, + qthreshold, + modality_1_perturbation, + cell, + modality_1_timepoint, + ) + replicability_map_df["model"] = model + all_rep.append(replicability_map_df) + + all_rep = pd.concat(all_rep, axis=0).reset_index(drop=True) + all_rep["metric"] = "Mean average precision" + all_rep["value"] = all_rep["mean_average_precision"] + return all_rep + + +def _plot(all_rep, save_path): + sns.set_context("talk") + sns.set(font_scale=1.7) + sns.set_style("white") + + test = all_rep.sort_values(by="q_value").reset_index(drop=True) + test["Drugs"] = test["Metadata_broad_sample"] + + x_order = ( + test.loc[test["model"] == "SO3_pointcloud_SDF"] + .sort_values(by="q_value")["Metadata_broad_sample"] + .values + ) + ordered_drugs = all_rep.groupby(['Metadata_broad_sample']).mean().sort_values(by='q_value').reset_index()['Metadata_broad_sample'] + x_order = ordered_drugs + + g = sns.catplot( + data=test, + x="Drugs", + y="q_value", + hue="model", + kind="point", + order=x_order, + hue_order=[ + "Classical_image_seg", + "Rotation_invariant_image_seg", + "Classical_image_SDF", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", + ], + palette=["#A6ACE0", "#6277DB", "#D9978E", "#D8553B", "#2ED9FF"], + aspect=2, + height=5, + ) + g.set_xticklabels(rotation=90) + plt.axhline(y=0.05, color="black") + this_path = Path(save_path) + Path(this_path).mkdir(parents=True, exist_ok=True) + g.savefig(this_path / "q_values.png", dpi=300, bbox_inches="tight") + g.savefig(this_path / "q_values.pdf", dpi=300, bbox_inches="tight") + test.to_csv(this_path / "q_values.csv") def load_data(exp, plate, filetype): @@ -31,21 +215,11 @@ def get_metacols(df): return [c for c in df.columns if c.startswith("Metadata_")] -def get_featurecols(df): - """returna list of featuredata columns.""" - return [c for c in df.columns if not c.startswith("Metadata")] - - def get_metadata(df): """return dataframe of just metadata columns.""" return df[get_metacols(df)] -def get_featuredata(df): - """return dataframe of just featuredata columns.""" - return df[get_featurecols(df)] - - def remove_negcon_and_empty_wells(df): """return dataframe of non-negative control wells.""" df = ( diff --git a/src/br/notebooks/fig2_cellpack.ipynb b/src/br/notebooks/fig2_cellpack.ipynb deleted file mode 100644 index 090afd9..0000000 --- a/src/br/notebooks/fig2_cellpack.ipynb +++ /dev/null @@ -1,510 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0cee08e3-a83f-4d43-863e-4c7b897fa4b6", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-ffdee303-0dd4-513d-b18c-beba028b49c7\"\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.archetype import AA_Fast\n", - "from br.features.plot import collect_outputs, plot, plot_stratified_pc\n", - "from br.features.reconstruction import stratified_latent_walk\n", - "from br.features.utils import (\n", - " normalize_intensities_and_get_colormap,\n", - " normalize_intensities_and_get_colormap_apply,\n", - ")\n", - "from br.models.compute_features import compute_features, get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import (\n", - " get_pc_loss,\n", - " get_pc_loss_chamfer,\n", - " save_embeddings,\n", - " save_emissions,\n", - ")\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "markdown", - "id": "167ec165-8536-48a7-8c8f-5833a9f66a87", - "metadata": {}, - "source": [ - "# Load data and models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42598204-3b2f-4f35-88d1-f3513b4b8498", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_cellpack_save_embeddings/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44cf3c8d-5b5f-4014-82e9-d10637265302", - "metadata": {}, - "outputs": [], - "source": [ - "# Get datamodules, models, runs, model sizes\n", - "\n", - "dataset_name = \"cellpack\"\n", - "batch_size = 2\n", - "debug = True\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")\n", - "gg = pd.DataFrame()\n", - "gg[\"model\"] = run_names\n", - "gg[\"model_size\"] = model_sizes\n", - "gg.to_csv(save_path + \"model_sizes.csv\")" - ] - }, - { - "cell_type": "markdown", - "id": "63b85b0e-8957-4786-a3c8-d8c617117acf", - "metadata": {}, - "source": [ - "# Compute embeddings and emissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21da78a3-3254-4fba-8118-6039c84c825b", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute embeddings and reconstructions for each model\n", - "\n", - "debug = False\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = \"rule\"\n", - "eval_scaled_img = [False] * 5\n", - "eval_scaled_img_params = [{}] * 5\n", - "loss_eval_list = None\n", - "sample_points_list = [True, True, False, False, False]\n", - "skew_scale = 100\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6e31d85d-b33c-4ee0-ac00-2987cd7c5e40", - "metadata": {}, - "outputs": [], - "source": [ - "# Save emission stats for each model\n", - "\n", - "max_batches = 2\n", - "save_emissions(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " max_batches,\n", - " debug,\n", - " device,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "85533af6-d04d-424e-9e6a-e35e4470f703", - "metadata": {}, - "source": [ - "# Compute benchmarking features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e3afc433-362e-4fe2-810e-799670f7bebb", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute multi-metric benchmarking features\n", - "\n", - "keys = [\"pcloud\"] * 5\n", - "max_embed_dim = 256\n", - "DATA_LIST = get_all_configs_per_dataset(results_path)\n", - "data_config_list = DATA_LIST[dataset_name][\"data_paths\"]\n", - "\n", - "evolve_params = {\n", - " \"modality_list_evolve\": keys,\n", - " \"config_list_evolve\": data_config_list,\n", - " \"num_evolve_samples\": 40,\n", - " \"compute_evolve_dataloaders\": False,\n", - " \"eval_meshed_img\": [False] * 5,\n", - " \"skew_scale\": 100,\n", - " \"eval_meshed_img_model_type\": [None] * 5,\n", - " \"only_embedding\": False,\n", - " \"fit_pca\": False,\n", - "}\n", - "\n", - "loss_eval = get_pc_loss_chamfer()\n", - "loss_eval_list = [loss_eval] * 5\n", - "use_sample_points_list = [True, True, False, False, False]\n", - "\n", - "classification_params = {\"class_labels\": [\"rule\"]}\n", - "rot_inv_params = {\"squeeze_2d\": False, \"id\": \"cell_id\", \"max_batches\": 4000}\n", - "\n", - "regression_params = {\"df_feat\": None, \"target_cols\": None, \"feature_df_path\": None}\n", - "\n", - "compactness_params = {\n", - " \"method\": \"mle\",\n", - " \"num_PCs\": None,\n", - " \"blobby_outlier_max_cc\": None,\n", - " \"check_duplicates\": True,\n", - "}\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "compute_embeds = False\n", - "\n", - "metric_list = [\n", - " \"Rotation Invariance Error\",\n", - " \"Evolution Energy\",\n", - " \"Reconstruction\",\n", - " \"Classification\",\n", - " \"Compactness\",\n", - "]\n", - "\n", - "\n", - "compute_features(\n", - " dataset=dataset_name,\n", - " results_path=results_path,\n", - " embeddings_path=save_path,\n", - " save_folder=save_path,\n", - " data_list=data_list,\n", - " all_models=all_models,\n", - " run_names=run_names,\n", - " use_sample_points_list=use_sample_points_list,\n", - " keys=keys,\n", - " device=device,\n", - " max_embed_dim=max_embed_dim,\n", - " splits_list=splits_list,\n", - " compute_embeds=compute_embeds,\n", - " classification_params=classification_params,\n", - " regression_params=regression_params,\n", - " metric_list=metric_list,\n", - " loss_eval_list=loss_eval_list,\n", - " evolve_params=evolve_params,\n", - " rot_inv_params=rot_inv_params,\n", - " compactness_params=compactness_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c2a1c6ce-0a8e-40b2-a798-4688321ac9c0", - "metadata": {}, - "source": [ - "# Polar plot viz" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "628d547d-6556-48c3-b1fb-cc6488cf11ff", - "metadata": {}, - "outputs": [], - "source": [ - "# Holistic viz of features\n", - "\n", - "model_order = [\n", - " \"Classical_image\",\n", - " \"Rotation_invariant_image\",\n", - " \"Classical_pointcloud\",\n", - " \"Rotation_invariant_pointcloud\",\n", - "]\n", - "metric_list = [\n", - " \"reconstruction\",\n", - " \"emissions\",\n", - " \"classification_rule\",\n", - " \"compactness\",\n", - " \"evolution_energy\",\n", - " \"model_sizes\",\n", - " \"rotation_invariance_error\",\n", - "]\n", - "norm = \"std\"\n", - "title = \"cellpack_comparison\"\n", - "colors_list = None\n", - "unique_expressivity_metrics = [\"Classification_rule\"]\n", - "df, df_non_agg = collect_outputs(save_path, norm, model_order, metric_list)\n", - "plot(save_path, df, model_order, title, colors_list, norm, unique_expressivity_metrics)" - ] - }, - { - "cell_type": "markdown", - "id": "da73f075-509b-4e21-a9d4-f62bd0512a6a", - "metadata": {}, - "source": [ - "# Latent walks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea974272-d0c3-4bef-af00-ea8f1ce72c9e", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "\n", - "run_names = [\"Rotation_invariant_pointcloud_jitter\"]\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)\n", - "model = all_models[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "279ce410-20b8-4bc9-836d-ecf5b6a20c30", - "metadata": {}, - "outputs": [], - "source": [ - "# Params for viz\n", - "key = \"pcloud\"\n", - "stratify_key = \"rule\"\n", - "z_max = 0.3\n", - "z_ind = 1\n", - "flip = True\n", - "views = [\"xy\"]\n", - "xlim = [-20, 20]\n", - "ylim = [-20, 20]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "58ea8dc2-71b3-4fc0-820e-a0aa8912e9f5", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute stratified latent walk\n", - "\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "stratified_latent_walk(\n", - " model,\n", - " device,\n", - " all_ret,\n", - " \"pcloud\",\n", - " 256,\n", - " 256,\n", - " 2,\n", - " this_save_path,\n", - " stratify_key,\n", - " latent_walk_range=[-2, 0, 2],\n", - " z_max=z_max,\n", - " z_ind=z_ind,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aecc42c1-80c5-4064-984f-2d40962d2937", - "metadata": {}, - "outputs": [], - "source": [ - "# Save reconstruction plots\n", - "items = os.listdir(this_save_path)\n", - "fnames = [i for i in items if i.split(\".\")[-1] == \"csv\"]\n", - "fnames = [i for i in fnames if i.split(\"_\")[1] == \"0\"]\n", - "names = [i.split(\".\")[0] for i in fnames]\n", - "cm_name = \"inferno\"\n", - "\n", - "all_df = []\n", - "for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " df = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " df, cmap, vmin, vmax = normalize_intensities_and_get_colormap(\n", - " df, pcts=[5, 95], cm_name=cm_name\n", - " )\n", - " df[stratify_key] = names[idx]\n", - " all_df.append(df)\n", - "df = pd.concat(all_df, axis=0).reset_index(drop=True)\n", - "\n", - "plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip)" - ] - }, - { - "cell_type": "markdown", - "id": "02819e65-159b-489f-bcd4-9a47bac6b0f2", - "metadata": {}, - "source": [ - "# Archetype analysis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1ac785d-0b57-4337-8af2-01b57df36b12", - "metadata": {}, - "outputs": [], - "source": [ - "# Fit 6 archetypes\n", - "this_ret = all_ret\n", - "labels = this_ret[\"rule\"].values\n", - "matrix = this_ret[[i for i in this_ret.columns if \"mu\" in i]].values\n", - "\n", - "n_archetypes = 6\n", - "aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)\n", - "archetypes_df = pd.DataFrame(aa.Z, columns=[f\"mu_{i}\" for i in range(matrix.shape[1])])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3461c46c-11d2-444c-b380-7d0c04636fb0", - "metadata": {}, - "outputs": [], - "source": [ - "# Save reconstructions\n", - "\n", - "this_save_path = Path(save_path) / Path(\"archetypes\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "model = model.eval()\n", - "key = \"pcloud\"\n", - "all_xhat = []\n", - "with torch.no_grad():\n", - " for i in range(n_archetypes):\n", - " z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0)\n", - " z_inf = z_inf.to(device)\n", - " z_inf = z_inf.float()\n", - " decoder = model.decoder[key]\n", - " xhat = decoder(z_inf)\n", - " xhat = xhat.detach().cpu().numpy()\n", - " xhat = save_pcloud(xhat[0], this_save_path, i, z_max, z_ind)\n", - " all_xhat.append(xhat)\n", - "\n", - "\n", - "from br.features.plot import plot_pc_saved\n", - "\n", - "names = [str(i) for i in range(n_archetypes)]\n", - "key = \"archetype\"\n", - "\n", - "plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33b487d9-235a-4da3-aa4b-8263c7c5e03a", - "metadata": {}, - "outputs": [], - "source": [ - "# Save numpy arrays\n", - "\n", - "key = \"archetype\"\n", - "items = os.listdir(this_save_path)\n", - "fnames = [i for i in items if i.split(\".\")[-1] == \"csv\"]\n", - "names = [i.split(\".\")[0] for i in fnames]\n", - "\n", - "df = pd.DataFrame([])\n", - "for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " print(fname)\n", - " dft = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " dft[key] = names[idx]\n", - " df = pd.concat([df, dft], ignore_index=True)\n", - "\n", - "archetypes = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\"]\n", - "\n", - "for arch in archetypes:\n", - " this_df = df.loc[df[\"archetype\"] == arch].reset_index(drop=True)\n", - " np_arr = this_df[[\"x\", \"y\", \"z\"]].values\n", - " print(np_arr.shape)\n", - " np.save(this_save_path / Path(f\"{arch}.npy\"), np_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0951d6ad-4020-48a7-83ab-a81deaa01170", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3be8744e-e745-4717-9f41-754fd88ba98b", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/fig3_pcna.ipynb b/src/br/notebooks/fig3_pcna.ipynb deleted file mode 100644 index ff79750..0000000 --- a/src/br/notebooks/fig3_pcna.ipynb +++ /dev/null @@ -1,768 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "d4f45efa-014c-4f41-a5eb-e778724f3bff", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-25a8cdbf-56c0-521b-b855-e8cd1f848fa1\"\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.archetype import AA_Fast\n", - "from br.features.plot import collect_outputs, plot, plot_stratified_pc\n", - "from br.features.reconstruction import stratified_latent_walk\n", - "from br.features.utils import (\n", - " normalize_intensities_and_get_colormap,\n", - " normalize_intensities_and_get_colormap_apply,\n", - ")\n", - "from br.models.compute_features import compute_features, get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import (\n", - " get_pc_loss,\n", - " get_pc_loss_chamfer,\n", - " save_embeddings,\n", - " save_emissions,\n", - ")\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "markdown", - "id": "bd85876c-ceac-4bea-8f36-31e3fdbeaa7e", - "metadata": {}, - "source": [ - "# Load data and models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a68cf046-21e5-4fda-8817-1c94dab23028", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_pcna_embeddings/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc7e6aa3-9cd0-49ec-9bc5-ef9cf8a1930a", - "metadata": {}, - "outputs": [], - "source": [ - "# Get datamodules, models, runs, model sizes\n", - "\n", - "dataset_name = \"pcna\"\n", - "batch_size = 2\n", - "debug = False\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")\n", - "\n", - "gg = pd.DataFrame()\n", - "gg[\"model\"] = run_names\n", - "gg[\"model_size\"] = model_sizes\n", - "gg.to_csv(save_path + \"model_sizes.csv\")" - ] - }, - { - "cell_type": "markdown", - "id": "c3293140-65a8-42ef-9c32-bf1ee2b28f40", - "metadata": {}, - "source": [ - "# Compute embeddings and emissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e26890a8-02ff-46e7-b4b1-696ad9f3e17e", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute embeddings and reconstructions for each model\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = None\n", - "eval_scaled_img = [False] * 5\n", - "eval_scaled_img_params = [{}] * 5\n", - "loss_eval_list = None\n", - "sample_points_list = [False, False, True, True, False]\n", - "skew_scale = 100\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "220d6dae-3012-49b9-9ca2-4b88b66c13b9", - "metadata": {}, - "outputs": [], - "source": [ - "# Save emission stats for each model\n", - "\n", - "max_batches = 2\n", - "save_emissions(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " max_batches,\n", - " debug,\n", - " device,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "5d2d7064-24f6-4d44-9de1-a54c410bcd2e", - "metadata": {}, - "source": [ - "# Compute benchmarking features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7947cc37-8de4-4351-9e21-89a888d6260b", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute multi-metric benchmarking features\n", - "\n", - "keys = [\"pcloud\", \"pcloud\", \"image\", \"image\", \"pcloud\"]\n", - "max_embed_dim = 256\n", - "DATA_LIST = get_all_configs_per_dataset(results_path)\n", - "data_config_list = DATA_LIST[dataset_name][\"data_paths\"]\n", - "\n", - "evolve_params = {\n", - " \"modality_list_evolve\": keys,\n", - " \"config_list_evolve\": data_config_list,\n", - " \"num_evolve_samples\": 40,\n", - " \"compute_evolve_dataloaders\": False,\n", - " \"eval_meshed_img\": [False] * 5,\n", - " \"skew_scale\": 100,\n", - " \"eval_meshed_img_model_type\": [None] * 5,\n", - " \"only_embedding\": False,\n", - " \"fit_pca\": False,\n", - "}\n", - "\n", - "loss_eval = get_pc_loss_chamfer()\n", - "loss_eval_list = [loss_eval] * 5\n", - "use_sample_points_list = [False, False, True, True, False]\n", - "\n", - "classification_params = {\"class_labels\": [\"cell_stage_fine\", \"flag_comment\"]}\n", - "rot_inv_params = {\"squeeze_2d\": False, \"id\": \"cell_id\", \"max_batches\": 4000}\n", - "\n", - "regression_params = {\"df_feat\": None, \"target_cols\": None, \"feature_df_path\": None}\n", - "\n", - "compactness_params = {\n", - " \"method\": \"mle\",\n", - " \"num_PCs\": None,\n", - " \"blobby_outlier_max_cc\": None,\n", - " \"check_duplicates\": True,\n", - "}\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "compute_embeds = False\n", - "\n", - "metric_list = [\n", - " # \"Rotation Invariance Error\",\n", - " # \"Evolution Energy\",\n", - " # \"Reconstruction\",\n", - " \"Classification\",\n", - " # \"Compactness\",\n", - "]\n", - "\n", - "\n", - "compute_features(\n", - " dataset=dataset_name,\n", - " results_path=results_path,\n", - " embeddings_path=save_path,\n", - " save_folder=save_path,\n", - " data_list=data_list,\n", - " all_models=all_models,\n", - " run_names=run_names,\n", - " use_sample_points_list=use_sample_points_list,\n", - " keys=keys,\n", - " device=device,\n", - " max_embed_dim=max_embed_dim,\n", - " splits_list=splits_list,\n", - " compute_embeds=compute_embeds,\n", - " classification_params=classification_params,\n", - " regression_params=regression_params,\n", - " metric_list=metric_list,\n", - " loss_eval_list=loss_eval_list,\n", - " evolve_params=evolve_params,\n", - " rot_inv_params=rot_inv_params,\n", - " compactness_params=compactness_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f731c9e7-d99c-4895-a1ca-1f9f5c046917", - "metadata": {}, - "source": [ - "# Polar plot viz" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5628c651-4748-405a-9918-46646ecedb91", - "metadata": {}, - "outputs": [], - "source": [ - "# Holistic viz of features\n", - "model_order = [\n", - " \"Classical_image\",\n", - " \"Rotation_invariant_image\",\n", - " \"Classical_pointcloud\",\n", - " \"Rotation_invariant_pointcloud\",\n", - "]\n", - "metric_list = [\n", - " \"reconstruction\",\n", - " \"emissions\",\n", - " \"classification_cell_stage_fine\",\n", - " \"classification_flag_comment\",\n", - " \"compactness\",\n", - " \"evolution_energy\",\n", - " \"model_sizes\",\n", - " \"rotation_invariance_error\",\n", - "]\n", - "norm = \"std\"\n", - "title = \"pcna_comparison\"\n", - "colors_list = None\n", - "unique_expressivity_metrics = [\"Classification_cell_stage_fine\", \"Classification_flag_comment\"]\n", - "df, df_non_agg = collect_outputs(save_path, norm, model_order, metric_list)\n", - "plot(save_path, df, model_order, title, colors_list, norm, unique_expressivity_metrics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c835cb34-d339-4cc6-8b33-b459c37b5377", - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "import seaborn as sns\n", - "\n", - "sns.set(font_scale=5)\n", - "sns.set_style(\"white\")\n", - "g = sns.catplot(\n", - " data=df_non_agg,\n", - " x=\"model\",\n", - " y=\"value\",\n", - " col=\"variable\",\n", - " kind=\"bar\",\n", - " sharey=False,\n", - " sharex=True,\n", - " order=model_order,\n", - " col_wrap=5,\n", - " height=20,\n", - " aspect=1,\n", - ")\n", - "g.set_xticklabels(rotation=30)" - ] - }, - { - "cell_type": "markdown", - "id": "fe6f15f6-ade1-4155-ad87-7ff5bdc87bcd", - "metadata": {}, - "source": [ - "# Latent walks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cdedc788-d80a-4b55-bfcf-4b66f78f9fd5", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "run_names = [\"Rotation_invariant_pointcloud_jitter\"]\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)\n", - "model = all_models[-1]\n", - "# Subset to interphase stages\n", - "interphase_stages = [\n", - " \"G1\",\n", - " \"earlyS\",\n", - " \"earlyS-midS\",\n", - " \"midS\",\n", - " \"midS-lateS\",\n", - " \"lateS\",\n", - " \"lateS-G2\",\n", - " \"G2\",\n", - "]\n", - "all_ret = all_ret.loc[all_ret[\"cell_stage_fine\"].isin(interphase_stages)].reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16e8e0c6-1aa7-40b0-9f60-5942b919b9d7", - "metadata": {}, - "outputs": [], - "source": [ - "# Params for viz\n", - "key = \"pcloud\"\n", - "stratify_key = \"cell_stage_fine\"\n", - "z_max = 0.3\n", - "z_ind = 2\n", - "flip = False\n", - "views = [\"xy\"]\n", - "xlim = [-20, 20]\n", - "ylim = [-20, 20]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b203254-8c02-416e-8103-d0d4e6d25db3", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute stratified latent walk\n", - "\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "stratified_latent_walk(\n", - " model,\n", - " device,\n", - " all_ret,\n", - " \"pcloud\",\n", - " 256,\n", - " 256,\n", - " 2,\n", - " this_save_path,\n", - " stratify_key,\n", - " latent_walk_range=[-2, 0, 2],\n", - " z_max=z_max,\n", - " z_ind=z_ind,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea49968c-e6f7-4bc1-aad4-6d6594484341", - "metadata": {}, - "outputs": [], - "source": [ - "# Save reconstruction plots\n", - "\n", - "import os\n", - "\n", - "items = os.listdir(this_save_path)\n", - "fnames = [i for i in items if i.split(\".\")[-1] == \"csv\"]\n", - "fnames = [i for i in fnames if i.split(\"_\")[1] == \"0\"]\n", - "fnames = [i for i in fnames if i.split(\"_\")[0] in interphase_stages]\n", - "names = [i.split(\".\")[0] for i in fnames]\n", - "\n", - "all_df = []\n", - "for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " df = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " # normalize per PC\n", - " df, cmap, vmin, vmax = normalize_intensities_and_get_colormap(\n", - " df, pcts=[5, 95], cm_name=\"YlGnBu\"\n", - " )\n", - " df[stratify_key] = names[idx]\n", - " all_df.append(df)\n", - "df = pd.concat(all_df, axis=0).reset_index(drop=True)\n", - "plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip)\n", - "\n", - "# normalize across all PCs\n", - "df, cmap, vmin, vmax = normalize_intensities_and_get_colormap(df, pcts=[5, 95], cm_name=\"YlGnBu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "abd2fff3-ab83-4780-b7a2-c983bc773922", - "metadata": {}, - "outputs": [], - "source": [ - "vmax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7c29c22-532c-47be-937a-91c7889de70a", - "metadata": {}, - "outputs": [], - "source": [ - "# save contrast adjusted reconstruction plots\n", - "\n", - "use_vmin = vmin\n", - "use_vmax = vmax\n", - "\n", - "for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " df = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " df[key] = names[idx]\n", - " this_name = names[idx]\n", - " df = normalize_intensities_and_get_colormap_apply(df, use_vmin, use_vmax)\n", - " np_arr = df[[\"x\", \"y\", \"z\"]].values\n", - " colors = cmap(df[\"inorm\"].values)[:, :3]\n", - " np_arr2 = colors\n", - " np_arr = np.concatenate([np_arr, np_arr2], axis=1)\n", - " np.save(this_save_path / Path(f\"{this_name}.npy\"), np_arr)" - ] - }, - { - "cell_type": "markdown", - "id": "c4628a65-a0d2-40df-bf88-49ca1eecfc0b", - "metadata": {}, - "source": [ - "# Pseudo time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a07a92d2-0b31-43b2-a603-7449f5f3c742", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute pseudo time bins\n", - "\n", - "bins = [\n", - " (247.407, 390.752),\n", - " (390.752, 533.383),\n", - " (533.383, 676.015),\n", - " (676.015, 818.646),\n", - " (818.646, 961.277),\n", - "]\n", - "correct_bins = []\n", - "for ind, row in all_ret.iterrows():\n", - " this_bin = []\n", - " for bin_ in bins:\n", - " if (row[\"volume_of_nucleus_um3\"] > bin_[0]) and (row[\"volume_of_nucleus_um3\"] <= bin_[1]):\n", - " this_bin.append(bin_)\n", - " if row[\"volume_of_nucleus_um3\"] < bins[0][0]:\n", - " this_bin.append(bin_)\n", - " if row[\"volume_of_nucleus_um3\"] > bins[4][1]:\n", - " this_bin.append(bin_)\n", - " assert len(this_bin) == 1\n", - " correct_bins.append(this_bin[0])\n", - "all_ret[\"vol_bins\"] = correct_bins\n", - "import pandas as pd\n", - "\n", - "all_ret[\"vol_bins_inds\"] = pd.factorize(all_ret[\"vol_bins\"])[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "637f5739-bf2e-412b-a01d-7ce3f06f55b2", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret = all_ret.groupby([\"vol_bins\"]).sample(n=75).reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7f884eae-707f-49aa-8e33-c3d1884bb141", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret[\"cell_stage_fine\"].value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a725cc4-3fa1-4ea8-9887-b918236d06e9", - "metadata": {}, - "outputs": [], - "source": [ - "z_max = 0.2\n", - "z_ind = 2\n", - "use_vmin = 5.03\n", - "use_vmax = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4cedf95e-2a1c-4e08-b42c-f94408271dfb", - "metadata": {}, - "outputs": [], - "source": [ - "# Save reconstructions per bin\n", - "\n", - "this_save_path = Path(save_path) / Path(\"pseudo_time_2\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "cols = [i for i in all_ret.columns if \"mu\" in i]\n", - "for ind, gr in all_ret.groupby([\"vol_bins\"]):\n", - " this_stage_df = gr.reset_index(drop=True)\n", - " this_stage_mu = this_stage_df[cols].values\n", - " mean_mu = this_stage_mu.mean(axis=0)\n", - " dist = (this_stage_mu - mean_mu) ** 2\n", - " dist = np.sum(dist, axis=1)\n", - " closest_idx = np.argmin(dist)\n", - " real_input = this_stage_df.iloc[closest_idx][\"CellId\"]\n", - "\n", - " z_inf = torch.tensor(mean_mu).unsqueeze(axis=0)\n", - " z_inf = z_inf.to(device)\n", - " z_inf = z_inf.float()\n", - "\n", - " decoder = model.decoder[\"pcloud\"]\n", - " xhat = decoder(z_inf)\n", - " xhat = save_pcloud(xhat[0], this_save_path, str(ind), z_max, z_ind)\n", - "\n", - "\n", - "names = os.listdir(this_save_path)\n", - "names = [i for i in names if i.split(\".\")[-1] == \"csv\"]\n", - "names = [i.split(\".csv\")[0] for i in names]\n", - "plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28a5e885-a992-4cb5-9f5a-602cbb02e84a", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c4fa48d-f843-445c-9835-3cd52cf1416f", - "metadata": {}, - "outputs": [], - "source": [ - "# Save contrast adjusted recons\n", - "\n", - "items = os.listdir(this_save_path)\n", - "items = [this_save_path / Path(i) for i in items if i.split(\".\")[-1] == \"csv\"]\n", - "\n", - "all_df = []\n", - "for j, i in enumerate(items):\n", - " df = pd.read_csv(i)\n", - " df[\"cluster\"] = str(i).split(\"/\")[-1][:-4]\n", - " df = df.loc[df[\"z\"] < 0.4]\n", - " df = df.loc[df[\"z\"] > -0.4].reset_index(drop=True)\n", - " all_df.append(df)\n", - "df = pd.concat(all_df, axis=0).reset_index(drop=True)\n", - "\n", - "for clust in df[\"cluster\"].unique():\n", - " df_2 = df.loc[df[\"cluster\"] == clust].reset_index(drop=True)\n", - " df_2 = normalize_intensities_and_get_colormap_apply(df_2, vmin=use_vmin, vmax=use_vmax)\n", - " colors = cmap(df_2[\"inorm\"].values)[:, :3]\n", - " np_arr = df_2[[\"x\", \"y\", \"z\"]].values\n", - " np_arr2 = colors\n", - " np_arr = np.concatenate([np_arr, np_arr2], axis=1)\n", - " np.save(Path(this_save_path) / Path(f\"{clust}.npy\"), np_arr)" - ] - }, - { - "cell_type": "markdown", - "id": "743a930f-58f9-42f2-bfd0-7ef91cecd0a0", - "metadata": {}, - "source": [ - "# Archetype analysis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a313a25-b489-4e4b-a5ec-8f4ac5acdf53", - "metadata": {}, - "outputs": [], - "source": [ - "# Save 8 archetypes\n", - "this_ret = all_ret\n", - "labels = this_ret[\"cell_stage_fine\"].values\n", - "matrix = this_ret[[i for i in this_ret.columns if \"mu\" in i]].values\n", - "\n", - "n_archetypes = 8\n", - "aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)\n", - "archetypes_df = pd.DataFrame(aa.Z, columns=[f\"mu_{i}\" for i in range(matrix.shape[1])])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c48fb301-0722-4dcc-9f0f-f346813ad2ec", - "metadata": {}, - "outputs": [], - "source": [ - "z_max = 0.2\n", - "z_ind = 2\n", - "use_vmin = 5.03\n", - "use_vmax = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33be45de-1d30-4ff6-99af-dd348b835ee8", - "metadata": {}, - "outputs": [], - "source": [ - "# Save archetypes\n", - "this_save_path = Path(save_path) / Path(\"archetypes\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "model = model.eval()\n", - "key = \"pcloud\"\n", - "all_xhat = []\n", - "with torch.no_grad():\n", - " for i in range(n_archetypes):\n", - " z_inf = torch.tensor(archetypes_df.iloc[i].values).unsqueeze(axis=0)\n", - " z_inf = z_inf.to(device)\n", - " z_inf = z_inf.float()\n", - " decoder = model.decoder[key]\n", - " xhat = decoder(z_inf)\n", - " xhat = xhat.detach().cpu().numpy()\n", - " xhat = save_pcloud(xhat[0], this_save_path, i, z_max, z_ind)\n", - " print(xhat.shape)\n", - " all_xhat.append(xhat)\n", - "\n", - "names = [str(i) for i in range(n_archetypes)]\n", - "plot_pc_saved(this_save_path, names, key, flip, 0.5, views, xlim, ylim)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24984e0f-8ea1-4ab7-809d-3e5eead0f7f4", - "metadata": {}, - "outputs": [], - "source": [ - "key" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9d5684e-cc36-482a-a5fc-2679597c5a7e", - "metadata": {}, - "outputs": [], - "source": [ - "# Save contrast adjusted numpy arrays\n", - "key = \"archetype\"\n", - "import os\n", - "\n", - "items = os.listdir(this_save_path)\n", - "fnames = [i for i in items if i.split(\".\")[-1] == \"csv\"]\n", - "names = [i.split(\".\")[0] for i in fnames]\n", - "\n", - "df = pd.DataFrame([])\n", - "for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " print(fname)\n", - " dft = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " dft[key] = names[idx]\n", - " df = pd.concat([df, dft], ignore_index=True)\n", - "\n", - "archetypes = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\"]\n", - "\n", - "for arch in archetypes:\n", - " this_df = df.loc[df[\"archetype\"] == arch].reset_index(drop=True)\n", - " np_arr = this_df[[\"x\", \"y\", \"z\"]].values\n", - " this_df = normalize_intensities_and_get_colormap_apply(this_df, use_vmin, use_vmax)\n", - " colors = cmap(this_df[\"inorm\"].values)[:, :3]\n", - " np_arr2 = colors\n", - " np_arr = np.concatenate([np_arr, np_arr2], axis=1)\n", - " print(np_arr.shape)\n", - " np.save(this_save_path / Path(f\"{arch}.npy\"), np_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f3de62f-aad6-424d-b76a-157ce8f69c32", - "metadata": {}, - "outputs": [], - "source": [ - "use_vmax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "229c9924-95c2-4308-a2ee-dacab2d37ab6", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/fig4_other_punctate.ipynb b/src/br/notebooks/fig4_other_punctate.ipynb deleted file mode 100644 index 3deade1..0000000 --- a/src/br/notebooks/fig4_other_punctate.ipynb +++ /dev/null @@ -1,496 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "387fbbc4-3f5e-4491-bb6b-71c3811300a0", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-ff70592b-6c77-5bde-832d-88d1e18cad50\"\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.archetype import AA_Fast\n", - "from br.features.plot import collect_outputs, plot, plot_stratified_pc\n", - "from br.features.reconstruction import stratified_latent_walk\n", - "from br.features.utils import (\n", - " normalize_intensities_and_get_colormap,\n", - " normalize_intensities_and_get_colormap_apply,\n", - ")\n", - "from br.models.compute_features import compute_features, get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import (\n", - " get_pc_loss,\n", - " get_pc_loss_chamfer,\n", - " save_embeddings,\n", - " save_emissions,\n", - ")\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "markdown", - "id": "b19c1cd4-02b6-4e6e-abe5-c90de92ebbe6", - "metadata": {}, - "source": [ - "# Load data and models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d4372d26-2176-4e06-b47e-3d9d4ef80554", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "\n", - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_var_punctate_embeddings/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53e0b793-8341-4789-9a71-ab14c667cf52", - "metadata": {}, - "outputs": [], - "source": [ - "# Util function\n", - "def get_data_and_models(dataset_name, batch_size, results_path, debug=False):\n", - " data_list = get_data(dataset_name, batch_size, results_path, debug)\n", - " all_models, run_names, model_sizes = load_model_from_path(\n", - " dataset_name, results_path\n", - " ) # default list of models in load_models.py\n", - " return data_list, all_models, run_names, model_sizes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c71312b5-cc24-4d47-a4be-471dddda0e18", - "metadata": {}, - "outputs": [], - "source": [ - "# Get datamodules, models, runs, model sizes\n", - "\n", - "dataset_name = \"other_punctate\"\n", - "batch_size = 2\n", - "debug = False\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28f69257-cb1c-4a96-8903-482d9b03e826", - "metadata": {}, - "outputs": [], - "source": [ - "gg = pd.DataFrame()\n", - "gg[\"model\"] = run_names\n", - "gg[\"model_size\"] = model_sizes\n", - "gg.to_csv(save_path + \"model_sizes.csv\")" - ] - }, - { - "cell_type": "markdown", - "id": "ede0eaca-ef18-4bf3-91f3-587eff53bce1", - "metadata": {}, - "source": [ - "# compute embeddings and emissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33808edb-7c32-486e-8b2c-d930bea1a848", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute embeddings and reconstructions for each model\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = None\n", - "eval_scaled_img = [False] * 5\n", - "eval_scaled_img_params = [{}] * 5\n", - "loss_eval_list = None\n", - "sample_points_list = [True, True, False, False, False]\n", - "skew_scale = 100\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a5a392c-2a2d-41ce-aacf-b1e1fb3c8156", - "metadata": {}, - "outputs": [], - "source": [ - "# Save emission stats for each model\n", - "\n", - "max_batches = 2\n", - "save_emissions(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " max_batches,\n", - " debug,\n", - " device,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "955098f2-be5f-4731-8913-80e2e4caf25b", - "metadata": {}, - "source": [ - "# Compute benchmarking features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba7d4798-08bf-4799-95d8-922d20afe003", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute multi-metric benchmarking features\n", - "\n", - "keys = [\"image\", \"image\", \"pcloud\", \"pcloud\", \"pcloud\"]\n", - "max_embed_dim = 256\n", - "DATA_LIST = get_all_configs_per_dataset(results_path)\n", - "data_config_list = DATA_LIST[dataset_name][\"data_paths\"]\n", - "\n", - "evolve_params = {\n", - " \"modality_list_evolve\": keys,\n", - " \"config_list_evolve\": data_config_list,\n", - " \"num_evolve_samples\": 40,\n", - " \"compute_evolve_dataloaders\": False,\n", - " \"eval_meshed_img\": [False] * 5,\n", - " \"skew_scale\": 100,\n", - " \"eval_meshed_img_model_type\": [None] * 5,\n", - " \"only_embedding\": False,\n", - " \"fit_pca\": False,\n", - "}\n", - "\n", - "loss_eval = get_pc_loss_chamfer()\n", - "loss_eval_list = [loss_eval] * 5\n", - "use_sample_points_list = [True, True, False, False, False]\n", - "\n", - "classification_params = {\"class_labels\": [\"structure_name\", \"cell_stage\"]}\n", - "rot_inv_params = {\"squeeze_2d\": False, \"id\": \"cell_id\", \"max_batches\": 40}\n", - "\n", - "regression_params = {\"df_feat\": None, \"target_cols\": None, \"feature_df_path\": None}\n", - "\n", - "compactness_params = {\n", - " \"method\": \"mle\",\n", - " \"num_PCs\": None,\n", - " \"blobby_outlier_max_cc\": None,\n", - " \"check_duplicates\": True,\n", - "}\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "compute_embeds = False\n", - "\n", - "metric_list = [\n", - " \"Rotation Invariance Error\",\n", - " \"Evolution Energy\",\n", - " \"Reconstruction\",\n", - " \"Classification\",\n", - " \"Compactness\",\n", - "]\n", - "\n", - "\n", - "compute_features(\n", - " dataset=dataset_name,\n", - " results_path=results_path,\n", - " embeddings_path=save_path,\n", - " save_folder=save_path,\n", - " data_list=data_list,\n", - " all_models=all_models,\n", - " run_names=run_names,\n", - " use_sample_points_list=use_sample_points_list,\n", - " keys=keys,\n", - " device=device,\n", - " max_embed_dim=max_embed_dim,\n", - " splits_list=splits_list,\n", - " compute_embeds=compute_embeds,\n", - " classification_params=classification_params,\n", - " regression_params=regression_params,\n", - " metric_list=metric_list,\n", - " loss_eval_list=loss_eval_list,\n", - " evolve_params=evolve_params,\n", - " rot_inv_params=rot_inv_params,\n", - " compactness_params=compactness_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "bba83dc7-aba7-486f-bf62-037efea988ba", - "metadata": {}, - "source": [ - "# Polar plot viz" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b308ebae-ba3e-486e-8ec6-3d03d38e7a45", - "metadata": {}, - "outputs": [], - "source": [ - "# Holistic viz of features\n", - "model_order = [\n", - " \"Classical_image\",\n", - " \"Rotation_invariant_image\",\n", - " \"Classical_pointcloud\",\n", - " \"Rotation_invariant_pointcloud\",\n", - "]\n", - "metric_list = [\n", - " \"reconstruction\",\n", - " \"emissions\",\n", - " \"classification_cell_stage\",\n", - " \"classification_structure_name\",\n", - " \"compactness\",\n", - " \"evolution_energy\",\n", - " \"model_sizes\",\n", - " \"rotation_invariance_error\",\n", - "]\n", - "norm = \"std\"\n", - "title = \"variance_comparison\"\n", - "colors_list = None\n", - "unique_expressivity_metrics = [\"classification_cell_stage\", \"classification_structure_name\"]\n", - "df, df_non_agg = collect_outputs(save_path, norm, model_order, metric_list)\n", - "plot(save_path, df, model_order, title, colors_list, norm, unique_expressivity_metrics)" - ] - }, - { - "cell_type": "markdown", - "id": "311c92f3-8726-49d6-9d97-0465a333f545", - "metadata": {}, - "source": [ - "# latent walks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93e09e0f-e675-40a2-9ae6-b18cbaafe41d", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "run_names = [\"Rotation_invariant_pointcloud_structurenorm\"]\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa155823-fbc0-486b-b103-2641c29aaef5", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret = all_ret.merge(df[[\"CellId\", \"structure_name\", \"cell_stage\"]], on=\"CellId\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8bee006f-1a51-42b2-b5db-993eef9f0354", - "metadata": {}, - "outputs": [], - "source": [ - "structs = [\"NUP153\", \"SON\", \"HIST1H2BJ\", \"SMC1A\", \"CETN2\", \"SLC25A17\", \"RAB5A\"]\n", - "all_ret = all_ret.loc[all_ret[\"structure_name\"].isin(structs)].reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "83de34ac-5697-4a25-8f7e-ec424ca80c67", - "metadata": {}, - "outputs": [], - "source": [ - "# Params for viz\n", - "key = \"pcloud\"\n", - "stratify_key = \"structure_name\"\n", - "z_max = None\n", - "z_ind = 2\n", - "flip = False\n", - "views = [\"xy\"]\n", - "xlim = [-20, 20]\n", - "ylim = [-20, 20]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40b32aaf-7f4e-4966-a3dd-3a7bfde28475", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute stratified latent walk\n", - "\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "stratified_latent_walk(\n", - " model,\n", - " device,\n", - " all_ret,\n", - " \"pcloud\",\n", - " 256,\n", - " 256,\n", - " 2,\n", - " this_save_path,\n", - " stratify_key,\n", - " latent_walk_range=[-2, 0, 2],\n", - " z_max=z_max,\n", - " z_ind=z_ind,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d56d1af-1b1a-4ae4-8225-271ac99dfce3", - "metadata": {}, - "outputs": [], - "source": [ - "# Save reconstruction plots\n", - "\n", - "viz_norms = {\n", - " \"CETN2\": [440, 800],\n", - " \"NUP153\": [420, 600],\n", - " \"SON\": [420, 1500],\n", - " \"SMC1A\": [450, 630],\n", - " \"RAB5A\": [420, 600],\n", - " \"SLC25A17\": [400, 515],\n", - " \"HIST1H2BJ\": [450, 2885],\n", - "}\n", - "import yaml\n", - "\n", - "# norms used for model training\n", - "model_norms = \"./src/br/data/preprocessing/pc_preprocessing/model_structnorms.yaml\"\n", - "with open(model_norms) as stream:\n", - " model_norms = yaml.safe_load(stream)\n", - "\n", - "# norms used for viz\n", - "viz_norms = \"./src/br/data/preprocessing/pc_preprocessing/viz_structnorms.yaml\"\n", - "with open(viz_norms) as stream:\n", - " viz_norms = yaml.safe_load(stream)\n", - "\n", - "import os\n", - "\n", - "items = os.listdir(this_save_path)\n", - "for struct in structs:\n", - " fnames = [i for i in items if i.split(\".\")[-1] == \"csv\"]\n", - " fnames = [i for i in fnames if i.split(\"_\")[1] == \"0\"]\n", - " fnames = [i for i in fnames if i.split(\"_\")[0] in [struct]]\n", - " names = [i.split(\".\")[0] for i in fnames]\n", - "\n", - " renorm = model_norms[struct]\n", - " this_viz_norm = viz_norms[struct]\n", - " use_vmin = this_viz_norm[0]\n", - " use_vmax = this_viz_norm[1]\n", - "\n", - " all_df = []\n", - " for idx, _ in enumerate(fnames):\n", - " fname = fnames[idx]\n", - " df = pd.read_csv(f\"{this_save_path}/{fname}\", index_col=0)\n", - " df[\"s\"] = df[\"s\"] / 10 # scalar values were scaled by 10 during training\n", - " df[\"s\"] = df[\"s\"] * (renorm[1] - renorm[0]) + renorm[0]\n", - " df[stratify_key] = names[idx]\n", - " all_df.append(df)\n", - " df = pd.concat(all_df, axis=0).reset_index(drop=True)\n", - " if struct in [\"NUP153\", \"SON\", \"HIST1H2BJ\", \"SMC1A\"]:\n", - " df = df.loc[df[\"z\"] < 0.2].reset_index(drop=True)\n", - " df = normalize_intensities_and_get_colormap_apply(df, use_vmin, use_vmax)\n", - " plot_stratified_pc(df, xlim, ylim, stratify_key, this_save_path, cmap, flip)\n", - "\n", - " for pc_bin in df[\"structure_name\"].unique():\n", - " this_df = df.loc[df[\"structure_name\"] == pc_bin].reset_index(drop=True)\n", - " print(this_df.shape, struct, pc_bin)\n", - " np_arr = this_df[[\"x\", \"y\", \"z\"]].values\n", - " colors = cmap(this_df[\"inorm\"].values)[:, :3]\n", - " np_arr2 = colors\n", - " np_arr = np.concatenate([np_arr, np_arr2], axis=1)\n", - " np.save(this_save_path / Path(f\"{stratify_key}_{pc_bin}.npy\"), np_arr)\n", - " cmap = plt.get_cmap(\"YlGnBu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "abd43718-b69e-4188-a2a6-febe3b687c4f", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/fig5_npm1_analysis.ipynb b/src/br/notebooks/fig5_npm1_analysis.ipynb deleted file mode 100644 index c4a1b73..0000000 --- a/src/br/notebooks/fig5_npm1_analysis.ipynb +++ /dev/null @@ -1,585 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "22c4a2e8-f280-463d-9140-68d0a8f8e63c", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-864c07c4-8eeb-5b23-8d57-eaeb942a9a0f\"\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.archetype import AA_Fast\n", - "from br.features.plot import collect_outputs, plot, plot_stratified_pc\n", - "from br.features.reconstruction import stratified_latent_walk\n", - "from br.features.utils import (\n", - " normalize_intensities_and_get_colormap,\n", - " normalize_intensities_and_get_colormap_apply,\n", - ")\n", - "from br.models.compute_features import compute_features, get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import (\n", - " get_pc_loss,\n", - " get_pc_loss_chamfer,\n", - " save_embeddings,\n", - " save_emissions,\n", - ")\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10394787-c363-4803-80ce-2f72ce16df21", - "metadata": {}, - "outputs": [], - "source": [ - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_npm1_save_embeddings/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ac2b6072-3314-420b-9746-f67b28fb8539", - "metadata": {}, - "outputs": [], - "source": [ - "dataset_name = \"npm1\"\n", - "batch_size = 2\n", - "debug = False\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7428170f-624a-4f9c-84b4-0c6162a2759a", - "metadata": {}, - "source": [ - "# Compute embeddings and emissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc65e52e-0dbd-4651-bf92-e29d774ed2a6", - "metadata": {}, - "outputs": [], - "source": [ - "from br.models.save_embeddings import save_embeddings\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = None\n", - "eval_scaled_img = [False] * 5\n", - "\n", - "gt_mesh_dir = MESH_DIR\n", - "gt_sampled_pts_dir = SAMPLE_DIR\n", - "gt_scale_factor_dict_path = SCALE_FACTOR_DIR\n", - "\n", - "eval_scaled_img_params = [\n", - " {\n", - " \"eval_scaled_img_model_type\": \"iae\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": None,\n", - " \"gt_sampled_pts_dir\": gt_sampled_pts_dir,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"sdf\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"seg\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"sdf\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"seg\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - "]\n", - "loss_eval_list = [torch.nn.MSELoss(reduction=\"none\")] * 5\n", - "sample_points_list = [False] * 5\n", - "skew_scale = None\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce355be9-4fe8-40e3-95ba-8e86714a022e", - "metadata": {}, - "outputs": [], - "source": [ - "run_names" - ] - }, - { - "cell_type": "markdown", - "id": "a377891f-f19c-494a-84f1-5e26c96ecc05", - "metadata": {}, - "source": [ - "# Latent walks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e928040-117f-44e6-b2c8-6bfc6ed3614d", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "\n", - "run_names = [\"Rotation_invariant_pointcloud_SDF\"]\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)\n", - "model = all_models[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b477940-d3e5-429b-a6bf-0e857843b723", - "metadata": {}, - "outputs": [], - "source": [ - "save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c288e6f-0fd8-422b-baa2-c57c8a0178e2", - "metadata": {}, - "outputs": [], - "source": [ - "import pyvista as pv\n", - "from cyto_dl.image.transforms import RotationMask\n", - "from skimage.io import imread\n", - "from sklearn.decomposition import PCA\n", - "from tqdm import tqdm\n", - "\n", - "from br.data.utils import mesh_seg_model_output\n", - "from br.visualization.mitsuba_render_image import plot\n", - "\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "lw_dict = {\"num_pieces\": [], \"PC\": [], \"bin\": [], \"CellId\": []}\n", - "for num_pieces in all_ret[\"STR_connectivity_cc_thresh\"].unique():\n", - " this_sub_m = all_ret.loc[all_ret[\"STR_connectivity_cc_thresh\"] == num_pieces].reset_index(\n", - " drop=True\n", - " )\n", - " all_features = this_sub_m[[i for i in this_sub_m.columns if \"mu\" in i]].values\n", - " latent_dim = 512\n", - " dim_size = latent_dim\n", - " x_label = \"pcloud\"\n", - " pca = PCA(n_components=dim_size)\n", - " pca_features = pca.fit_transform(all_features)\n", - " pca_std_list = pca_features.std(axis=0)\n", - " for rank in [0, 1]:\n", - " all_xhat = []\n", - " all_closest_real = []\n", - " all_closest_img = []\n", - " latent_walk_range = [-2, 0, 2]\n", - " for value_index, value in enumerate(tqdm(latent_walk_range, total=len(latent_walk_range))):\n", - " z_inf = torch.zeros(1, dim_size)\n", - " z_inf[:, rank] += value * pca_std_list[rank]\n", - " z_inf = pca.inverse_transform(z_inf).numpy()\n", - "\n", - " dist = (all_features - z_inf) ** 2\n", - " dist = np.sum(dist, axis=1)\n", - " closest_idx = np.argmin(dist)\n", - " closest_real_id = this_sub_m.iloc[closest_idx][\"CellId\"]\n", - " mesh = pv.read(\n", - " all_ret.loc[all_ret[\"CellId\"] == closest_real_id][\"mesh_path_noalign\"].iloc[0]\n", - " )\n", - " mesh.save(this_save_path / Path(f\"{num_pieces}_{rank}_{value_index}.ply\"))\n", - "\n", - " lw_dict[\"num_pieces\"].append(num_pieces)\n", - " lw_dict[\"PC\"].append(rank)\n", - " lw_dict[\"bin\"].append(value_index)\n", - " lw_dict[\"CellId\"].append(closest_real_id)\n", - "\n", - " # this_mesh_path = this_save_path / Path(f'{num_pieces}_{rank}_{value_index}.ply')\n", - " # this_mesh_path = './' + str(this_mesh_path)\n", - "\n", - " # mitsuba_save_path = this_save_path / Path('mitsuba')\n", - " # mitsuba_save_path.mkdir(parents=True, exist_ok=True)\n", - " # mitsuba_save_path = './' + str(mitsuba_save_path)\n", - " # name = f\"{num_pieces}_{rank}_{value_index}\"\n", - "\n", - " # plot(str(this_mesh_path), mitsuba_save_path, 120, None, None, name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "83f292b3-f955-4b77-a0af-01ffae8d8c56", - "metadata": {}, - "outputs": [], - "source": [ - "lw_dict = pd.DataFrame(lw_dict)\n", - "lw_dict.to_csv(this_save_path / \"latent_walk.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5cfe1459-5749-4773-9c93-1dc17062c2b1", - "metadata": {}, - "outputs": [], - "source": [ - "lw_dict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7656abd5-b72e-4513-86db-ea7341540cab", - "metadata": {}, - "outputs": [], - "source": [ - "save_path = \"./test_npm1_save_embeddings/\"\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "\n", - "# num_pieces = 4.0\n", - "num_pieces = \"2.0\"\n", - "rank = 0\n", - "bin_ = 0\n", - "this_mesh_path = this_save_path / Path(f\"{num_pieces}_{rank}_{bin_}.ply\")\n", - "this_mesh_path = \"./\" + str(this_mesh_path)\n", - "\n", - "save_path = this_save_path / Path(\"mitsuba\")\n", - "save_path.mkdir(parents=True, exist_ok=True)\n", - "save_path = \"./\" + str(save_path)\n", - "name = f\"{num_pieces}_{rank}_{bin_}\"\n", - "\n", - "\n", - "plot(str(this_mesh_path), save_path, 10, 0, None, name)" - ] - }, - { - "cell_type": "markdown", - "id": "fe57d425-25e7-4476-951a-afbbcf5ff0ca", - "metadata": {}, - "source": [ - "# Archetype" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b043c7d8-7e82-4371-8653-4a44d94e7907", - "metadata": {}, - "outputs": [], - "source": [ - "from br.features.archetype import AA_Fast\n", - "\n", - "n_archetypes = 5\n", - "matrix = all_ret[[i for i in all_ret.columns if \"mu\" in i]].values\n", - "aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)\n", - "\n", - "import pandas as pd\n", - "\n", - "archetypes_df = pd.DataFrame(aa.Z, columns=[f\"mu_{i}\" for i in range(matrix.shape[1])])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb7affce-1f78-4813-b538-1ce47486f432", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path = Path(save_path) / Path(\"archetypes\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "arch_dict = {\"CellId\": [], \"archetype\": []}\n", - "all_features = matrix\n", - "for i in range(n_archetypes):\n", - " this_mu = archetypes_df.iloc[i].values\n", - " dist = (all_features - this_mu) ** 2\n", - " dist = np.sum(dist, axis=1)\n", - " closest_idx = np.argmin(dist)\n", - " closest_real_id = all_ret.iloc[closest_idx][\"CellId\"]\n", - " print(dist, closest_real_id)\n", - " mesh = pv.read(all_ret.loc[all_ret[\"CellId\"] == closest_real_id][\"mesh_path_noalign\"].iloc[0])\n", - " mesh.save(this_save_path / Path(f\"{i}.ply\"))\n", - " arch_dict[\"archetype\"].append(i)\n", - " arch_dict[\"CellId\"].append(closest_real_id)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4cb5e0e2-27ba-4d30-80a9-cbbcb297781f", - "metadata": {}, - "outputs": [], - "source": [ - "arch_dict = pd.DataFrame(arch_dict)\n", - "arch_dict.to_csv(this_save_path / \"archetypes.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2b0035e2-6154-4e83-813a-a7776d8d46ab", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42294cd1-bf48-4e7d-9304-41a897099b1b", - "metadata": {}, - "outputs": [], - "source": [ - "save_path = \"./test_npm1_save_embeddings/\"\n", - "this_save_path = Path(save_path) / Path(\"archetypes\")\n", - "\n", - "arch = \"4\"\n", - "this_mesh_path = this_save_path / Path(f\"{arch}.ply\")\n", - "this_mesh_path = \"./\" + str(this_mesh_path)\n", - "\n", - "save_path = this_save_path / Path(\"mitsuba\")\n", - "save_path.mkdir(parents=True, exist_ok=True)\n", - "save_path = \"./\" + str(save_path)\n", - "name = f\"{arch}\"\n", - "\n", - "\n", - "plot(str(this_mesh_path), save_path, 90, 0, None, name)" - ] - }, - { - "cell_type": "markdown", - "id": "10b04aec-802f-4a3f-841d-f5b8a7a9afc6", - "metadata": {}, - "source": [ - "# Pseudo time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a15b8947-fc00-46b1-b13b-191d0660822b", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret[\"volume_of_nucleus_um3\"] = all_ret[\"dna_shape_volume_lcc\"] * 0.108**3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5acaf553-9edb-493d-af82-4805418ba75e", - "metadata": {}, - "outputs": [], - "source": [ - "feat = \"volume_of_nucleus_um3\"\n", - "upper = np.quantile(all_ret[feat], q=0.99)\n", - "lower = np.quantile(all_ret[feat], q=0.01)\n", - "\n", - "this = all_ret.loc[all_ret[feat] < upper]\n", - "this = this.loc[this[feat] > lower].reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3fa15fee-896b-4980-a18f-33ad32f8bbfb", - "metadata": {}, - "outputs": [], - "source": [ - "this[\"vol_bins\"] = pd.cut(this[feat], bins=5)\n", - "this[\"vol_bins_ind\"] = pd.factorize(this[\"vol_bins\"])[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba66e901-7d70-4c9a-b91b-df7f9094931c", - "metadata": {}, - "outputs": [], - "source": [ - "this[\"vol_bins\"].value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a5946e6f-52ec-4be3-abf1-da95fb90efdf", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path = Path(save_path) / Path(\"pseudo_time\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35e59ca0-dbc1-4882-ba39-e92c2157060c", - "metadata": {}, - "outputs": [], - "source": [ - "all_features = this[[i for i in this.columns if \"mu\" in i]].values\n", - "\n", - "vol_dict = {\"vol_bin\": [], \"CellId\": []}\n", - "this[\"vol_bins\"] = this[\"vol_bins\"].astype(str)\n", - "for hh in this[\"vol_bins\"].unique():\n", - " this_ret = this.loc[this[\"vol_bins\"] == hh].reset_index(drop=True)\n", - "\n", - " this_mu = np.expand_dims(\n", - " this_ret[[i for i in this_ret.columns if \"mu\" in i]].mean(axis=0), axis=0\n", - " )\n", - " dist = (all_features - this_mu) ** 2\n", - " # dist = np.sum(dist, axis=1)\n", - " k = 1\n", - " # print(min(latent_dim, all_features.shape[0]))\n", - " inds = np.argpartition(dist.sum(axis=-1), k)[:k] # get 10 closest\n", - " closest_samples = this.iloc[inds].reset_index(drop=True)\n", - " for ind, row in closest_samples.iterrows():\n", - " # closest_real_id = this.iloc[closest_idx]['CellId']\n", - " closest_real_id = row[\"CellId\"]\n", - " print(\n", - " closest_idx,\n", - " this_ret[\"vol_bins\"].unique(),\n", - " all_features.shape,\n", - " this_ret.shape,\n", - " this_ret[\"dna_shape_volume_lcc\"].mean(),\n", - " closest_real_id,\n", - " )\n", - " mesh = pv.read(\n", - " all_ret.loc[all_ret[\"CellId\"] == closest_real_id][\"mesh_path_noalign\"].iloc[0]\n", - " )\n", - " mesh.save(this_save_path / Path(f\"{hh}_{ind}_{closest_real_id}.ply\"))\n", - "\n", - " vol_dict[\"vol_bin\"].append(hh)\n", - " vol_dict[\"CellId\"].append(closest_real_id)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afd2c4a8-dc5a-4cff-a9d9-154c3793fa6f", - "metadata": {}, - "outputs": [], - "source": [ - "vol_dict = pd.DataFrame(vol_dict)\n", - "vol_dict.to_csv(this_save_path / \"pseudo_time.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3faab1b2-a09e-49ae-8a0d-28057cf7e8ba", - "metadata": {}, - "outputs": [], - "source": [ - "save_path = \"./test_npm1_save_embeddings/\"\n", - "this_save_path = Path(save_path) / Path(\"pseudo_time\")\n", - "\n", - "pseu = \"(533.383, 676.015]_0_970952\"\n", - "this_mesh_path = this_save_path / Path(f\"{pseu}.ply\")\n", - "this_mesh_path = \"./\" + str(this_mesh_path)\n", - "\n", - "save_path = this_save_path / Path(\"mitsuba\")\n", - "save_path.mkdir(parents=True, exist_ok=True)\n", - "save_path = \"./\" + str(save_path)\n", - "name = f\"{pseu}\"\n", - "\n", - "\n", - "plot(str(this_mesh_path), save_path, 90, 90, None, name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca62a559-48a1-4273-ac5f-4e2fcd407e4a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/fig6_other_polymorphic.ipynb b/src/br/notebooks/fig6_other_polymorphic.ipynb deleted file mode 100644 index fb8272a..0000000 --- a/src/br/notebooks/fig6_other_polymorphic.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "17d20bd5-43ee-4490-8811-6edbde5f86d1", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-864c07c4-8eeb-5b23-8d57-eaeb942a9a0f\"\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.archetype import AA_Fast\n", - "from br.features.plot import collect_outputs, plot, plot_stratified_pc\n", - "from br.features.reconstruction import stratified_latent_walk\n", - "from br.features.utils import (\n", - " normalize_intensities_and_get_colormap,\n", - " normalize_intensities_and_get_colormap_apply,\n", - ")\n", - "from br.models.compute_features import compute_features, get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import (\n", - " get_pc_loss,\n", - " get_pc_loss_chamfer,\n", - " save_embeddings,\n", - " save_emissions,\n", - ")\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "markdown", - "id": "b8b6a10b-38a9-4fc5-8dd0-03babdccee70", - "metadata": {}, - "source": [ - "# Load data and models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f05f5c60-cae8-41d7-9297-eef048e459b3", - "metadata": {}, - "outputs": [], - "source": [ - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_polymorphic_save_embeddings/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ad0f0f4-fa0e-412f-a6fe-ec88e49ffd43", - "metadata": {}, - "outputs": [], - "source": [ - "dataset_name = \"other_polymorphic\"\n", - "batch_size = 2\n", - "debug = False\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0cb523c3-ef58-4557-8a31-e1eb095c95de", - "metadata": {}, - "source": [ - "# Compute embeddings and emissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7088001f-97ae-4091-a591-3aad744c7ed2", - "metadata": {}, - "outputs": [], - "source": [ - "from br.models.save_embeddings import save_embeddings\n", - "\n", - "splits_list = [\"test\"]\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = None\n", - "eval_scaled_img = [False] * 5\n", - "\n", - "gt_mesh_dir = MESH_DIR\n", - "gt_sampled_pts_dir = SAMPLES_DIR\n", - "gt_scale_factor_dict_path = SCALE_FACTOR_DIR\n", - "\n", - "eval_scaled_img_params = [\n", - " {\n", - " \"eval_scaled_img_model_type\": \"iae\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": None,\n", - " \"gt_sampled_pts_dir\": gt_sampled_pts_dir,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"sdf\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"seg\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"sdf\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - " {\n", - " \"eval_scaled_img_model_type\": \"seg\",\n", - " \"eval_scaled_img_resolution\": 32,\n", - " \"gt_mesh_dir\": gt_mesh_dir,\n", - " \"gt_scale_factor_dict_path\": gt_scale_factor_dict_path,\n", - " \"gt_sampled_pts_dir\": None,\n", - " \"mesh_ext\": \"stl\",\n", - " },\n", - "]\n", - "loss_eval_list = [torch.nn.MSELoss(reduction=\"none\")] * 5\n", - "sample_points_list = [False] * 5\n", - "skew_scale = None\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2927b452-7b3a-49cc-a8a7-cdd6d213dfac", - "metadata": {}, - "source": [ - "# Latent walks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f28accf-950d-495c-8df7-b5dfb59936d2", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "run_names = [\"Rotation_invariant_pointcloud_SDF\"]\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "all_ret, df = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a80bba01-e9b1-49a2-8f6e-e9ea34e7c942", - "metadata": {}, - "outputs": [], - "source": [ - "save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b62f8cc7-ca76-4c19-96bf-350c2d23f3e9", - "metadata": {}, - "outputs": [], - "source": [ - "cols = [i for i in all_ret.columns if \"mu\" in i]\n", - "feat_cols = [i for i in all_ret.columns if \"str\" in i]\n", - "# feat_cols = ['mem_position_width', 'mem_position_height', 'mem_position_depth_lcc']\n", - "cols = cols + feat_cols\n", - "this_ret = all_ret.loc[all_ret[\"structure_name\"] == \"ST6GAL1\"].reset_index(drop=True)\n", - "pca = PCA(n_components=512)\n", - "pca_features = pca.fit_transform(this_ret[cols].values)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b477503d-8950-4ff2-b507-937c40520584", - "metadata": {}, - "outputs": [], - "source": [ - "for i in feat_cols:\n", - " corr = np.abs(np.corrcoef(pca_features[:, 1], this_ret[i].values)[0, 1])\n", - " if corr > 0.5:\n", - " print(i, corr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d537d96e-7013-41bc-a8b3-825e954b8104", - "metadata": {}, - "outputs": [], - "source": [ - "pca_features.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "58e24687-d019-4696-b86a-e865d7f9e95e", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret[\"structure_name\"].value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8a5582e-574f-41e8-9ea4-736b55869ad2", - "metadata": {}, - "outputs": [], - "source": [ - "[i for i in all_ret.columns if \"path\" in i]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "899a8340-29c8-4f35-a947-6385e8f00918", - "metadata": {}, - "outputs": [], - "source": [ - "import pyvista as pv\n", - "from cyto_dl.image.transforms import RotationMask\n", - "from skimage.io import imread\n", - "from sklearn.decomposition import PCA\n", - "from tqdm import tqdm\n", - "\n", - "from br.data.utils import mesh_seg_model_output\n", - "\n", - "this_save_path = Path(save_path) / Path(\"latent_walks\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "lw_dict = {\"structure_name\": [], \"PC\": [], \"bin\": [], \"CellId\": []}\n", - "\n", - "for struct in all_ret[\"structure_name\"].unique():\n", - " this_sub_m = all_ret.loc[all_ret[\"structure_name\"] == struct].reset_index(drop=True)\n", - " all_features = this_sub_m[[i for i in this_sub_m.columns if \"mu\" in i]].values\n", - " latent_dim = 512\n", - " dim_size = latent_dim\n", - " x_label = \"pcloud\"\n", - " pca = PCA(n_components=dim_size)\n", - " pca_features = pca.fit_transform(all_features)\n", - " pca_std_list = pca_features.std(axis=0)\n", - " for rank in [0, 1]:\n", - " all_xhat = []\n", - " all_closest_real = []\n", - " all_closest_img = []\n", - " latent_walk_range = [-2, 0, 2]\n", - " for value_index, value in enumerate(tqdm(latent_walk_range, total=len(latent_walk_range))):\n", - " z_inf = torch.zeros(1, dim_size)\n", - " z_inf[:, rank] += value * pca_std_list[rank]\n", - " z_inf = pca.inverse_transform(z_inf).numpy()\n", - "\n", - " dist = (all_features - z_inf) ** 2\n", - " dist = np.sum(dist, axis=1)\n", - " closest_idx = np.argmin(dist)\n", - " closest_real_id = this_sub_m.iloc[closest_idx][\"CellId\"]\n", - " print(closest_real_id, struct, rank, value_index)\n", - " mesh = pv.read(\n", - " all_ret.loc[all_ret[\"CellId\"] == closest_real_id][\"mesh_path_noalign\"].iloc[0]\n", - " )\n", - " mesh.save(this_save_path / Path(f\"{struct}_{rank}_{value_index}.ply\"))\n", - "\n", - " lw_dict[\"structure_name\"].append(struct)\n", - " lw_dict[\"PC\"].append(rank)\n", - " lw_dict[\"bin\"].append(value_index)\n", - " lw_dict[\"CellId\"].append(closest_real_id)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1c2ee61-dd0f-4915-a29d-1c3b561a9e91", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "69a3f454-9a19-40bf-bfa2-be5012df68d4", - "metadata": {}, - "outputs": [], - "source": [ - "lw_dict = pd.DataFrame(lw_dict)\n", - "lw_dict.to_csv(this_save_path / \"latent_walk.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e646836-8373-4e3e-a763-58fe35f8d068", - "metadata": {}, - "outputs": [], - "source": [ - "lw_dict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e763ab33-8099-488e-a6e2-906151fe4891", - "metadata": {}, - "outputs": [], - "source": [ - "# num_pieces = 4.0\n", - "struct = \"FBL\"\n", - "rank = 1\n", - "bin_ = 2\n", - "this_mesh_path = this_save_path / Path(f\"{struct}_{rank}_{bin_}.ply\")\n", - "this_mesh_path = \"./\" + str(this_mesh_path)\n", - "\n", - "mitsuba_save_path = this_save_path / Path(\"mitsuba\")\n", - "mitsuba_save_path.mkdir(parents=True, exist_ok=True)\n", - "mitsuba_save_path = \"./\" + str(mitsuba_save_path)\n", - "name = f\"{struct}_{rank}_{bin_}\"\n", - "\n", - "\n", - "plot(str(this_mesh_path), mitsuba_save_path, -130, 0, None, name)" - ] - }, - { - "cell_type": "markdown", - "id": "771c73fb-8be2-44ea-95b4-0bdd22a3a6dc", - "metadata": {}, - "source": [ - "# Archetype" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1b318bba-3e87-4554-8eba-be788650a579", - "metadata": {}, - "outputs": [], - "source": [ - "from br.features.archetype import AA_Fast\n", - "\n", - "n_archetypes = 4\n", - "matrix = all_ret[[i for i in all_ret.columns if \"mu\" in i]].values\n", - "aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)\n", - "\n", - "import pandas as pd\n", - "\n", - "archetypes_df = pd.DataFrame(aa.Z, columns=[f\"mu_{i}\" for i in range(matrix.shape[1])])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6836a9b9-573f-42de-93eb-9c771b86d44d", - "metadata": {}, - "outputs": [], - "source": [ - "archetypes_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b656c3ef-e0b3-45f5-9e34-1482e707a8b5", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path = Path(save_path) / Path(\"archetypes\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "arch_dict = {\"CellId\": [], \"archetype\": []}\n", - "\n", - "all_features = matrix\n", - "for i in range(n_archetypes):\n", - " this_mu = archetypes_df.iloc[i].values\n", - " dist = (all_features - this_mu) ** 2\n", - " dist = np.sum(dist, axis=1)\n", - " closest_idx = np.argmin(dist)\n", - " closest_real_id = all_ret.iloc[closest_idx][\"CellId\"]\n", - " mesh = pv.read(all_ret.loc[all_ret[\"CellId\"] == closest_real_id][\"mesh_path_noalign\"].iloc[0])\n", - " mesh.save(this_save_path / Path(f\"{i}.ply\"))\n", - " arch_dict[\"archetype\"].append(i)\n", - " arch_dict[\"CellId\"].append(closest_real_id)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba64b1fc-d06a-441c-b6ae-e73f3f4e04a8", - "metadata": {}, - "outputs": [], - "source": [ - "arch_dict = pd.DataFrame(arch_dict)\n", - "arch_dict.to_csv(this_save_path / \"archetypes.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dbf9332c-01fb-41e6-9022-80af44d2abaf", - "metadata": {}, - "outputs": [], - "source": [ - "from br.visualization.mitsuba_render_image import plot\n", - "\n", - "# num_pieces = 4.0\n", - "arch = \"3\"\n", - "this_mesh_path = this_save_path / Path(f\"{arch}.ply\")\n", - "this_mesh_path = \"./\" + str(this_mesh_path)\n", - "\n", - "mitsuba_save_path = this_save_path / Path(\"mitsuba\")\n", - "mitsuba_save_path.mkdir(parents=True, exist_ok=True)\n", - "mitsuba_save_path = \"./\" + str(mitsuba_save_path)\n", - "name = f\"{arch}\"\n", - "\n", - "\n", - "plot(str(this_mesh_path), mitsuba_save_path, 10, 0, None, name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "693f1a36-6d2b-4ba3-b017-ef87512b9a6a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/fig7_drugdata_analysis.ipynb b/src/br/notebooks/fig7_drugdata_analysis.ipynb deleted file mode 100644 index 56b0f59..0000000 --- a/src/br/notebooks/fig7_drugdata_analysis.ipynb +++ /dev/null @@ -1,515 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "980c3db8-9252-4e3b-970c-926551219ef1", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-0bb056ed-239d-5614-a667-fd108c1880cf\"\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.models.compute_features import get_embeddings\n", - "from br.models.load_models import get_data_and_models\n", - "from br.models.save_embeddings import get_pc_loss, save_embeddings\n", - "from br.models.utils import get_all_configs_per_dataset\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "markdown", - "id": "83b5a0c6-586c-457b-955e-f344da74cc35", - "metadata": {}, - "source": [ - "# Load data and models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8d47ccb8-db91-4b8e-b7fb-9c112191eb27", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "os.chdir(\"../../benchmarking_representations/\")\n", - "save_path = \"./test_npm1_perturb/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "62dd54d3-6272-4ab5-9853-e1c027f96de2", - "metadata": {}, - "outputs": [], - "source": [ - "# Get datamodules, models, runs, model sizes\n", - "\n", - "dataset_name = \"npm1_perturb\"\n", - "batch_size = 2\n", - "debug = False\n", - "results_path = \"./configs/results/\"\n", - "data_list, all_models, run_names, model_sizes = get_data_and_models(\n", - " dataset_name, batch_size, results_path, debug\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "1feb7e95-5d04-4dd1-9141-48bf1a3b0297", - "metadata": {}, - "source": [ - "# Compute embeddings" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7e8f82b6-ffad-4cf7-942c-efce7bbea688", - "metadata": {}, - "outputs": [], - "source": [ - "# Save embeddings for each model\n", - "\n", - "splits_list = [\"train\", \"val\", \"test\"]\n", - "meta_key = None\n", - "eval_scaled_img = [False] * 5\n", - "eval_scaled_img_params = [{}] * 5\n", - "loss_eval_list = [None] * 5\n", - "sample_points_list = [False] * 5\n", - "skew_scale = None\n", - "save_embeddings(\n", - " save_path,\n", - " data_list,\n", - " all_models,\n", - " run_names,\n", - " debug,\n", - " splits_list,\n", - " device,\n", - " meta_key,\n", - " loss_eval_list,\n", - " sample_points_list,\n", - " skew_scale,\n", - " eval_scaled_img,\n", - " eval_scaled_img_params,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08c4d765-9bc0-4245-b553-c5e21509fa5e", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "5c60eeb0-d0b5-4b1f-90bc-3b6cc8a23890", - "metadata": {}, - "source": [ - "# Get embeddings" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "675b2ac8-5a7f-4059-ba3d-4309c44dc741", - "metadata": {}, - "outputs": [], - "source": [ - "# Load model and embeddings\n", - "DATASET_INFO = get_all_configs_per_dataset(results_path)\n", - "run_names = None\n", - "all_ret, orig = get_embeddings(run_names, dataset_name, DATASET_INFO, save_path)\n", - "all_ret[\"well_position\"] = \"A0\" # dummy\n", - "all_ret[\"Assay_Plate_Barcode\"] = \"Plate0\" # dummy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45143b0e-9e74-4cfe-b177-936e6d5d549d", - "metadata": {}, - "outputs": [], - "source": [ - "all_ret[\"model\"].unique()" - ] - }, - { - "cell_type": "markdown", - "id": "819186dd-1c2d-4992-88ea-cb255bbfa526", - "metadata": {}, - "source": [ - "# mAP and fraction retrieved calculation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8eb47d75-ceb1-4bf1-8f91-600e77eedc2c", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute mAP and fraction retrieved as described in Chandrasekaran 2024\n", - "import pandas as pd\n", - "from tqdm import tqdm\n", - "\n", - "from br.chandrasekaran_et_al import utils\n", - "\n", - "\n", - "def get_featurecols(df):\n", - " \"\"\"returna list of featuredata columns\"\"\"\n", - " return [c for c in df.columns if \"mu\" in c]\n", - "\n", - "\n", - "def get_featuredata(df):\n", - " \"\"\"return dataframe of just featuredata columns\"\"\"\n", - " return df[get_featurecols(df)]\n", - "\n", - "\n", - "cols = [i for i in all_ret.columns if \"mu\" in i]\n", - "\n", - "replicate_feature = \"Metadata_broad_sample\"\n", - "batch_size = 100000\n", - "null_size = 100000\n", - "\n", - "\n", - "all_rep = []\n", - "all_match = []\n", - "all_fr = []\n", - "for model in tqdm(all_ret[\"model\"].unique(), total=len(all_ret[\"model\"].unique())):\n", - " df_feats = all_ret.loc[all_ret[\"model\"] == model].reset_index(drop=True)\n", - " df_feats[\"Metadata_ObjectNumber\"] = df_feats[\"CellId\"]\n", - "\n", - " import pycytominer\n", - "\n", - " all_normalized_df = []\n", - " cols = [i for i in df_feats.columns if \"mu\" in i]\n", - " for plate in df_feats[\"Assay_Plate_Barcode\"].unique():\n", - " test = df_feats.loc[df_feats[\"Assay_Plate_Barcode\"] == plate].reset_index(drop=True)\n", - " # test = test.groupby(['condition_coarse']).mean().reset_index()\n", - " # test['Assay_Plate_Barcode'] = 'plate0'\n", - " # test['well_position'] = 'a0'\n", - " normalized_df = pycytominer.normalize(\n", - " profiles=test,\n", - " features=cols,\n", - " meta_features=[\n", - " \"Assay_Plate_Barcode\",\n", - " \"well_position\",\n", - " \"condition_coarse\",\n", - " \"condition\",\n", - " ],\n", - " method=\"standardize\",\n", - " mad_robustize_epsilon=0,\n", - " samples=\"all\",\n", - " )\n", - " normalized_df = pycytominer.normalize(\n", - " profiles=normalized_df,\n", - " features=cols,\n", - " meta_features=[\n", - " \"Assay_Plate_Barcode\",\n", - " \"well_position\",\n", - " \"condition_coarse\",\n", - " \"condition\",\n", - " ],\n", - " method=\"standardize\",\n", - " samples=\"condition == 'DMSO (control)'\",\n", - " )\n", - "\n", - " all_normalized_df.append(normalized_df)\n", - " df_final = pd.concat(all_normalized_df, axis=0).reset_index(drop=True)\n", - "\n", - " vals = []\n", - " for ind, row in df_final.iterrows():\n", - " if row[\"condition\"] == \"DMSO (control)\":\n", - " vals.append(\"negcon\")\n", - " else:\n", - " vals.append(None)\n", - "\n", - " # more dummy cols\n", - " df_final[\"Metadata_control_type\"] = vals\n", - " df_final[\"Metadata_broad_sample\"] = df_final[\"condition\"]\n", - " df_final[\"Cell_type\"] = \"hIPSc\"\n", - " df_final[\"Perturbation\"] = \"compound\"\n", - " df_final[\"Time\"] = \"1\"\n", - " df_final[\"Metadata_target_list\"] = \"none\"\n", - " df_final[\"target_list\"] = \"none\"\n", - " df_final[\"Metadata_Plate\"] = \"Plate0\"\n", - "\n", - " experiment_df = df_final\n", - "\n", - " replicability_map_df = pd.DataFrame()\n", - " replicability_fr_df = pd.DataFrame()\n", - " matching_map_df = pd.DataFrame()\n", - " matching_fr_df = pd.DataFrame()\n", - " gene_compound_matching_map_df = pd.DataFrame()\n", - " gene_compound_matching_fr_df = pd.DataFrame()\n", - "\n", - " replicate_feature = \"Metadata_broad_sample\"\n", - " for cell in experiment_df.Cell_type.unique():\n", - " cell_df = experiment_df.query(\"Cell_type==@cell\")\n", - " modality_1_perturbation = \"compound\"\n", - " modality_1_experiments_df = cell_df.query(\"Perturbation==@modality_1_perturbation\")\n", - " for modality_1_timepoint in modality_1_experiments_df.Time.unique():\n", - " modality_1_timepoint_df = modality_1_experiments_df.query(\n", - " \"Time==@modality_1_timepoint\"\n", - " )\n", - " modality_1_df = pd.DataFrame()\n", - " for plate in modality_1_timepoint_df.Assay_Plate_Barcode.unique():\n", - " data_df = df_final.loc[df_final[\"Assay_Plate_Barcode\"].isin([plate])]\n", - " data_df = data_df.drop(\n", - " columns=[\"Metadata_target_list\", \"target_list\"]\n", - " ).reset_index(drop=True)\n", - " # data_df = data_df.groupby(['pert_iname']).sample(n=10).reset_index(drop=True)\n", - " modality_1_df = utils.concat_profiles(modality_1_df, data_df)\n", - "\n", - " # Set Metadata_broad_sample value to \"DMSO\" for DMSO wells\n", - " modality_1_df[replicate_feature].fillna(\"DMSO\", inplace=True)\n", - " print(modality_1_df.shape)\n", - "\n", - " # Remove empty wells\n", - " modality_1_df = utils.remove_empty_wells(modality_1_df)\n", - "\n", - " # Description\n", - " description = f\"{modality_1_perturbation}_{cell}_{utils.time_point(modality_1_perturbation, modality_1_timepoint)}\"\n", - "\n", - " modality_1_df[\"Metadata_negcon\"] = np.where(\n", - " modality_1_df[\"Metadata_control_type\"] == \"negcon\", 1, 0\n", - " ) # Create dummy column\n", - "\n", - " pos_sameby = [\"Metadata_broad_sample\"]\n", - " pos_diffby = []\n", - " neg_sameby = [\"Metadata_Plate\"]\n", - " neg_diffby = [\"Metadata_negcon\"]\n", - "\n", - " metadata_df = utils.get_metadata(modality_1_df)\n", - " feature_df = get_featuredata(modality_1_df)\n", - " feature_values = feature_df.values\n", - "\n", - " result = utils.run_pipeline(\n", - " metadata_df,\n", - " feature_values,\n", - " pos_sameby,\n", - " pos_diffby,\n", - " neg_sameby,\n", - " neg_diffby,\n", - " anti_match=False,\n", - " batch_size=batch_size,\n", - " null_size=null_size,\n", - " )\n", - " result = result.query(\"Metadata_negcon==0\").reset_index(drop=True)\n", - "\n", - " qthreshold = 0.001\n", - "\n", - " replicability_map_df, replicability_fr_df = utils.create_replicability_df(\n", - " replicability_map_df,\n", - " replicability_fr_df,\n", - " result,\n", - " pos_sameby,\n", - " qthreshold,\n", - " modality_1_perturbation,\n", - " cell,\n", - " modality_1_timepoint,\n", - " )\n", - " replicability_map_df[\"model\"] = model\n", - " matching_map_df[\"model\"] = model\n", - " replicability_fr_df[\"model\"] = model\n", - " all_rep.append(replicability_map_df)\n", - " all_match.append(matching_map_df)\n", - " all_fr.append(replicability_fr_df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4426f733-3a1b-460f-a19b-5425af0041e7", - "metadata": {}, - "outputs": [], - "source": [ - "all_rep = pd.concat(all_rep, axis=0).reset_index(drop=True)\n", - "all_fr = pd.concat(all_fr, axis=0).reset_index(drop=True)\n", - "\n", - "all_fr[\"metric\"] = \"Fraction retrieved\"\n", - "all_fr[\"value\"] = all_fr[\"fr\"]\n", - "all_rep[\"metric\"] = \"Mean average precision\"\n", - "all_rep[\"value\"] = all_rep[\"mean_average_precision\"]\n", - "metrics = pd.concat([all_fr, all_rep], axis=0).reset_index(drop=True)\n", - "\n", - "plot_df = metrics.loc[metrics[\"metric\"] == \"Fraction retrieved\"].reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53693e25-0ea6-4f88-ab2f-f564266bc245", - "metadata": {}, - "outputs": [], - "source": [ - "rep_dict = {\n", - " \"CNN_sdf_noalign_global\": \"Classical_image_SDF\",\n", - " \"CNN_sdf_SO3_global\": \"SO3_image_SDF\",\n", - " \"CNN_seg_noalign_global\": \"Classical_image_seg\",\n", - " \"CNN_seg_SO3_global\": \"SO3_image_seg\",\n", - " \"vn_so3\": \"SO3_pointcloud_SDF\",\n", - "}\n", - "all_rep[\"model\"] = all_rep[\"model\"].replace(rep_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4aaa174-ae41-4a9f-a060-2a016f05e933", - "metadata": {}, - "outputs": [], - "source": [ - "ordered_drugs = (\n", - " all_rep.groupby([\"Metadata_broad_sample\"])\n", - " .mean()\n", - " .sort_values(by=\"q_value\")\n", - " .reset_index()[\"Metadata_broad_sample\"]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3cd5546d-867d-4a65-a09c-f7240946f42c", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "import seaborn as sns\n", - "\n", - "sns.set_context(\"talk\")\n", - "sns.set(font_scale=1.7)\n", - "sns.set_style(\"white\")\n", - "\n", - "test = all_rep.sort_values(by=\"q_value\").reset_index(drop=True)\n", - "test[\"Drugs\"] = test[\"Metadata_broad_sample\"]\n", - "\n", - "x_order = (\n", - " test.loc[test[\"model\"] == \"SO3_pointcloud_SDF\"]\n", - " .sort_values(by=\"q_value\")[\"Metadata_broad_sample\"]\n", - " .values\n", - ")\n", - "\n", - "x_order = ordered_drugs\n", - "\n", - "g = sns.catplot(\n", - " data=test,\n", - " x=\"Drugs\",\n", - " y=\"q_value\",\n", - " hue=\"model\",\n", - " kind=\"point\",\n", - " order=x_order,\n", - " hue_order=[\n", - " \"Classical_image_seg\",\n", - " \"SO3_image_seg\",\n", - " \"Classical_image_SDF\",\n", - " \"SO3_image_SDF\",\n", - " \"SO3_pointcloud_SDF\",\n", - " ],\n", - " palette=[\"#A6ACE0\", \"#6277DB\", \"#D9978E\", \"#D8553B\", \"#2ED9FF\"],\n", - " aspect=2,\n", - " height=5,\n", - ")\n", - "g.set_xticklabels(rotation=90)\n", - "plt.axhline(y=0.05, color=\"black\")\n", - "this_path = Path(save_path + \"drug_dataset\")\n", - "Path(this_path).mkdir(parents=True, exist_ok=True)\n", - "g.savefig(this_path / \"q_values.pdf\", dpi=300, bbox_inches=\"tight\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1f4e4142-b698-4f47-a17a-73a648191720", - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.read_csv(\n", - " \"/allen/aics/modeling/ritvik/projects/aws_uploads/morphology_appropriate_representation_learning/cellPACK_single_cell_punctate_structure/reference_nuclear_shapes/manifest.csv\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "692979be-80cb-4d45-bcad-bc6546785178", - "metadata": {}, - "outputs": [], - "source": [ - "df[\"nucobj_path\"].iloc[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60580552-ed9c-45fc-90fe-4853f38f0a3b", - "metadata": {}, - "outputs": [], - "source": [ - "df[\"nucobj_path\"] = df[\"nucobj_path\"].apply(\n", - " lambda x: x.replace(\"./morphology_appropriate_representation_learning\", \".\")\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d863644-b8d5-44b2-9f13-9c348fd98256", - "metadata": {}, - "outputs": [], - "source": [ - "df.to_csv(\n", - " \"/allen/aics/modeling/ritvik/projects/aws_uploads/morphology_appropriate_representation_learning/cellPACK_single_cell_punctate_structure/reference_nuclear_shapes/manifest.csv\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b0d6f3e3-6f66-42aa-9e5c-95b85c3f8b90", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From d07414c3e504e49cb439d0cd768dd5fe2e76e411 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Fri, 22 Nov 2024 12:46:16 -0800 Subject: [PATCH 25/35] remvove selected gpu --- src/br/analysis/analysis_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index cf78c99..a6ce490 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -101,7 +101,6 @@ def _setup_gpu(): # Based on the utilization, set the GPU ID # Setting a GPU ID is crucial for the script to work well! selected_gpu_id_or_uuid = config_gpu() - selected_gpu_id_or_uuid = "MIG-ffdee303-0dd4-513d-b18c-beba028b49c7" # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID if selected_gpu_id_or_uuid: From 103b649b499b1bb49ee03aa0a421daff1ddd7b01 Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Mon, 25 Nov 2024 13:07:02 -0800 Subject: [PATCH 26/35] Modified the code to get the GPU ID based on memory utilization --- src/br/analysis/analysis_utils.py | 76 +++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index a6ce490..89468ec 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -37,10 +37,11 @@ def get_gpu_info(): # Run nvidia-smi command and get the output cmd = [ "nvidia-smi", - "--query-gpu=index,uuid,name,utilization.gpu", + "--query-gpu=index,uuid,name,memory.used,memory.total", "--format=csv,noheader,nounits", ] result = subprocess.run(cmd, capture_output=True, text=True) + # print(result) return result.stdout.strip() @@ -51,16 +52,48 @@ def check_mig(): return "MIG" in result.stdout -def get_mig_ids(): - # Get the MIG UUIDs - cmd = ["nvidia-smi", "-L"] - result = subprocess.run(cmd, capture_output=True, text=True) - mig_ids = [] - for line in result.stdout.splitlines(): - if "MIG" in line: - mig_id = line.split("(UUID: ")[-1].strip(")") - mig_ids.append(mig_id) - return mig_ids +def get_mig_ids(gpu_uuid): + try: + # Get the list of GPUs + output = subprocess.check_output(['nvidia-smi','--query-gpu=,index,uuid' ,'--format=csv,noheader']).decode('utf-8').strip().split('\n') + + # Find the index of the specified GPU UUID + gpu_index = -1 + for i, line in enumerate(output): + if gpu_uuid in line: + gpu_index = i + break + + if gpu_index == -1: + print(f"GPU UUID {gpu_uuid} not found.") + return [] + + # Now we need to get the MIG IDs for this GPU + mig_ids = [] + # Run nvidia-smi command to get detailed information including MIG IDs + detailed_output = subprocess.check_output(['nvidia-smi', '-L']).decode('utf-8').strip().split('\n') + + # Flag to determine if we are in the right GPU section + in_gpu_section = False + for line in detailed_output: + if f"GPU {gpu_index}:" in line: # Adjusted to check for the specific GPU section + in_gpu_section = True + elif "GPU" in line and in_gpu_section: # Encounter another GPU section + break + + # print(line) + + if in_gpu_section: + # Check for MIG devices + if "MIG" in line: + mig_id = line.split('(')[1].split(')')[0].split(' ')[-1] # Assuming format is '.... MIG (UUID) ...' + mig_ids.append(mig_id.strip()) + + return mig_ids + + except subprocess.CalledProcessError as e: + print(f"An error occurred: {e}") + return [] def config_gpu(): @@ -71,18 +104,15 @@ def config_gpu(): lines = gpu_info.splitlines() for line in lines: - index, uuid, name, utilization = map(str.strip, line.split(",")) - - # If utilization is [N/A], treat it as less than 10 - if utilization == "[N/A]": - utilization = -1 # Assign a value less than 10 to simulate "idle" - else: - utilization = int(utilization) - - # Check if GPU utilization is under 10% (indicating it's idle) - if utilization < 10: + index, uuid, name, mem_used, mem_total = map(str.strip, line.split(",")) + utilization = float(mem_used)*100/float(mem_total) + + # Check if GPU utilization is under 20% (indicating it's idle) + if utilization < 20: + # print(uuid, utilization) if is_mig: - mig_ids = get_mig_ids() + mig_ids = get_mig_ids(uuid) + if mig_ids: selected_gpu_id_or_uuid = mig_ids[0] # Select the first MIG ID break # Exit the loop after finding the first MIG ID @@ -99,7 +129,7 @@ def _setup_gpu(): torch.cuda.empty_cache() # Based on the utilization, set the GPU ID - # Setting a GPU ID is crucial for the script to work well! + # Setting a GPU ID is crucial for the script to work! selected_gpu_id_or_uuid = config_gpu() # Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID From 4d42ed97b70a8c2256f81d8bb5fea791f3024d07 Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Mon, 25 Nov 2024 13:32:32 -0800 Subject: [PATCH 27/35] Updated the paths in configs (merged from main) --- .../npm1_variance/image_sdf_classical.yaml | 54 ++++++++++++++++++ .../npm1_variance/image_sdf_so3.yaml | 55 +++++++++++++++++++ .../npm1_variance/image_seg_classical.yaml | 54 ++++++++++++++++++ .../npm1_variance/image_seg_so3.yaml | 55 +++++++++++++++++++ .../experiment/npm1_variance/pc_implicit.yaml | 54 ++++++++++++++++++ configs/logger/csv.yaml | 2 +- configs/results/cellpack.yaml | 16 +++--- configs/results/npm1.yaml | 13 ++--- configs/results/npm1_perturb.yaml | 10 ++-- configs/results/other_polymorphic.yaml | 12 ++-- configs/results/other_punctate.yaml | 13 +++-- configs/results/pcna.yaml | 20 +++---- 12 files changed, 311 insertions(+), 47 deletions(-) create mode 100644 configs/experiment/npm1_variance/image_sdf_classical.yaml create mode 100644 configs/experiment/npm1_variance/image_sdf_so3.yaml create mode 100644 configs/experiment/npm1_variance/image_seg_classical.yaml create mode 100644 configs/experiment/npm1_variance/image_seg_so3.yaml create mode 100644 configs/experiment/npm1_variance/pc_implicit.yaml diff --git a/configs/experiment/npm1_variance/image_sdf_classical.yaml b/configs/experiment/npm1_variance/image_sdf_classical.yaml new file mode 100644 index 0000000..360d017 --- /dev/null +++ b/configs/experiment/npm1_variance/image_sdf_classical.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: other_polymorphic/classical_image_sdf.yaml + - override /model: image/classical_sdf_35.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +experiment_name: npm1_variance +tags: ["equivariance"] + +seed: 42 + +data: + batch_size: 4 + +model: + x_label: image + +trainer: + check_val_every_n_epoch: 1 + min_epochs: 400 + max_epochs: 2000 + accelerator: gpu + devices: [0] + +callbacks: + early_stopping: + monitor: val/loss + + model_checkpoint: + dirpath: ./npm1_variance/ckpts + monitor: val/loss + save_top_k: 2 + every_n_epochs: 1 + +logger: + csv: + save_dir: ./npm1_variance + name: "classical_sdf" + prefix: + +##### ONLY USE WITH A100s +extras: + precision: + _target_: torch.set_float32_matmul_precision + precision: medium diff --git a/configs/experiment/npm1_variance/image_sdf_so3.yaml b/configs/experiment/npm1_variance/image_sdf_so3.yaml new file mode 100644 index 0000000..41d3805 --- /dev/null +++ b/configs/experiment/npm1_variance/image_sdf_so3.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: npm1/classical_image_sdf.yaml + - override /model: image/classical_sdf_35.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +experiment_name: npm1_variance +tags: ["equivariance"] + +seed: 42 + +data: + batch_size: 4 + +model: + x_label: image + group: so3 + +trainer: + check_val_every_n_epoch: 1 + min_epochs: 400 + max_epochs: 2000 + accelerator: gpu + devices: [0] + +callbacks: + early_stopping: + monitor: val/loss + + model_checkpoint: + dirpath: ./npm1_variance/ckpts + monitor: val/loss + save_top_k: 2 + every_n_epochs: 1 + +logger: + csv: + save_dir: ./npm1_variance + name: "so3_sdf" + prefix: + +##### ONLY USE WITH A100s +extras: + precision: + _target_: torch.set_float32_matmul_precision + precision: medium diff --git a/configs/experiment/npm1_variance/image_seg_classical.yaml b/configs/experiment/npm1_variance/image_seg_classical.yaml new file mode 100644 index 0000000..825db57 --- /dev/null +++ b/configs/experiment/npm1_variance/image_seg_classical.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: npm1/classical_image_seg.yaml + - override /model: image/classical_seg_35.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +experiment_name: npm1_variance +tags: ["equivariance"] + +seed: 42 + +data: + batch_size: 4 + +model: + x_label: image + +trainer: + check_val_every_n_epoch: 1 + min_epochs: 400 + max_epochs: 2000 + accelerator: gpu + devices: [0] + +callbacks: + early_stopping: + monitor: val/loss + + model_checkpoint: + dirpath: ./npm1_variance/ckpts + monitor: val/loss + save_top_k: 2 + every_n_epochs: 1 + +logger: + csv: + save_dir: ./npm1_variance + name: "classical_seg" + prefix: + +##### ONLY USE WITH A100s +extras: + precision: + _target_: torch.set_float32_matmul_precision + precision: medium diff --git a/configs/experiment/npm1_variance/image_seg_so3.yaml b/configs/experiment/npm1_variance/image_seg_so3.yaml new file mode 100644 index 0000000..aef8ada --- /dev/null +++ b/configs/experiment/npm1_variance/image_seg_so3.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: npm1/classical_image_seg.yaml + - override /model: image/classical_seg_35.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +experiment_name: npm1_variance +tags: ["equivariance"] + +seed: 42 + +data: + batch_size: 4 + +model: + x_label: image + group: so3 + +trainer: + check_val_every_n_epoch: 1 + min_epochs: 400 + max_epochs: 2000 + accelerator: gpu + devices: [0] + +callbacks: + early_stopping: + monitor: val/loss + + model_checkpoint: + dirpath: ./npm1_variance/ckpts + monitor: val/loss + save_top_k: 2 + every_n_epochs: 1 + +logger: + csv: + save_dir: ./npm1_variance + name: "so3_seg" + prefix: + +##### ONLY USE WITH A100s +extras: + precision: + _target_: torch.set_float32_matmul_precision + precision: medium diff --git a/configs/experiment/npm1_variance/pc_implicit.yaml b/configs/experiment/npm1_variance/pc_implicit.yaml new file mode 100644 index 0000000..fb4df04 --- /dev/null +++ b/configs/experiment/npm1_variance/pc_implicit.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: npm1/pc.yaml + - override /model: pc/implicit.yaml + - override /callbacks: default.yaml + - override /trainer: default.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +experiment_name: npm1_variance +tags: ["equivariance"] + +seed: 42 + +data: + batch_size: 4 + +model: + x_label: pcloud + +trainer: + check_val_every_n_epoch: 1 + min_epochs: 400 + max_epochs: 2000 + accelerator: gpu + devices: [0] + +callbacks: + early_stopping: + monitor: val/loss + + model_checkpoint: + dirpath: ./npm1_variance/ckpts + monitor: val/loss + save_top_k: 2 + every_n_epochs: 1 + +logger: + csv: + save_dir: ./npm1_variance + name: "pc_implicit" + prefix: + +##### ONLY USE WITH A100s +extras: + precision: + _target_: torch.set_float32_matmul_precision + precision: medium diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml index fa028e9..c524e13 100644 --- a/configs/logger/csv.yaml +++ b/configs/logger/csv.yaml @@ -4,4 +4,4 @@ csv: _target_: lightning.pytorch.loggers.csv_logs.CSVLogger save_dir: "${paths.output_dir}" name: "csv/" - prefix: "" + prefix: "" \ No newline at end of file diff --git a/configs/results/cellpack.yaml b/configs/results/cellpack.yaml index 67cd66e..ed588ed 100644 --- a/configs/results/cellpack.yaml +++ b/configs/results/cellpack.yaml @@ -17,13 +17,11 @@ names: "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_jitter", ] -data_paths: - [ - "/data/cellpack/image.yaml", - "/data/cellpack/image.yaml", - "/data/cellpack/pc.yaml", - "/data/cellpack/pc.yaml", - "/data/cellpack/pc.yaml", +data_paths: [ + "./configs/data/cellpack/image.yaml", + "./configs/data/cellpack/image.yaml", + "./configs/data/cellpack/pc.yaml", + "./configs/data/cellpack/pc.yaml", + # "./src/br/configs/data/cellpack/pc_jitter.yaml", + "./configs/data/cellpack/pc.yaml", ] -classification_label: ["rule"] -regression_label: diff --git a/configs/results/npm1.yaml b/configs/results/npm1.yaml index d2c2dd0..97c2f9c 100644 --- a/configs/results/npm1.yaml +++ b/configs/results/npm1.yaml @@ -19,12 +19,9 @@ names: ] data_paths: [ - "/data/npm1/pc.yaml", - "/data/npm1/so3_image_sdf.yaml", - "/data/npm1/so3_image_seg.yaml", - "/data/npm1/classical_image_sdf.yaml", - "/data/npm1/classical_image_seg.yaml", + "./configs/data/npm1/pc.yaml", + "./configs/data/npm1/so3_image_sdf.yaml", + "./configs/data/npm1/so3_image_seg.yaml", + "./configs/data/npm1/classical_image_sdf.yaml", + "./configs/data/npm1/classical_image_seg.yaml", ] -classification_label: ["STR_connectivity_cc_thresh"] -regression_label: - ["mean_centroid_distances", "mean_nucleolus_volume", "mean_nucleolus_area"] diff --git a/configs/results/npm1_perturb.yaml b/configs/results/npm1_perturb.yaml index 492e61b..5a5a5ea 100644 --- a/configs/results/npm1_perturb.yaml +++ b/configs/results/npm1_perturb.yaml @@ -19,9 +19,9 @@ names: ] data_paths: [ - "/data/npm1_perturb/pc.yaml", - "/data/npm1_perturb/classical_image_sdf.yaml", - "/data/npm1_perturb/classical_image_seg.yaml", - "/data/npm1_perturb/so3_image_sdf.yaml", - "/data/npm1_perturb/so3_image_seg.yaml", + "./configs/data/npm1_perturb/pc.yaml", + "./configs/data/npm1_perturb/classical_image_sdf.yaml", + "./configs/data/npm1_perturb/classical_image_seg.yaml", + "./configs/data/npm1_perturb/so3_image_sdf.yaml", + "./configs/data/npm1_perturb/so3_image_seg.yaml", ] diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index cccc773..2b4319f 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -19,11 +19,9 @@ names: ] data_paths: [ - "/data/other_polymorphic/pc.yaml", - "/data/other_polymorphic/so3_image_sdf.yaml", - "/data/other_polymorphic/so3_image_seg.yaml", - "/data/other_polymorphic/classical_image_sdf.yaml", - "/data/other_polymorphic/classical_image_seg.yaml", + "./configs/data/other_polymorphic/pc.yaml", + "./configs/data/other_polymorphic/so3_image_sdf.yaml", + "./configs/data/other_polymorphic/so3_image_seg.yaml", + "./configs/data/other_polymorphic/classical_image_sdf.yaml", + "./configs/data/other_polymorphic/classical_image_seg.yaml", ] -classification_label: ["structure_name"] -regression_label: ["avg_dists", "mean_volume", "mean_surface_area"] diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index 3ea855b..da1f528 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -6,6 +6,7 @@ model_checkpoints: "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_pointcloud.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud_structurenorm.ckpt", ] names: @@ -13,14 +14,14 @@ names: "Classical_image", "Rotation_invariant_image", "Classical_pointcloud", + "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_structurenorm", ] data_paths: [ - "/data/other_punctate/image.yaml", - "/data/other_punctate/image.yaml", - "/data/other_punctate/pc.yaml", - "/data/other_punctate/pc_intensity_structurenorm.yaml", + "./configs/data/other_punctate/image.yaml", + "./configs/data/other_punctate/image.yaml", + "./configs/data/other_punctate/pc.yaml", + "./configs/data/other_punctate/pc_intensity.yaml", + "./configs/data/other_punctate/pc_intensity_structurenorm.yaml", ] -classification_label: ["structure_name", "cell_stage"] -regression_label: diff --git a/configs/results/pcna.yaml b/configs/results/pcna.yaml index e481b22..9c4f184 100644 --- a/configs/results/pcna.yaml +++ b/configs/results/pcna.yaml @@ -3,27 +3,25 @@ image_path: ./morphology_appropriate_representation_learning/preprocessed_data/p pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/pcna/manifest.csv model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_image.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_image.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud_jitter.ckpt", ] names: [ - "Classical_image", - "Rotation_invariant_image", "Classical_pointcloud", "Rotation_invariant_pointcloud", + "Classical_image", + "Rotation_invariant_image", "Rotation_invariant_pointcloud_jitter", ] data_paths: [ - "/data/pcna/image.yaml", - "/data/pcna/image.yaml", - "/data/pcna/pc.yaml", - "/data/pcna/pc_intensity.yaml", + "./configs/data/pcna/pc.yaml", + "./configs/data/pcna/pc_intensity.yaml", + "./configs/data/pcna/image.yaml", + "./configs/data/pcna/image.yaml", # "./src/br/configs/data/pcna/pc_intensity_jitter.yaml", - "/data/pcna/pc_intensity.yaml", + "./configs/data/pcna/pc_intensity.yaml", ] -classification_label: ["cell_stage_fine", "flag_comment"] -regression_label: From 35e7977c33baf3d8df8a8844f7f6802f66abf7ba Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Mon, 25 Nov 2024 13:41:59 -0800 Subject: [PATCH 28/35] Revert "Updated the paths in configs (merged from main)" This reverts commit 4d42ed97b70a8c2256f81d8bb5fea791f3024d07. --- .../npm1_variance/image_sdf_classical.yaml | 54 ------------------ .../npm1_variance/image_sdf_so3.yaml | 55 ------------------- .../npm1_variance/image_seg_classical.yaml | 54 ------------------ .../npm1_variance/image_seg_so3.yaml | 55 ------------------- .../experiment/npm1_variance/pc_implicit.yaml | 54 ------------------ configs/logger/csv.yaml | 2 +- configs/results/cellpack.yaml | 16 +++--- configs/results/npm1.yaml | 13 +++-- configs/results/npm1_perturb.yaml | 10 ++-- configs/results/other_polymorphic.yaml | 12 ++-- configs/results/other_punctate.yaml | 13 ++--- configs/results/pcna.yaml | 20 ++++--- 12 files changed, 47 insertions(+), 311 deletions(-) delete mode 100644 configs/experiment/npm1_variance/image_sdf_classical.yaml delete mode 100644 configs/experiment/npm1_variance/image_sdf_so3.yaml delete mode 100644 configs/experiment/npm1_variance/image_seg_classical.yaml delete mode 100644 configs/experiment/npm1_variance/image_seg_so3.yaml delete mode 100644 configs/experiment/npm1_variance/pc_implicit.yaml diff --git a/configs/experiment/npm1_variance/image_sdf_classical.yaml b/configs/experiment/npm1_variance/image_sdf_classical.yaml deleted file mode 100644 index 360d017..0000000 --- a/configs/experiment/npm1_variance/image_sdf_classical.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: other_polymorphic/classical_image_sdf.yaml - - override /model: image/classical_sdf_35.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - - override /logger: csv.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -experiment_name: npm1_variance -tags: ["equivariance"] - -seed: 42 - -data: - batch_size: 4 - -model: - x_label: image - -trainer: - check_val_every_n_epoch: 1 - min_epochs: 400 - max_epochs: 2000 - accelerator: gpu - devices: [0] - -callbacks: - early_stopping: - monitor: val/loss - - model_checkpoint: - dirpath: ./npm1_variance/ckpts - monitor: val/loss - save_top_k: 2 - every_n_epochs: 1 - -logger: - csv: - save_dir: ./npm1_variance - name: "classical_sdf" - prefix: - -##### ONLY USE WITH A100s -extras: - precision: - _target_: torch.set_float32_matmul_precision - precision: medium diff --git a/configs/experiment/npm1_variance/image_sdf_so3.yaml b/configs/experiment/npm1_variance/image_sdf_so3.yaml deleted file mode 100644 index 41d3805..0000000 --- a/configs/experiment/npm1_variance/image_sdf_so3.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: npm1/classical_image_sdf.yaml - - override /model: image/classical_sdf_35.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - - override /logger: csv.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -experiment_name: npm1_variance -tags: ["equivariance"] - -seed: 42 - -data: - batch_size: 4 - -model: - x_label: image - group: so3 - -trainer: - check_val_every_n_epoch: 1 - min_epochs: 400 - max_epochs: 2000 - accelerator: gpu - devices: [0] - -callbacks: - early_stopping: - monitor: val/loss - - model_checkpoint: - dirpath: ./npm1_variance/ckpts - monitor: val/loss - save_top_k: 2 - every_n_epochs: 1 - -logger: - csv: - save_dir: ./npm1_variance - name: "so3_sdf" - prefix: - -##### ONLY USE WITH A100s -extras: - precision: - _target_: torch.set_float32_matmul_precision - precision: medium diff --git a/configs/experiment/npm1_variance/image_seg_classical.yaml b/configs/experiment/npm1_variance/image_seg_classical.yaml deleted file mode 100644 index 825db57..0000000 --- a/configs/experiment/npm1_variance/image_seg_classical.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: npm1/classical_image_seg.yaml - - override /model: image/classical_seg_35.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - - override /logger: csv.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -experiment_name: npm1_variance -tags: ["equivariance"] - -seed: 42 - -data: - batch_size: 4 - -model: - x_label: image - -trainer: - check_val_every_n_epoch: 1 - min_epochs: 400 - max_epochs: 2000 - accelerator: gpu - devices: [0] - -callbacks: - early_stopping: - monitor: val/loss - - model_checkpoint: - dirpath: ./npm1_variance/ckpts - monitor: val/loss - save_top_k: 2 - every_n_epochs: 1 - -logger: - csv: - save_dir: ./npm1_variance - name: "classical_seg" - prefix: - -##### ONLY USE WITH A100s -extras: - precision: - _target_: torch.set_float32_matmul_precision - precision: medium diff --git a/configs/experiment/npm1_variance/image_seg_so3.yaml b/configs/experiment/npm1_variance/image_seg_so3.yaml deleted file mode 100644 index aef8ada..0000000 --- a/configs/experiment/npm1_variance/image_seg_so3.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: npm1/classical_image_seg.yaml - - override /model: image/classical_seg_35.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - - override /logger: csv.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -experiment_name: npm1_variance -tags: ["equivariance"] - -seed: 42 - -data: - batch_size: 4 - -model: - x_label: image - group: so3 - -trainer: - check_val_every_n_epoch: 1 - min_epochs: 400 - max_epochs: 2000 - accelerator: gpu - devices: [0] - -callbacks: - early_stopping: - monitor: val/loss - - model_checkpoint: - dirpath: ./npm1_variance/ckpts - monitor: val/loss - save_top_k: 2 - every_n_epochs: 1 - -logger: - csv: - save_dir: ./npm1_variance - name: "so3_seg" - prefix: - -##### ONLY USE WITH A100s -extras: - precision: - _target_: torch.set_float32_matmul_precision - precision: medium diff --git a/configs/experiment/npm1_variance/pc_implicit.yaml b/configs/experiment/npm1_variance/pc_implicit.yaml deleted file mode 100644 index fb4df04..0000000 --- a/configs/experiment/npm1_variance/pc_implicit.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: npm1/pc.yaml - - override /model: pc/implicit.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - - override /logger: csv.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -experiment_name: npm1_variance -tags: ["equivariance"] - -seed: 42 - -data: - batch_size: 4 - -model: - x_label: pcloud - -trainer: - check_val_every_n_epoch: 1 - min_epochs: 400 - max_epochs: 2000 - accelerator: gpu - devices: [0] - -callbacks: - early_stopping: - monitor: val/loss - - model_checkpoint: - dirpath: ./npm1_variance/ckpts - monitor: val/loss - save_top_k: 2 - every_n_epochs: 1 - -logger: - csv: - save_dir: ./npm1_variance - name: "pc_implicit" - prefix: - -##### ONLY USE WITH A100s -extras: - precision: - _target_: torch.set_float32_matmul_precision - precision: medium diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml index c524e13..fa028e9 100644 --- a/configs/logger/csv.yaml +++ b/configs/logger/csv.yaml @@ -4,4 +4,4 @@ csv: _target_: lightning.pytorch.loggers.csv_logs.CSVLogger save_dir: "${paths.output_dir}" name: "csv/" - prefix: "" \ No newline at end of file + prefix: "" diff --git a/configs/results/cellpack.yaml b/configs/results/cellpack.yaml index ed588ed..67cd66e 100644 --- a/configs/results/cellpack.yaml +++ b/configs/results/cellpack.yaml @@ -17,11 +17,13 @@ names: "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_jitter", ] -data_paths: [ - "./configs/data/cellpack/image.yaml", - "./configs/data/cellpack/image.yaml", - "./configs/data/cellpack/pc.yaml", - "./configs/data/cellpack/pc.yaml", - # "./src/br/configs/data/cellpack/pc_jitter.yaml", - "./configs/data/cellpack/pc.yaml", +data_paths: + [ + "/data/cellpack/image.yaml", + "/data/cellpack/image.yaml", + "/data/cellpack/pc.yaml", + "/data/cellpack/pc.yaml", + "/data/cellpack/pc.yaml", ] +classification_label: ["rule"] +regression_label: diff --git a/configs/results/npm1.yaml b/configs/results/npm1.yaml index 97c2f9c..d2c2dd0 100644 --- a/configs/results/npm1.yaml +++ b/configs/results/npm1.yaml @@ -19,9 +19,12 @@ names: ] data_paths: [ - "./configs/data/npm1/pc.yaml", - "./configs/data/npm1/so3_image_sdf.yaml", - "./configs/data/npm1/so3_image_seg.yaml", - "./configs/data/npm1/classical_image_sdf.yaml", - "./configs/data/npm1/classical_image_seg.yaml", + "/data/npm1/pc.yaml", + "/data/npm1/so3_image_sdf.yaml", + "/data/npm1/so3_image_seg.yaml", + "/data/npm1/classical_image_sdf.yaml", + "/data/npm1/classical_image_seg.yaml", ] +classification_label: ["STR_connectivity_cc_thresh"] +regression_label: + ["mean_centroid_distances", "mean_nucleolus_volume", "mean_nucleolus_area"] diff --git a/configs/results/npm1_perturb.yaml b/configs/results/npm1_perturb.yaml index 5a5a5ea..492e61b 100644 --- a/configs/results/npm1_perturb.yaml +++ b/configs/results/npm1_perturb.yaml @@ -19,9 +19,9 @@ names: ] data_paths: [ - "./configs/data/npm1_perturb/pc.yaml", - "./configs/data/npm1_perturb/classical_image_sdf.yaml", - "./configs/data/npm1_perturb/classical_image_seg.yaml", - "./configs/data/npm1_perturb/so3_image_sdf.yaml", - "./configs/data/npm1_perturb/so3_image_seg.yaml", + "/data/npm1_perturb/pc.yaml", + "/data/npm1_perturb/classical_image_sdf.yaml", + "/data/npm1_perturb/classical_image_seg.yaml", + "/data/npm1_perturb/so3_image_sdf.yaml", + "/data/npm1_perturb/so3_image_seg.yaml", ] diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index 2b4319f..cccc773 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -19,9 +19,11 @@ names: ] data_paths: [ - "./configs/data/other_polymorphic/pc.yaml", - "./configs/data/other_polymorphic/so3_image_sdf.yaml", - "./configs/data/other_polymorphic/so3_image_seg.yaml", - "./configs/data/other_polymorphic/classical_image_sdf.yaml", - "./configs/data/other_polymorphic/classical_image_seg.yaml", + "/data/other_polymorphic/pc.yaml", + "/data/other_polymorphic/so3_image_sdf.yaml", + "/data/other_polymorphic/so3_image_seg.yaml", + "/data/other_polymorphic/classical_image_sdf.yaml", + "/data/other_polymorphic/classical_image_seg.yaml", ] +classification_label: ["structure_name"] +regression_label: ["avg_dists", "mean_volume", "mean_surface_area"] diff --git a/configs/results/other_punctate.yaml b/configs/results/other_punctate.yaml index da1f528..3ea855b 100644 --- a/configs/results/other_punctate.yaml +++ b/configs/results/other_punctate.yaml @@ -6,7 +6,6 @@ model_checkpoints: "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Classical_pointcloud.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/other_punctate/Rotation_invariant_pointcloud_structurenorm.ckpt", ] names: @@ -14,14 +13,14 @@ names: "Classical_image", "Rotation_invariant_image", "Classical_pointcloud", - "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_structurenorm", ] data_paths: [ - "./configs/data/other_punctate/image.yaml", - "./configs/data/other_punctate/image.yaml", - "./configs/data/other_punctate/pc.yaml", - "./configs/data/other_punctate/pc_intensity.yaml", - "./configs/data/other_punctate/pc_intensity_structurenorm.yaml", + "/data/other_punctate/image.yaml", + "/data/other_punctate/image.yaml", + "/data/other_punctate/pc.yaml", + "/data/other_punctate/pc_intensity_structurenorm.yaml", ] +classification_label: ["structure_name", "cell_stage"] +regression_label: diff --git a/configs/results/pcna.yaml b/configs/results/pcna.yaml index 9c4f184..e481b22 100644 --- a/configs/results/pcna.yaml +++ b/configs/results/pcna.yaml @@ -3,25 +3,27 @@ image_path: ./morphology_appropriate_representation_learning/preprocessed_data/p pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/pcna/manifest.csv model_checkpoints: [ - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_pointcloud.ckpt", - "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_image.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_image.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Classical_pointcloud.ckpt", + "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud.ckpt", "./morphology_appropriate_representation_learning/model_checkpoints/pcna/Rotation_invariant_pointcloud_jitter.ckpt", ] names: [ - "Classical_pointcloud", - "Rotation_invariant_pointcloud", "Classical_image", "Rotation_invariant_image", + "Classical_pointcloud", + "Rotation_invariant_pointcloud", "Rotation_invariant_pointcloud_jitter", ] data_paths: [ - "./configs/data/pcna/pc.yaml", - "./configs/data/pcna/pc_intensity.yaml", - "./configs/data/pcna/image.yaml", - "./configs/data/pcna/image.yaml", + "/data/pcna/image.yaml", + "/data/pcna/image.yaml", + "/data/pcna/pc.yaml", + "/data/pcna/pc_intensity.yaml", # "./src/br/configs/data/pcna/pc_intensity_jitter.yaml", - "./configs/data/pcna/pc_intensity.yaml", + "/data/pcna/pc_intensity.yaml", ] +classification_label: ["cell_stage_fine", "flag_comment"] +regression_label: From 6ec5819534b14184e13da3d211688c629e29437a Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Mon, 25 Nov 2024 16:49:46 -0800 Subject: [PATCH 29/35] Modified the compute_evolve_dataloaders function to make this work for the cellpack dataset --- src/br/analysis/analysis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 89468ec..2b1f5c2 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -196,7 +196,7 @@ def _setup_evolve_params(run_names, data_config_list, keys): """Set up dataloader parameters specific to the evolution energy metric.""" eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) - compute_evolve_dataloaders = True + compute_evolve_dataloaders = dataset_name != "cellpack" if "SDF" in "\t".join(run_names): eval_meshed_img = [True] * len(run_names) eval_meshed_img_model_type = [] From dbb20b480026918b08d521a486b455d6451d2969 Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Mon, 25 Nov 2024 16:56:41 -0800 Subject: [PATCH 30/35] Revert "Modified the compute_evolve_dataloaders function to make this work for the cellpack dataset" This reverts commit 6ec5819534b14184e13da3d211688c629e29437a. --- src/br/analysis/analysis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 2b1f5c2..89468ec 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -196,7 +196,7 @@ def _setup_evolve_params(run_names, data_config_list, keys): """Set up dataloader parameters specific to the evolution energy metric.""" eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) - compute_evolve_dataloaders = dataset_name != "cellpack" + compute_evolve_dataloaders = True if "SDF" in "\t".join(run_names): eval_meshed_img = [True] * len(run_names) eval_meshed_img_model_type = [] From d08ac3539c30fd41b5e50dbc8c16d515f6dcd364 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Tue, 26 Nov 2024 11:05:10 -0800 Subject: [PATCH 31/35] remove leading underscores --- src/br/analysis/analysis_utils.py | 24 +-- src/br/analysis/prereq.py | 344 ------------------------------ src/br/analysis/run_analysis.py | 28 +-- src/br/analysis/run_embeddings.py | 6 +- src/br/analysis/run_features.py | 12 +- 5 files changed, 35 insertions(+), 379 deletions(-) delete mode 100644 src/br/analysis/prereq.py diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 89468ec..adf7826 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -123,7 +123,7 @@ def config_gpu(): return selected_gpu_id_or_uuid -def _setup_gpu(): +def setup_gpu(): # Free up cache gc.collect() torch.cuda.empty_cache() @@ -141,7 +141,7 @@ def _setup_gpu(): print("No suitable GPU or MIG ID found. Exiting...") -def _setup_evaluation_params(manifest, run_names): +def setup_evaluation_params(manifest, run_names): """Return evaluation params related to. 1. loss_eval_list - which loss to use for each model (Defaults to Chamfer loss) @@ -192,7 +192,7 @@ def _setup_evaluation_params(manifest, run_names): return eval_scaled_img, eval_scaled_img_params, loss_eval_list, sample_points_list, skew_scale -def _setup_evolve_params(run_names, data_config_list, keys): +def setup_evolve_params(run_names, data_config_list, keys): """Set up dataloader parameters specific to the evolution energy metric.""" eval_meshed_img = [False] * len(run_names) eval_meshed_img_model_type = [None] * len(run_names) @@ -224,7 +224,7 @@ def _setup_evolve_params(run_names, data_config_list, keys): return evolve_params -def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): +def get_feature_params(results_path, dataset_name, manifest, keys, run_names): """ Get parameters associated with calculation of 1. Rot invariance @@ -262,7 +262,7 @@ def _get_feature_params(results_path, dataset_name, manifest, keys, run_names): ) -def _dataset_specific_subsetting(all_ret, dataset_name): +def dataset_specific_subsetting(all_ret, dataset_name): """Subset each dataset for analysis. E.g. For PCNA dataset, only look at interphase. Also specify dataset specific visualization params. @@ -318,7 +318,7 @@ def _dataset_specific_subsetting(all_ret, dataset_name): return all_ret, stratify_key, n_archetypes, viz_params -def _viz_other_punctate(this_save_path, viz_params, stratify_key): +def viz_other_punctate(this_save_path, viz_params, stratify_key): # Norms based on Viana 2023 # norms used for model training model_norms = "./src/br/data/preprocessing/pc_preprocessing/model_structnorms.yaml" @@ -376,7 +376,7 @@ def _viz_other_punctate(this_save_path, viz_params, stratify_key): cmap = plt.get_cmap("YlGnBu") -def _latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_name): +def latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_name): """Visualize saved latent walks from csvs. this_save_path - folder where csvs are saved @@ -384,7 +384,7 @@ def _latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_n viz_params - parameters associated with visualization (e.g. xlims, ylims) """ if dataset_name == "other_punctate": - return _viz_other_punctate(this_save_path, viz_params, stratify_key) + return viz_other_punctate(this_save_path, viz_params, stratify_key) items = os.listdir(this_save_path) fnames = [i for i in items if i.split(".")[-1] == "csv"] # get csvs @@ -432,7 +432,7 @@ def _latent_walk_save_recons(this_save_path, stratify_key, viz_params, dataset_n np.save(this_save_path / Path(f"{this_name}.npy"), np_arr) -def _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path): +def archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path): """Visualize saved archetypes from archetype matrix dataframe.""" all_xhat = [] with torch.no_grad(): @@ -480,7 +480,7 @@ def _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_ np.save(this_save_path / Path(f"{arch}.npy"), np_arr) -def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bins=None): +def pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bins=None): """Psuedotime analysis for PCNA and NPM1 dataset.""" if not bins: # Pseudotime bins based on npm1 dataset from WTC-11 hIPS single cell image dataset @@ -544,7 +544,7 @@ def _pseudo_time_analysis(model, all_ret, save_path, device, key, viz_params, bi ) -def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): +def latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): lw_dict = {stratify_key: [], "PC": [], "bin": [], "CellId": []} mesh_folder = all_ret["mesh_folder"].iloc[0] # mesh folder for strat in all_ret[stratify_key].unique(): @@ -578,7 +578,7 @@ def _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim): lw_dict.to_csv(this_save_path / "latent_walk.csv") -def _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features): +def archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features): arch_dict = {"CellId": [], "archetype": []} mesh_folder = all_ret["mesh_folder"].iloc[0] # mesh folder for i in range(len(archetypes_df)): diff --git a/src/br/analysis/prereq.py b/src/br/analysis/prereq.py deleted file mode 100644 index 13d0370..0000000 --- a/src/br/analysis/prereq.py +++ /dev/null @@ -1,344 +0,0 @@ -# Free up cache -import argparse -import gc -import os -import subprocess - -import pandas as pd -import torch - -from br.models.compute_features import compute_features -from br.models.load_models import get_data_and_models -from br.models.save_embeddings import ( - get_pc_loss_chamfer, - save_embeddings, - save_emissions, -) -from br.models.utils import get_all_configs_per_dataset - -gc.collect() -torch.cuda.empty_cache() - -# Based on the utilization, set the GPU ID - - -def get_gpu_info(): - # Run nvidia-smi command and get the output - cmd = [ - "nvidia-smi", - "--query-gpu=index,uuid,name,utilization.gpu", - "--format=csv,noheader,nounits", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - return result.stdout.strip() - - -def check_mig(): - # Check if MIG is enabled - cmd = ["nvidia-smi", "-L"] - result = subprocess.run(cmd, capture_output=True, text=True) - return "MIG" in result.stdout - - -def get_mig_ids(): - # Get the MIG UUIDs - cmd = ["nvidia-smi", "-L"] - result = subprocess.run(cmd, capture_output=True, text=True) - mig_ids = [] - for line in result.stdout.splitlines(): - if "MIG" in line: - mig_id = line.split("(UUID: ")[-1].strip(")") - mig_ids.append(mig_id) - return mig_ids - - -def config_gpu(): - selected_gpu_id_or_uuid = "" - is_mig = check_mig() - - gpu_info = get_gpu_info() - lines = gpu_info.splitlines() - - for line in lines: - index, uuid, name, utilization = map(str.strip, line.split(",")) - - # If utilization is [N/A], treat it as less than 10 - if utilization == "[N/A]": - utilization = -1 # Assign a value less than 10 to simulate "idle" - else: - utilization = int(utilization) - - # Check if GPU utilization is under 10% (indicating it's idle) - if utilization < 10: - if is_mig: - mig_ids = get_mig_ids() - if mig_ids: - selected_gpu_id_or_uuid = mig_ids[0] # Select the first MIG ID - break # Exit the loop after finding the first MIG ID - else: - selected_gpu_id_or_uuid = uuid - print(f"Selected UUID is {selected_gpu_id_or_uuid}") - break - return selected_gpu_id_or_uuid - - -selected_gpu_id_or_uuid = config_gpu() - -# Set the CUDA_VISIBLE_DEVICES environment variable using the selected ID -if selected_gpu_id_or_uuid: - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = selected_gpu_id_or_uuid - print(f"CUDA_VISIBLE_DEVICES set to: {selected_gpu_id_or_uuid}") -else: - print("No suitable GPU or MIG ID found. Exiting...") - -# Set the device -device = "cuda:0" - -# Setting a GPU ID is crucial for the script to work well! - - -def main(args): - # Set working directory and paths - os.chdir(args.src_path) - save_path = args.save_path - results_path = args.results_path - dataset_name = args.dataset_name - batch_size = args.batch_size - debug = args.debug - - # Load data and models - data_list, all_models, run_names, model_sizes, manifest = get_data_and_models( - dataset_name, batch_size, results_path, debug - ) - - # Save model sizes to CSV - sizes_ = pd.DataFrame() - sizes_["model"] = run_names - sizes_["model_size"] = model_sizes - sizes_.to_csv(os.path.join(save_path, "model_sizes.csv")) - - save_embeddings_across_models(args, manifest, data_list, all_models, run_names) - compute_relevant_features() - - -def _setup_evaluation_params(manifest, run_names): - """Return evaluation params related to. - - 1. loss_eval_list - which loss to use for each model (Defaults to Chamfer loss) - 2. skew_scale - Hyperparameter associated with sampling of pointclouds from images - 3. sample_points_list - whether to sample pointclouds for each model - 4. eval_scaled_img - whether to scale the images for evaluation (specific to SDF models) - 5. eval_scaled_img_params - parameters like mesh paths, scale factors, pointcloud paths associated - with evaluating scaled images - """ - eval_scaled_img = [False] * len(run_names) - eval_scaled_img_params = [{}] * len(run_names) - - if "SDF" in "\t".join(run_names): - eval_scaled_img_resolution = 32 - gt_mesh_dir = manifest["mesh_folder"].iloc[0] - gt_sampled_pts_dir = manifest["pointcloud_folder"].iloc[0] - gt_scale_factor_dict_path = manifest["scale_factor"].iloc[0] - eval_scaled_img_params = [] - for name_ in run_names: - if "seg" in name_: - model_type = "seg" - elif "SDF" in name_: - model_type = "sdf" - elif "pointcloud" in name_: - model_type = "iae" - eval_scaled_img_params.append( - { - "eval_scaled_img_model_type": model_type, - "eval_scaled_img_resolution": eval_scaled_img_resolution, - "gt_mesh_dir": gt_mesh_dir, - "gt_scale_factor_dict_path": gt_scale_factor_dict_path, - "gt_sampled_pts_dir": gt_sampled_pts_dir, - "mesh_ext": "stl", - } - ) - loss_eval_list = [torch.nn.MSELoss(reduction="none")] * len(run_names) - sample_points_list = [False] * len(run_names) - skew_scale = None - else: - loss_eval_list = None - skew_scale = 100 - sample_points_list = [] - for name_ in run_names: - if "image" in name_: - sample_points_list.append(True) - else: - sample_points_list.append(False) - return eval_scaled_img, eval_scaled_img_params, loss_eval_list, sample_points_list, skew_scale - - -def save_embeddings_across_models(args, manifest, data_list, all_models, run_names): - """ - Save embeddings across models - """ - # Compute embeddings and reconstructions for each model - splits_list = ["train", "val", "test"] - ( - eval_scaled_img, - eval_scaled_img_params, - loss_eval_list, - sample_points_list, - skew_scale, - ) = _setup_evaluation_params(manifest, run_names) - - save_embeddings( - args.save_path, - data_list, - all_models, - run_names, - args.debug, - splits_list, - device, - args.meta_key, - loss_eval_list, - sample_points_list, - skew_scale, - eval_scaled_img, - eval_scaled_img_params, - ) - - -def compute_relevant_features(): - - batch_size = 1 - data_list, all_models, run_names, model_sizes = get_data_and_models( - dataset_name, batch_size, results_path, debug - ) - - # Save emission stats for each model - max_batches = 40 - save_emissions( - save_path, - data_list, - all_models, - run_names, - max_batches, - debug, - device, - loss_eval_list, - sample_points_list, - skew_scale, - eval_scaled_img, - eval_scaled_img_params, - ) - - # Compute multi-metric benchmarking features - keys = ["pcloud"] * 5 - max_embed_dim = 256 - DATA_LIST = get_all_configs_per_dataset(results_path) - data_config_list = DATA_LIST[dataset_name]["data_paths"] - - evolve_params = { - "modality_list_evolve": keys, - "config_list_evolve": data_config_list, - "num_evolve_samples": 40, - "compute_evolve_dataloaders": False, - "eval_meshed_img": [False] * 5, - "skew_scale": 100, - "eval_meshed_img_model_type": [None] * 5, - "only_embedding": False, - "fit_pca": False, - } - - loss_eval = get_pc_loss_chamfer() - loss_eval_list = [loss_eval] * 5 - # use_sample_points_list = [True, True, False, False, False] # This again is different . RITVIK - use_sample_points_list = [False, False, True, True, False] - - classification_params = {"class_labels": ["rule"]} - rot_inv_params = {"squeeze_2d": False, "id": "cell_id", "max_batches": 4000} - regression_params = {"df_feat": None, "target_cols": None, "feature_df_path": None} - compactness_params = { - "method": "mle", - "num_PCs": None, - "blobby_outlier_max_cc": None, - "check_duplicates": True, - } - - splits_list = ["train", "val", "test"] - compute_embeds = False - - metric_list = [ - "Rotation Invariance Error", - "Evolution Energy", - "Reconstruction", - "Classification", - "Compactness", - ] # Different again - - compute_features( - dataset=dataset_name, - results_path=results_path, - embeddings_path=save_path, - save_folder=save_path, - data_list=data_list, - all_models=all_models, - run_names=run_names, - use_sample_points_list=use_sample_points_list, - keys=keys, - device=device, - max_embed_dim=max_embed_dim, - splits_list=splits_list, - compute_embeds=compute_embeds, - classification_params=classification_params, - regression_params=regression_params, - metric_list=metric_list, - loss_eval_list=loss_eval_list, - evolve_params=evolve_params, - rot_inv_params=rot_inv_params, - compactness_params=compactness_params, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script for Benchmarking Representations") - parser.add_argument( - "--src_path", type=str, required=True, help="Path to the source directory." - ) - parser.add_argument( - "--save_path", type=str, required=True, help="Path to save the embeddings." - ) - parser.add_argument( - "--results_path", type=str, required=True, help="Path to the results directory." - ) - parser.add_argument( - "--meta_key", - type=str, - required=True, - help="Metadata to add to the embeddings aside from CellId", - ) - parser.add_argument( - "--sdf", - type=bool, - required=True, - help="boolean indicating whether the experiments involve SDFs", - ) - parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") - parser.add_argument("--batch_size", type=int, default=2, help="Batch size for processing.") - parser.add_argument("--debug", type=bool, default=True, help="Enable debug mode.") - - args = parser.parse_args() - - # Validate that required paths are provided - if not args.src_path or not args.save_path or not args.results_path or not args.dataset_name: - print("Error: Required arguments are missing.") - sys.exit(1) - - main(args) - -""" -Example -os.chdir(r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/") -save_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/src/test_cellpack_save_embeddings/" -results_path = r"/allen/aics/assay-dev/users/Fatwir/benchmarking_representations/configs/results/" -dataset_name = "cellpack" -batch_size = 2 -debug = True - -""" diff --git a/src/br/analysis/run_analysis.py b/src/br/analysis/run_analysis.py index 9a4887a..c018690 100644 --- a/src/br/analysis/run_analysis.py +++ b/src/br/analysis/run_analysis.py @@ -6,13 +6,13 @@ import pandas as pd from br.analysis.analysis_utils import ( - _archetypes_polymorphic, - _archetypes_save_recons, - _dataset_specific_subsetting, - _latent_walk_polymorphic, - _latent_walk_save_recons, - _pseudo_time_analysis, - _setup_gpu, + archetypes_polymorphic, + archetypes_save_recons, + dataset_specific_subsetting, + latent_walk_polymorphic, + latent_walk_save_recons, + pseudo_time_analysis, + setup_gpu, str2bool, ) from br.features.archetype import AA_Fast @@ -23,7 +23,7 @@ def main(args): - _setup_gpu() + setup_gpu() device = "cuda:0" config_path = os.environ.get("CYTODL_CONFIG_PATH") @@ -38,7 +38,7 @@ def main(args): all_ret, df = get_embeddings([run_name], args.dataset_name, DATASET_INFO, args.embeddings_path) model, x_label, latent_dim, model_size = _load_model_from_path(checkpoints[0], False, device) - all_ret, stratify_key, n_archetypes, viz_params = _dataset_specific_subsetting( + all_ret, stratify_key, n_archetypes, viz_params = dataset_specific_subsetting( all_ret, args.dataset_name ) @@ -48,7 +48,7 @@ def main(args): this_save_path.mkdir(parents=True, exist_ok=True) if args.sdf: - _latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim) + latent_walk_polymorphic(stratify_key, all_ret, this_save_path, latent_dim) else: stratified_latent_walk( model, @@ -66,7 +66,7 @@ def main(args): ) # Save reconstruction plots - _latent_walk_save_recons(this_save_path, stratify_key, viz_params, args.dataset_name) + latent_walk_save_recons(this_save_path, stratify_key, viz_params, args.dataset_name) # Archetype analysis matrix = all_ret[[i for i in all_ret.columns if "mu" in i]].values @@ -77,13 +77,13 @@ def main(args): this_save_path.mkdir(parents=True, exist_ok=True) if args.sdf: - _archetypes_polymorphic(this_save_path, archetypes_df, all_ret, matrix) + archetypes_polymorphic(this_save_path, archetypes_df, all_ret, matrix) else: - _archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path) + archetypes_save_recons(model, archetypes_df, device, key, viz_params, this_save_path) # Pseudotime analysis if "volume_of_nucleus_um3" in all_ret.columns: - _pseudo_time_analysis(model, all_ret, args.save_path, device, key, viz_params) + pseudo_time_analysis(model, all_ret, args.save_path, device, key, viz_params) if __name__ == "__main__": diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 09d6571..f7e38c5 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -3,7 +3,7 @@ import os import sys -from br.analysis.analysis_utils import _setup_evaluation_params, _setup_gpu, str2bool +from br.analysis.analysis_utils import setup_evaluation_params, setup_gpu, str2bool from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings @@ -11,7 +11,7 @@ def main(args): # Setup GPUs and set the device - _setup_gpu() + setup_gpu() device = "cuda:0" # Get config path from CYTODL_CONFIG_PATH @@ -29,7 +29,7 @@ def main(args): loss_eval_list, sample_points_list, skew_scale, - ) = _setup_evaluation_params(manifest, run_names) + ) = setup_evaluation_params(manifest, run_names) # save embeddings for each model save_embeddings( diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index 77b5585..b82a634 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -6,9 +6,9 @@ import pandas as pd from br.analysis.analysis_utils import ( - _get_feature_params, - _setup_evaluation_params, - _setup_gpu, + get_feature_params, + setup_evaluation_params, + setup_gpu, str2bool, ) from br.features.plot import collect_outputs, plot @@ -19,7 +19,7 @@ def main(args): # Setup GPUs and set the device - _setup_gpu() + setup_gpu() device = "cuda:0" # set batch size to 1 for emission stats/features @@ -53,7 +53,7 @@ def main(args): loss_eval_list, sample_points_list, skew_scale, - ) = _setup_evaluation_params(manifest, run_names) + ) = setup_evaluation_params(manifest, run_names) # Save emission stats for each model max_batches = 40 @@ -79,7 +79,7 @@ def main(args): classification_params, evolve_params, regression_params, - ) = _get_feature_params( + ) = get_feature_params( config_path + "/results/", args.dataset_name, manifest, keys, run_names ) From 34bc7b9f7255a6357caac306c3391d1bc9414978 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Tue, 26 Nov 2024 12:03:40 -0800 Subject: [PATCH 32/35] create dir if doesnt exist + fix cellpack evolve bug --- src/br/analysis/analysis_utils.py | 2 +- src/br/analysis/run_embeddings.py | 5 ++++- src/br/analysis/run_features.py | 9 +++++++-- src/br/features/evolve.py | 32 +++++-------------------------- 4 files changed, 17 insertions(+), 31 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index adf7826..abb2b8c 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -239,7 +239,7 @@ def get_feature_params(results_path, dataset_name, manifest, keys, run_names): data_config_list = [cytodl_config_path + i for i in data_config_list] class_label = DATA_LIST[dataset_name]["classification_label"] regression_label = DATA_LIST[dataset_name]["regression_label"] - evolve_params = _setup_evolve_params(run_names, data_config_list, keys) + evolve_params = setup_evolve_params(run_names, data_config_list, keys) classification_params = {"class_labels": class_label, "df_feat": manifest} rot_inv_params = {"squeeze_2d": False, "id": "cell_id", "max_batches": 4000} regression_params = { diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index f7e38c5..8bafecd 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -2,7 +2,7 @@ import argparse import os import sys - +from pathlib import Path from br.analysis.analysis_utils import setup_evaluation_params, setup_gpu, str2bool from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings @@ -31,6 +31,9 @@ def main(args): skew_scale, ) = setup_evaluation_params(manifest, run_names) + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + # save embeddings for each model save_embeddings( args.save_path, diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index b82a634..faa4c1e 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -2,9 +2,8 @@ import argparse import os import sys - +from pathlib import Path import pandas as pd - from br.analysis.analysis_utils import ( get_feature_params, setup_evaluation_params, @@ -40,6 +39,9 @@ def main(args): ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) max_embed_dim = min(latent_dims) + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + # Save model sizes to CSV sizes_ = pd.DataFrame() sizes_["model"] = run_names @@ -93,6 +95,7 @@ def main(args): if regression_params["target_cols"]: metric_list.append("Regression") + # Compute multi-metric benchmarking features compute_features( dataset=args.dataset_name, @@ -166,4 +169,6 @@ def main(args): """ Example run: python src/br/analysis/run_features.py --save_path "./outputs/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/pcna" --sdf False --dataset_name "pcna" + + python src/br/analysis/run_features.py --save_path "/outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --sdf False --dataset_name "cellpack" --debug False """ diff --git a/src/br/features/evolve.py b/src/br/features/evolve.py index 03b07ed..27b3fa5 100644 --- a/src/br/features/evolve.py +++ b/src/br/features/evolve.py @@ -54,6 +54,8 @@ def update_config(config_path, data, configs, save_path, suffix): config = yaml.safe_load(stream) if "pointcloudutils.datamodules.ShapenetDataModule" == config["_target_"]: config["dataset_folder"] = str(save_path / "iae") + elif "pointcloudutils.datamodules.CellPackDataModule" == config["_target_"]: + pass else: config["path"] = str(save_path / f"{suffix}.csv") config["batch_size"] = 2 @@ -64,6 +66,9 @@ def update_config(config_path, data, configs, save_path, suffix): def make_csv(pc_path, image_path, num_samples, save_path, key="CellId", pc_is_iae=False): + if not pc_path and not image_path: + return + if pc_path.split(".")[-1] == "csv": pc_df = pd.read_csv(pc_path) else: @@ -119,33 +124,6 @@ def make_csv(pc_path, image_path, num_samples, save_path, key="CellId", pc_is_ia image_df.to_csv(save_path / "image.csv") -# def get_pc_configs(dataset_name): -# folder = get_config_folders(dataset_name) -# config_list = [ -# f"../data/configs/{folder}/pointcloud_3.yaml", -# f"../data/configs/{folder}/pointcloud_4.yaml", -# ] -# return config_list - - -# def get_config_folders(dataset_name): -# if dataset_name == "cellpainting": -# folder = "inference_cellpainting_configs" -# elif dataset_name == "variance": -# folder = "inference_variance_data_configs" -# elif dataset_name == "pcna": -# folder = "inference_pcna_data_configs" -# return folder - - -# def get_image_configs(dataset_name): -# folder = get_config_folders(dataset_name) -# config_list = [ -# f"../data/configs/{folder}/image_full.yaml", -# ] -# return config_list - - def get_dataloaders(save_path, config_list_evolve, modality_list): data = [] configs = [] From c35a1ba028f59a2ef20fdcd6e0b6cc85cd34bcc5 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Tue, 26 Nov 2024 12:03:56 -0800 Subject: [PATCH 33/35] add image path and pc path for polymorphic data --- configs/results/other_polymorphic.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/results/other_polymorphic.yaml b/configs/results/other_polymorphic.yaml index cccc773..09cfe61 100644 --- a/configs/results/other_polymorphic.yaml +++ b/configs/results/other_polymorphic.yaml @@ -1,6 +1,6 @@ orig_df: ./morphology_appropriate_representation_learning/preprocessed_data/other_polymorphic/manifest.csv -image_path: -pc_path: +image_path: ./morphology_appropriate_representation_learning/preprocessed_data/other_polymorphic/manifest.csv +pc_path: ./morphology_appropriate_representation_learning/preprocessed_data/other_polymorphic/manifest.csv model_checkpoints: [ "./morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/Rotation_invariant_pointcloud_SDF.ckpt", From 77c2986266c8aa1a36b8816d6be9236b318473d0 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Tue, 26 Nov 2024 12:32:32 -0800 Subject: [PATCH 34/35] remove stray prints --- src/br/analysis/run_features.py | 1 + src/br/features/outlier_compactness.py | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index faa4c1e..a694c3d 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -126,6 +126,7 @@ def main(args): csvs = [i.split(".")[0] for i in csvs] # Remove non metric related csvs csvs = [i for i in csvs if i not in run_names and i not in keys] + csvs = [i for i in csvs if i not in ['image', 'pcloud']] # classification and regression metrics are unique to each dataset unique_metrics = [i for i in csvs if "classification" in i or "regression" in i] # Collect dataframe and make plots diff --git a/src/br/features/outlier_compactness.py b/src/br/features/outlier_compactness.py index 7092bbd..70d1d55 100644 --- a/src/br/features/outlier_compactness.py +++ b/src/br/features/outlier_compactness.py @@ -108,7 +108,6 @@ def compactness(this_mo, num_PCs, max_embed_dim, method): max_embed_dim.""" cols = [i for i in this_mo.columns if "mu" in i] this_feats = this_mo[cols].iloc[:, :max_embed_dim].dropna(axis=1).values - print(this_feats.shape) if method == "pca": _, _, val = compute_PCA_expl_var(this_feats, num_PCs) val = [val] @@ -127,7 +126,6 @@ def outlier_detection(this_mo, outlier_label=0, blobby_outlier_max_cc=None): ) else: if "flag_comment" in this_mo.columns: - print("Outlier column is flag comment") this_mo1 = this_mo.loc[ this_mo["flag_comment"].isin( ["cell appears dead or dying", "no EGFP fluorescence"] @@ -142,21 +140,18 @@ def outlier_detection(this_mo, outlier_label=0, blobby_outlier_max_cc=None): this_mo2["outlier"] = "No" this_mo = pd.concat([this_mo1, this_mo2], axis=0).reset_index(drop=True) elif "Anomaly" in this_mo.columns: - print("Outlier column is Anamoly") this_mo1 = this_mo.loc[~this_mo["Anomaly"].isin(["none"])] this_mo1["outlier"] = "Yes" this_mo2 = this_mo.loc[this_mo["Anomaly"].isin(["none"])] this_mo2["outlier"] = "No" this_mo = pd.concat([this_mo1, this_mo2], axis=0).reset_index(drop=True) elif "cell_stage" in this_mo.columns: - print("Outlier column is cell stage") this_mo1 = this_mo.loc[~this_mo["cell_stage"].isin(["M0"])] this_mo1["outlier"] = "Yes" this_mo2 = this_mo.loc[this_mo["cell_stage"].isin(["M0"])] this_mo2["outlier"] = "No" this_mo = pd.concat([this_mo1, this_mo2], axis=0).reset_index(drop=True) elif "outlier" not in this_mo.columns: - print("Outlier column is outlier") return 0 if this_mo["outlier"].isna().any(): From a4dc4e9f8f4743199d220eff065c654448bd0a68 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Tue, 26 Nov 2024 12:32:56 -0800 Subject: [PATCH 35/35] add new scripts to docs --- docs/USAGE.md | 50 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/docs/USAGE.md b/docs/USAGE.md index e604a1c..1d0a47c 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -62,10 +62,10 @@ Coming soon. ## Steps to train models -Training these models can take weeks. We've published our trained models so you don't have to. Skip to the next section if you'd like to just use our models. +Training these models can take days. We've published our trained models so you don't have to. Skip to the next section if you'd like to just use our models. 1. Create a single cell manifest (e.g. csv, parquet) for each dataset with a column corresponding to final processed paths, and create a split column corresponding to train/test/validation split. -2. Update the final single cell dataset path (`SINGLE_CELL_DATASET_PATH`) and the column in the manifest for appropriate input modality (`SDF_COLUMN`/`SEG_COLUMN`/`POINTCLOUD_COLUMN`/`IMAGE_COLUMN`) in each datamodule yaml files. e.g. for PCNA data these yaml files are located here - +2. Update the final single cell dataset path (`SINGLE_CELL_DATASET_PATH`) and the column in the manifest for appropriate input modality (`SDF_COLUMN`/`SEG_COLUMN`/`POINTCLOUD_COLUMN`/`IMAGE_COLUMN`) in each [datamodule file](../configs/data/). e.g. for PCNA data these yaml files are located here - ``` └── configs @@ -77,14 +77,16 @@ Training these models can take weeks. We've published our trained models so you          └── pc_intensity_jitter.yaml <- Datamodule for PCNA point clouds with intensity and jitter ``` -3. Train models using cyto_dl. Ensure to run the training scripts from the folder where the repo was cloned (and where all the data was downloaded). Experiment configs for point cloud and image models are located here - +3. Train models using cyto_dl. Ensure to run the training scripts from the folder where the repo was cloned (and where all the data was downloaded). [Experiment configs](../configs/experiment/) for point cloud and image models for the cellpack dataset are located here - ``` └── configs    └── experiment       └── cellpack -          ├── image_equiv.yaml <- Rotation invariant image model experiment -          └── pc_equiv.yaml <- Rotation invariant point cloud model experiment +          ├── image_classical.yaml <- Classical image model experiment +          ├── image_so3.yaml <- Rotation invariant image model experiment +          └── pc_classical.yaml <- Classical point cloud model experiment +          └── pc_so3.yaml <- Rotation invariant point cloud model experiment ``` Here is an example of training a rotation invariant point cloud model @@ -99,6 +101,14 @@ Override parts of the experiment config via command line or manually in the conf python src/br/models/train.py experiment=cellpack/pc_so3 model=pc/classical_earthmovers_sphere ++csv.save_dir=[SAVE_DIR] ``` +4. To compute embeddings from the trained models, update the data paths in the [datamodule files](../configs/data/) and run + +``` +python src/br/analysis/run_embeddings.py --save_path "./outputs/" --sdf False --dataset_name "pcna" --batch_size 5 --debug False +``` + +where dataset_name corresponds to a [result config](../configs/results/). The sdf argument should be set to True for experiments involving SDFs like the [npm1 dataset](../configs/results/npm1.yaml) and [other polymorphic dataset](../configs/experiment/other_polymorphic/). + ## Steps to download pre-trained models and pre-computed embeddings 1. To skip model training, download pre-trained models @@ -110,7 +120,7 @@ python src/br/models/train.py experiment=cellpack/pc_so3 model=pc/classical_eart * [WTC-11 hIPSc single cell image dataset v1 polymorphic structures](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_checkpoints/other_polymorphic/) * [Nucleolar drug perturbation dataset](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_checkpoints/npm1_perturb/) -2. Download pre-computed embeddings +2. To skip computing embeddings, download pre-computed embeddings * [cellPACK synthetic dataset](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_embeddings/cellpack/) * [DNA replication foci dataset](https://open.quiltdata.com/b/allencell/tree/aics/morphology_appropriate_representation_learning/model_embeddings/pcna/) @@ -121,16 +131,22 @@ python src/br/models/train.py experiment=cellpack/pc_so3 model=pc/classical_eart ## Steps to run benchmarking analysis -1. Run analysis for each dataset separately via jupyter notebooks +1. To compute benchmarking features from the embeddings and trained models, run + +``` +python src/br/analysis/run_features.py --save_path "/outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --sdf False --dataset_name "cellpack" --debug False +``` +where dataset_name corresponds to a [result config](../configs/results/). + +2. To run analysis like latent walks and archetype analysis on the embeddings and trained models, run + +``` +python src/br/analysis/run_analysis.py --save_path "./outputs_cellpack/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/cellpack" --dataset_name "cellpack" --run_name "Rotation_invariant_pointcloud_jitter" --sdf False +``` + +3. To run drug perturbation analysis, run + +``` +python src/br/analysis/run_drugdata_analysis.py --save_path "./outputs_npm1_perturb/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_perturb/" --dataset_name "npm1_perturb" ``` -└── src - └── br -    └── notebooks -       ├── fig2_cellpack.ipynb <- Reproduce Fig 2 cellPACK synthetic data results -       ├── fig3_pcna.ipynb <- Reproduce Fig 3 PCNA data results -       ├── fig4_other_punctate.ipynb <- Reproduce Fig 4 other puntate structure data results -       ├── fig5_npm1.ipynb <- Reproduce Fig 5 npm1 data results -       ├── fig6_other_polymorphic.ipynb <- Reproduce Fig 6 other polymorphic data results -       └── fig7_drug_data.ipynb <- Reproduce Fig 7 drug data results -``` \ No newline at end of file