From 0f16eab38f63e246283ecd774d3627e7d3564061 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 12:39:55 -0800 Subject: [PATCH 1/9] add supplemental reconstruction generation code for SDFs --- src/br/analysis/analysis_utils.py | 313 ++++++++++++++++++++++- src/br/analysis/run_drugdata_analysis.py | 25 +- src/br/analysis/run_embeddings.py | 3 +- src/br/analysis/run_features.py | 11 +- src/br/analysis/save_reconstructions.py | 103 ++++++++ src/br/chandrasekaran_et_al/utils.py | 15 +- 6 files changed, 438 insertions(+), 32 deletions(-) create mode 100644 src/br/analysis/save_reconstructions.py diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index abb2b8c..7b232d5 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -4,22 +4,27 @@ import subprocess from pathlib import Path +import matplotlib.colors as mcolors import matplotlib.pyplot as plt +import mesh_to_sdf import numpy as np import pandas as pd import pyvista as pv import torch +import trimesh import yaml +from aicsimageio import AICSImage from sklearn.decomposition import PCA from tqdm import tqdm +from br.data.utils import get_iae_reconstruction_3d_grid from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud from br.features.utils import ( normalize_intensities_and_get_colormap, normalize_intensities_and_get_colormap_apply, ) -from br.models.utils import get_all_configs_per_dataset +from br.models.utils import get_all_configs_per_dataset, move def str2bool(v): @@ -55,7 +60,14 @@ def check_mig(): def get_mig_ids(gpu_uuid): try: # Get the list of GPUs - output = subprocess.check_output(['nvidia-smi','--query-gpu=,index,uuid' ,'--format=csv,noheader']).decode('utf-8').strip().split('\n') + output = ( + subprocess.check_output( + ["nvidia-smi", "--query-gpu=,index,uuid", "--format=csv,noheader"] + ) + .decode("utf-8") + .strip() + .split("\n") + ) # Find the index of the specified GPU UUID gpu_index = -1 @@ -71,7 +83,9 @@ def get_mig_ids(gpu_uuid): # Now we need to get the MIG IDs for this GPU mig_ids = [] # Run nvidia-smi command to get detailed information including MIG IDs - detailed_output = subprocess.check_output(['nvidia-smi', '-L']).decode('utf-8').strip().split('\n') + detailed_output = ( + subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip().split("\n") + ) # Flag to determine if we are in the right GPU section in_gpu_section = False @@ -82,11 +96,13 @@ def get_mig_ids(gpu_uuid): break # print(line) - + if in_gpu_section: # Check for MIG devices if "MIG" in line: - mig_id = line.split('(')[1].split(')')[0].split(' ')[-1] # Assuming format is '.... MIG (UUID) ...' + mig_id = ( + line.split("(")[1].split(")")[0].split(" ")[-1] + ) # Assuming format is '.... MIG (UUID) ...' mig_ids.append(mig_id.strip()) return mig_ids @@ -105,14 +121,14 @@ def config_gpu(): for line in lines: index, uuid, name, mem_used, mem_total = map(str.strip, line.split(",")) - utilization = float(mem_used)*100/float(mem_total) - + utilization = float(mem_used) * 100 / float(mem_total) + # Check if GPU utilization is under 20% (indicating it's idle) if utilization < 20: # print(uuid, utilization) if is_mig: mig_ids = get_mig_ids(uuid) - + if mig_ids: selected_gpu_id_or_uuid = mig_ids[0] # Select the first MIG ID break # Exit the loop after finding the first MIG ID @@ -594,3 +610,284 @@ def archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features) arch_dict["CellId"].append(closest_real_id) arch_dict = pd.DataFrame(arch_dict) arch_dict.to_csv(this_save_path / "archetypes.csv") + + +def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, device, save_path): + with torch.no_grad(): + for j, model in enumerate(all_models): + this_data = data_list[j] + this_run_name = run_names[j] + this_key = keys[j] + for batch in this_data.test_dataloader(): + if not isinstance(batch["cell_id"], list): + if isinstance(batch["cell_id"], torch.Tensor): + cell_id = str(batch["cell_id"].item()) + else: + cell_id = str(batch["cell_id"]) + else: + cell_id = str(batch["cell_id"][0]) + if cell_id in test_ids: + input = batch[this_key].detach().cpu().numpy().squeeze() + if "pointcloud_SDF" in this_run_name: + eval_scaled_img_resolution = 32 + uni_sample_points = get_iae_reconstruction_3d_grid( + bb_min=-0.5, + bb_max=0.5, + resolution=eval_scaled_img_resolution, + padding=0.1, + ) + uni_sample_points = uni_sample_points.unsqueeze(0).repeat( + batch[this_key].shape[0], 1, 1 + ) + batch["points"] = uni_sample_points + xhat, z, z_params = model( + move(batch, device), decode=True, inference=True, return_params=True + ) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + recon = recon.reshape( + eval_scaled_img_resolution, + eval_scaled_img_resolution, + eval_scaled_img_resolution, + ) + else: + z = model.encode(move(batch, device)) + xhat = model.decode(z, return_canonical=True) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + canonical = xhat["canonical"].detach().cpu().numpy().squeeze() + + this_save_path = Path(save_path) / Path(this_run_name) + this_save_path.mkdir(parents=True, exist_ok=True) + np.save(this_save_path / Path(f"{cell_id}.npy"), recon) + + this_save_path_input = Path(save_path) / Path(this_run_name) / Path("input") + this_save_path_input.mkdir(parents=True, exist_ok=True) + np.save(this_save_path_input / Path(f"{cell_id}.npy"), input) + + this_save_path_canon = ( + Path(save_path) / Path(this_run_name) / Path("canonical") + ) + this_save_path_canon.mkdir(parents=True, exist_ok=True) + np.save(this_save_path_canon / Path(f"{cell_id}.npy"), canonical) + + +def save_supplemental_figure_punctate_reconstructions( + df, test_ids, run_names, reconstructions_path +): + def slice_(img, slices=None, z_ind=0): + if not slices: + return img.max(z_ind) + mid_z = int(img.shape[0] / 2) + if z_ind == 0: + img = img[mid_z - slices : mid_z + slices].max(0) + if z_ind == 2: + img = img[:, :, mid_z - slices : mid_z + slices].max(2) + return img + + for i, c in enumerate(test_ids): + row_index = i + recons = [] + for m in run_names: + input_path = reconstructions_path + f"{m}/input/{c}.npy" + input = np.load(input_path).squeeze() + + recon_path = reconstructions_path + f"{m}/{c}.npy" + recon = np.load(recon_path).squeeze() + + recon_path = reconstructions_path + f"{m}/canonical/{c}.npy" + recon_canonical = np.load(recon_path).squeeze() + + num_slice = 8 + z_ind = 0 + + input = slice_(input, num_slice, z_ind) + recon = slice_(recon, num_slice, z_ind) + recon_canonical = slice_(recon_canonical, num_slice, 2) + + i = 2 + fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) + # ax.imshow(this[:, :, :].max(i).T, origin='lower', cmap='gray_r') + # ax1.imshow(this2[:, :, :].max(i).T, origin='lower', cmap='gray_r') + # ax2.imshow(this3[:, :, :].max(i).T, origin='lower', cmap='gray_r') + ax.imshow(input, cmap="gray_r") + ax1.imshow(recon, cmap="gray_r") + ax2.imshow(recon_canonical, cmap="gray_r") + + ax.set_xticks([]) + ax.set_yticks([]) + ax2.set_xticks([]) + ax2.set_yticks([]) + ax1.set_xticks([]) + ax1.set_yticks([]) + # max_size = 192 + max_size = recon_canonical.shape[1] + ax.set_aspect("equal", adjustable="box") + ax1.set_aspect("equal", adjustable="box") + ax2.set_aspect("equal", adjustable="box") + + # max_size = 6 + ax.set_ylim([0, max_size]) + ax1.set_ylim([0, max_size]) + ax2.set_ylim([0, max_size]) + ax.set_xlim([0, max_size]) + ax1.set_xlim([0, max_size]) + ax2.set_xlim([0, max_size]) + + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + + ax1.spines["top"].set_visible(False) + ax1.spines["right"].set_visible(False) + ax1.spines["bottom"].set_visible(False) + ax1.spines["left"].set_visible(False) + + ax2.spines["top"].set_visible(False) + ax2.spines["right"].set_visible(False) + ax2.spines["bottom"].set_visible(False) + ax2.spines["left"].set_visible(False) + + ax.set_title("Input") + ax1.set_title("Reconstruction") + ax2.set_title("Canonical Reconstruction") + fig.subplots_adjust(wspace=0, hspace=0) + fig.savefig( + reconstructions_path + f"sample_recons_{c}_{m}.pdf", bbox_inches="tight", dpi=300 + ) + + +def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_path): + import pyvista as pv + + pv.start_xvfb() + gt_test_sdfs = [] + gt_test_segs = [] + for tid in test_ids: + path = df[df["CellId"] == int(tid)]["sdf_path"].values[0] + sdf = np.load(path) + path = df[df["CellId"] == int(tid)]["seg_path"].values[0] + seg = np.load(path) + gt_test_sdfs.append(sdf) + gt_test_segs.append(seg) + + gt_orig_struc = [] + gt_orig_cell = [] + gt_orig_nuc = [] + for tid in test_ids: + path = df[df["CellId"] == int(tid)]["crop_seg_masked"].values[0] + seg = AICSImage(path).data.squeeze() + path = df[df["CellId"] == int(tid)]["crop_seg"].values[0] + img = AICSImage(path).data.squeeze() + gt_orig_struc.append(seg) + gt_orig_nuc.append(img[0]) + gt_orig_cell.append(img[1]) + + eval_scaled_img_resolution = 32 + mid_slice_ = int(eval_scaled_img_resolution / 2) + uni_sample_points_grid = get_iae_reconstruction_3d_grid( + bb_min=-0.5, bb_max=0.5, resolution=eval_scaled_img_resolution, padding=0.1 + ) + gt_test_sdfs_iae = [] + mesh_folder = df["mesh_folder"].iloc[0] + for tid in test_ids: + path = mesh_folder + str(tid) + ".stl" + mesh = trimesh.load(path) + bbox = mesh.bounding_box.bounds + loc = (bbox[0] + bbox[1]) / 2 + scale_factor = (bbox[1] - bbox[0]).max() + mesh = mesh.apply_translation(-loc) + mesh = mesh.apply_scale(1 / scale_factor) + sdf_vals = mesh_to_sdf.mesh_to_sdf(mesh, query_points=uni_sample_points_grid.numpy()) + gt_test_sdfs_iae.append(sdf_vals) + + cmap_inverted = mcolors.ListedColormap(np.flipud(plt.cm.gray(np.arange(256)))) + + model_order = [ + "Classical_image_seg", + "Rotation_invariant_image_seg", + "Classical_image_SDF", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", + ] + + for split, split_ids, gt_segs, gt_sdfs, gt_test_i_sdfs in [ + ("test", test_ids, gt_test_segs, gt_test_sdfs, gt_test_sdfs_iae) + ]: + + num_rows = len(split_ids) + num_columns = len(model_order) + 4 + fig, axs = plt.subplots(num_rows, num_columns, figsize=(num_columns * 5, num_rows * 5)) + + for i, c in enumerate(split_ids): + gt_seg = gt_segs[i] + gt_sdf = np.clip(gt_sdfs[i], -2, 2) + gt_sdf_i = gt_test_i_sdfs[i].reshape( + eval_scaled_img_resolution, eval_scaled_img_resolution, eval_scaled_img_resolution + ) + row_index = i + recons = [] + for m in model_order: + recon_path = reconstructions_path + f"{m}/{c}.npy" + recon = np.load(recon_path).squeeze() + + if "SDF" in m or "vn" in m: + mid_z = recon.shape[0] // 2 + if ("Rotation" in m) and (m != "Rotation_invariant_pointcloud_SDF"): + z_slice = recon[mid_z, :, :].T + from scipy.ndimage import rotate + + z_slice = rotate(z_slice, angle=-135, cval=2) + z_slice = z_slice[14:-14, 14:-14] + else: + z_slice = recon[:, :, mid_z].T + else: + z_slice = recon.max(0) + recons.append(z_slice) + + struc_seg = gt_orig_struc[i] + cell_seg = gt_orig_cell[i] + nuc_seg = gt_orig_nuc[i] + + axs[row_index, 0].imshow( + cell_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5 + ) + axs[row_index, 0].imshow(nuc_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5) + axs[row_index, 0].imshow( + struc_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5 + ) + axs[row_index, 0].axis("off") + axs[row_index, 0].set_title("") # (f'GT Seg CellId {c}') + + axs[row_index, 1].imshow(gt_seg.max(0), cmap=cmap_inverted, origin="lower") + axs[row_index, 1].axis("off") + axs[row_index, 1].set_title("") # (f'GT Seg CellId {c}') + + for i, img in enumerate(recons[:2]): + axs[row_index, i + 2].imshow(img, cmap=cmap_inverted, origin="lower") + axs[row_index, i + 2].axis("off") + axs[row_index, i + 2].set_title("") # run_to_displ_name[model_order[i]]) + + axs[row_index, 4].imshow( + gt_sdf[:, :, mid_slice_].T, cmap="seismic", origin="lower", vmin=-2, vmax=2 + ) + axs[row_index, 4].axis("off") + axs[row_index, 4].set_title("") # (f'GT SDF CellId {c}') + + for i, img in enumerate(recons[2:4]): + axs[row_index, i + 5].imshow(img, cmap="seismic", origin="lower", vmin=-2, vmax=2) + axs[row_index, i + 5].axis("off") + axs[row_index, i + 5].set_title("") # run_to_displ_name[model_order[i]]) + + axs[row_index, 7].imshow( + gt_sdf_i[:, :, mid_slice_].T.clip(-0.5, 0.5), cmap="seismic", origin="lower" + ) + axs[row_index, 7].axis("off") + axs[row_index, 7].set_title("") # (f'GT SDF CellId {c}') + + axs[row_index, 8].imshow(recons[-1].clip(-0.5, 0.5), cmap="seismic", origin="lower") + axs[row_index, 8].axis("off") + axs[row_index, 8].set_title("") # (f'GT SDF CellId {c}') + + plt.tight_layout() + plt.savefig(reconstructions_path + "sample_recons.png", dpi=300, bbox_inches="tight") + plt.savefig(reconstructions_path + "sample_recons.pdf", dpi=300, bbox_inches="tight") diff --git a/src/br/analysis/run_drugdata_analysis.py b/src/br/analysis/run_drugdata_analysis.py index d403d06..fbc18a1 100644 --- a/src/br/analysis/run_drugdata_analysis.py +++ b/src/br/analysis/run_drugdata_analysis.py @@ -1,19 +1,20 @@ +import argparse import os +import sys from pathlib import Path + +from br.chandrasekaran_et_al.utils import _plot, perturbation_detection from br.models.compute_features import get_embeddings from br.models.utils import get_all_configs_per_dataset -from br.chandrasekaran_et_al.utils import perturbation_detection, _plot -import sys -import argparse def _get_featurecols(df): - """returna list of featuredata columns""" + """returna list of featuredata columns.""" return [c for c in df.columns if "mu" in c] def _get_featuredata(df): - """return dataframe of just featuredata columns""" + """return dataframe of just featuredata columns.""" return df[_get_featurecols(df)] @@ -25,11 +26,9 @@ def main(args): dataset_name = args.dataset_name DATASET_INFO = get_all_configs_per_dataset(results_path) dataset = DATASET_INFO[dataset_name] - run_names = dataset['names'] + run_names = dataset["names"] - all_ret, df = get_embeddings( - run_names, args.dataset_name, DATASET_INFO, args.embeddings_path - ) + all_ret, df = get_embeddings(run_names, args.dataset_name, DATASET_INFO, args.embeddings_path) all_ret["well_position"] = "A0" # dummy all_ret["Assay_Plate_Barcode"] = "Plate0" # dummy @@ -41,10 +40,10 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script for computing perturbation detection metrics") - parser.add_argument( - "--save_path", type=str, required=True, help="Path to save the results." + parser = argparse.ArgumentParser( + description="Script for computing perturbation detection metrics" ) + parser.add_argument("--save_path", type=str, required=True, help="Path to save the results.") parser.add_argument( "--embeddings_path", type=str, required=True, help="Path to the saved embeddings." ) @@ -63,4 +62,4 @@ def main(args): cellpack dataset python src/br/analysis/run_drugdata_analysis.py --save_path "./outputs_npm1_perturb/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_perturb/" --dataset_name "npm1_perturb" - """ \ No newline at end of file + """ diff --git a/src/br/analysis/run_embeddings.py b/src/br/analysis/run_embeddings.py index 8bafecd..673ab6b 100644 --- a/src/br/analysis/run_embeddings.py +++ b/src/br/analysis/run_embeddings.py @@ -3,6 +3,7 @@ import os import sys from pathlib import Path + from br.analysis.analysis_utils import setup_evaluation_params, setup_gpu, str2bool from br.models.load_models import get_data_and_models from br.models.save_embeddings import save_embeddings @@ -31,7 +32,7 @@ def main(args): skew_scale, ) = setup_evaluation_params(manifest, run_names) - # make save path directory + # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) # save embeddings for each model diff --git a/src/br/analysis/run_features.py b/src/br/analysis/run_features.py index a694c3d..144b993 100644 --- a/src/br/analysis/run_features.py +++ b/src/br/analysis/run_features.py @@ -3,7 +3,9 @@ import os import sys from pathlib import Path + import pandas as pd + from br.analysis.analysis_utils import ( get_feature_params, setup_evaluation_params, @@ -39,7 +41,7 @@ def main(args): ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) max_embed_dim = min(latent_dims) - # make save path directory + # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) # Save model sizes to CSV @@ -81,9 +83,7 @@ def main(args): classification_params, evolve_params, regression_params, - ) = get_feature_params( - config_path + "/results/", args.dataset_name, manifest, keys, run_names - ) + ) = get_feature_params(config_path + "/results/", args.dataset_name, manifest, keys, run_names) metric_list = [ "Rotation Invariance Error", @@ -95,7 +95,6 @@ def main(args): if regression_params["target_cols"]: metric_list.append("Regression") - # Compute multi-metric benchmarking features compute_features( dataset=args.dataset_name, @@ -126,7 +125,7 @@ def main(args): csvs = [i.split(".")[0] for i in csvs] # Remove non metric related csvs csvs = [i for i in csvs if i not in run_names and i not in keys] - csvs = [i for i in csvs if i not in ['image', 'pcloud']] + csvs = [i for i in csvs if i not in ["image", "pcloud"]] # classification and regression metrics are unique to each dataset unique_metrics = [i for i in csvs if "classification" in i or "regression" in i] # Collect dataframe and make plots diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py new file mode 100644 index 0000000..5289d3d --- /dev/null +++ b/src/br/analysis/save_reconstructions.py @@ -0,0 +1,103 @@ +# Free up cache +import argparse +import os +import sys +from pathlib import Path +from br.analysis.analysis_utils import ( + setup_gpu, + str2bool, + generate_reconstructions, + save_supplemental_figure_sdf_reconstructions, + save_supplemental_figure_punctate_reconstructions +) +from br.models.load_models import get_data_and_models + +test_ids_per_dataset_ = {'cellpack': ['9c1ff213-4e9e-4b73-a942-3baf9d37a50f'], + 'pcna': ['7624cd5b-715a-478e-9648-3bac4a73abe8', + '80d40c5e-65bf-43b0-8dea-b697c421ea78', + '6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4', + 'aabbbca4-6c35-4f3d-9467-7d573482f236', + 'd23de56e-bacf-4ec8-8e18-39822fea777b', + 'c382794f-5baf-4b17-8574-62dccbbbaefc', + '50b52c3e-4756-4684-a281-0141525ded9f', + '8713eea5-da72-4644-96fe-ba8340edb67d'], + 'other_punctate': ['721646', '873680', '994027', '490385', '451974', '811336', '835431'], + 'npm1': ['964798', '661110', '644401', '967887', '703621'], + 'other_polymorphic': ['691110', '723687', '816468', '800894'], + } + + +def main(args): + # Setup GPUs and set the device + setup_gpu() + device = "cuda:0" + + # set batch size to 1 for emission stats/features + batch_size = 1 + + # Get config path from CYTODL_CONFIG_PATH + config_path = os.environ.get("CYTODL_CONFIG_PATH") + + test_ids = args.test_ids + if not test_ids: + test_ids = test_ids_per_dataset_[args.dataset_name] + + # Load data and models + ( + data_list, + all_models, + run_names, + model_sizes, + manifest, + keys, + latent_dims, + ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + if args.generate_reconstructions: + generate_reconstructions(all_models, data_list, run_names, keys, test_ids, device, args.save_path) + + if args.sdf: + save_supplemental_figure_sdf_reconstructions(manifest, test_ids, args.save_path) + else: + save_supplemental_figure_punctate_reconstructions(manifest, test_ids, run_names, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing features") + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") + parser.add_argument("--sdf", type=str2bool, default=True, help="Whether the experiments involve SDFs") + parser.add_argument("--test_ids", default=False, nargs='+', help="List of test set cellids to reconstruct") + parser.add_argument( + "--generate_reconstructions", type=str2bool, default=False, help="Whether to skip generating reconstructions" + ) + + + args = parser.parse_args() + + # Validate that required paths are provided + if not args.save_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + + """ + Example run: + + PCNA dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_pcna/reconstructions/" --dataset_name "pcna" --generate_reconstructions True --sdf False + NPM1 dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions True + + Other polymorphic dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_polymorphic/reconstructions/" --dataset_name "other_polymorphic" --test_ids 691110 723687 816468 800894 --generate_reconstructions True + + python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions False + """ diff --git a/src/br/chandrasekaran_et_al/utils.py b/src/br/chandrasekaran_et_al/utils.py index 07a7e33..12d887d 100644 --- a/src/br/chandrasekaran_et_al/utils.py +++ b/src/br/chandrasekaran_et_al/utils.py @@ -2,10 +2,13 @@ import itertools import os from pathlib import Path -import seaborn as sns + import copairs.compute_np as backend +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import pycytominer +import seaborn as sns from copairs.compute import cosine_indexed from copairs.map import ( aggregate, @@ -18,8 +21,7 @@ from sklearn.metrics import average_precision_score from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm -import matplotlib.pyplot as plt -import pycytominer + from br.chandrasekaran_et_al import utils @@ -172,7 +174,12 @@ def _plot(all_rep, save_path): .sort_values(by="q_value")["Metadata_broad_sample"] .values ) - ordered_drugs = all_rep.groupby(['Metadata_broad_sample']).mean().sort_values(by='q_value').reset_index()['Metadata_broad_sample'] + ordered_drugs = ( + all_rep.groupby(["Metadata_broad_sample"]) + .mean() + .sort_values(by="q_value") + .reset_index()["Metadata_broad_sample"] + ) x_order = ordered_drugs g = sns.catplot( From 5af4ad7394532c14a5f4e660dc443f4903679022 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 12:54:12 -0800 Subject: [PATCH 2/9] run pre-commit --- src/br/analysis/analysis_utils.py | 148 +++++++++++++++--------- src/br/analysis/save_reconstructions.py | 60 ++++++---- 2 files changed, 130 insertions(+), 78 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 7b232d5..e9feafc 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -649,6 +649,16 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d eval_scaled_img_resolution, eval_scaled_img_resolution, ) + elif ("pointcloud" in this_run_name) and ("SDF" not in this_run_name): + batch = move(batch, device) + z, z_params = model.get_embeddings(batch, inference=True) + xhat = model.decode_embeddings( + z_params, batch, decode=True, return_canonical=True + ) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + canonical = recon + if "canonical" in xhat.keys(): + canonical = xhat["canonical"].detach().cpu().numpy().squeeze() else: z = model.encode(move(batch, device)) xhat = model.decode(z, return_canonical=True) @@ -683,6 +693,84 @@ def slice_(img, slices=None, z_ind=0): img = img[:, :, mid_z - slices : mid_z + slices].max(2) return img + def _plot_image(input, recon, recon_canonical): + num_slice = 8 + z_ind = 0 + + input = slice_(input, num_slice, z_ind) + recon = slice_(recon, num_slice, z_ind) + recon_canonical = slice_(recon_canonical, num_slice, 2) + + i = 2 + fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) + # ax.imshow(this[:, :, :].max(i).T, origin='lower', cmap='gray_r') + # ax1.imshow(this2[:, :, :].max(i).T, origin='lower', cmap='gray_r') + # ax2.imshow(this3[:, :, :].max(i).T, origin='lower', cmap='gray_r') + ax.imshow(input, cmap="gray_r") + ax1.imshow(recon, cmap="gray_r") + ax2.imshow(recon_canonical, cmap="gray_r") + + ax.set_xticks([]) + ax.set_yticks([]) + ax2.set_xticks([]) + ax2.set_yticks([]) + ax1.set_xticks([]) + ax1.set_yticks([]) + # max_size = 192 + max_size = recon_canonical.shape[1] + ax.set_aspect("equal", adjustable="box") + ax1.set_aspect("equal", adjustable="box") + ax2.set_aspect("equal", adjustable="box") + + # max_size = 6 + ax.set_ylim([0, max_size]) + ax1.set_ylim([0, max_size]) + ax2.set_ylim([0, max_size]) + ax.set_xlim([0, max_size]) + ax1.set_xlim([0, max_size]) + ax2.set_xlim([0, max_size]) + + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + + ax1.spines["top"].set_visible(False) + ax1.spines["right"].set_visible(False) + ax1.spines["bottom"].set_visible(False) + ax1.spines["left"].set_visible(False) + + ax2.spines["top"].set_visible(False) + ax2.spines["right"].set_visible(False) + ax2.spines["bottom"].set_visible(False) + ax2.spines["left"].set_visible(False) + + ax.set_title("Input") + ax1.set_title("Reconstruction") + ax2.set_title("Canonical Reconstruction") + fig.subplots_adjust(wspace=0, hspace=0) + return fig + + def _plot_pc(input, recon, recon_canonical): + max_z = 200 + max_size = 10 + z_ind = 1 + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + for this_p in [input, recon, recon_canonical]: + this_p = this_p[np.where(this_p[:, z_ind] < max_z)[0]] + this_p = this_p[np.where(this_p[:, z_ind] > -max_z)[0]] + axes[i].scatter(this_p[:, 2], this_p[:, 1], c="gray", s=1) + axes[i].spines["top"].set_visible(False) + axes[i].spines["right"].set_visible(False) + axes[i].spines["bottom"].set_visible(False) + axes[i].spines["left"].set_visible(False) + axes[i].set_aspect("equal", adjustable="box") + axes[i].set_ylim([-max_size, max_size]) + axes[i].set_xlim([-max_size, max_size]) + axes[i].set_yticks([]) + axes[i].set_xticks([]) + return fig + for i, c in enumerate(test_ids): row_index = i recons = [] @@ -696,61 +784,11 @@ def slice_(img, slices=None, z_ind=0): recon_path = reconstructions_path + f"{m}/canonical/{c}.npy" recon_canonical = np.load(recon_path).squeeze() - num_slice = 8 - z_ind = 0 - - input = slice_(input, num_slice, z_ind) - recon = slice_(recon, num_slice, z_ind) - recon_canonical = slice_(recon_canonical, num_slice, 2) - - i = 2 - fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) - # ax.imshow(this[:, :, :].max(i).T, origin='lower', cmap='gray_r') - # ax1.imshow(this2[:, :, :].max(i).T, origin='lower', cmap='gray_r') - # ax2.imshow(this3[:, :, :].max(i).T, origin='lower', cmap='gray_r') - ax.imshow(input, cmap="gray_r") - ax1.imshow(recon, cmap="gray_r") - ax2.imshow(recon_canonical, cmap="gray_r") - - ax.set_xticks([]) - ax.set_yticks([]) - ax2.set_xticks([]) - ax2.set_yticks([]) - ax1.set_xticks([]) - ax1.set_yticks([]) - # max_size = 192 - max_size = recon_canonical.shape[1] - ax.set_aspect("equal", adjustable="box") - ax1.set_aspect("equal", adjustable="box") - ax2.set_aspect("equal", adjustable="box") - - # max_size = 6 - ax.set_ylim([0, max_size]) - ax1.set_ylim([0, max_size]) - ax2.set_ylim([0, max_size]) - ax.set_xlim([0, max_size]) - ax1.set_xlim([0, max_size]) - ax2.set_xlim([0, max_size]) - - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - - ax1.spines["top"].set_visible(False) - ax1.spines["right"].set_visible(False) - ax1.spines["bottom"].set_visible(False) - ax1.spines["left"].set_visible(False) - - ax2.spines["top"].set_visible(False) - ax2.spines["right"].set_visible(False) - ax2.spines["bottom"].set_visible(False) - ax2.spines["left"].set_visible(False) - - ax.set_title("Input") - ax1.set_title("Reconstruction") - ax2.set_title("Canonical Reconstruction") - fig.subplots_adjust(wspace=0, hspace=0) + if "image" in m: + fig = _plot_image(input, recon, recon_canonical) + else: + fig = _plot_pc(input, recon, recon_canonical) + fig.savefig( reconstructions_path + f"sample_recons_{c}_{m}.pdf", bbox_inches="tight", dpi=300 ) diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index 5289d3d..61762e6 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -3,28 +3,32 @@ import os import sys from pathlib import Path + from br.analysis.analysis_utils import ( - setup_gpu, - str2bool, generate_reconstructions, + save_supplemental_figure_punctate_reconstructions, save_supplemental_figure_sdf_reconstructions, - save_supplemental_figure_punctate_reconstructions + setup_gpu, + str2bool, ) from br.models.load_models import get_data_and_models -test_ids_per_dataset_ = {'cellpack': ['9c1ff213-4e9e-4b73-a942-3baf9d37a50f'], - 'pcna': ['7624cd5b-715a-478e-9648-3bac4a73abe8', - '80d40c5e-65bf-43b0-8dea-b697c421ea78', - '6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4', - 'aabbbca4-6c35-4f3d-9467-7d573482f236', - 'd23de56e-bacf-4ec8-8e18-39822fea777b', - 'c382794f-5baf-4b17-8574-62dccbbbaefc', - '50b52c3e-4756-4684-a281-0141525ded9f', - '8713eea5-da72-4644-96fe-ba8340edb67d'], - 'other_punctate': ['721646', '873680', '994027', '490385', '451974', '811336', '835431'], - 'npm1': ['964798', '661110', '644401', '967887', '703621'], - 'other_polymorphic': ['691110', '723687', '816468', '800894'], - } +test_ids_per_dataset_ = { + "cellpack": ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"], + "pcna": [ + "7624cd5b-715a-478e-9648-3bac4a73abe8", + "80d40c5e-65bf-43b0-8dea-b697c421ea78", + "6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4", + "aabbbca4-6c35-4f3d-9467-7d573482f236", + "d23de56e-bacf-4ec8-8e18-39822fea777b", + "c382794f-5baf-4b17-8574-62dccbbbaefc", + "50b52c3e-4756-4684-a281-0141525ded9f", + "8713eea5-da72-4644-96fe-ba8340edb67d", + ], + "other_punctate": ["721646", "873680", "994027", "490385", "451974", "811336", "835431"], + "npm1": ["964798", "661110", "644401", "967887", "703621"], + "other_polymorphic": ["691110", "723687", "816468", "800894"], +} def main(args): @@ -53,16 +57,20 @@ def main(args): latent_dims, ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) - # make save path directory + # make save path directory Path(args.save_path).mkdir(parents=True, exist_ok=True) if args.generate_reconstructions: - generate_reconstructions(all_models, data_list, run_names, keys, test_ids, device, args.save_path) + generate_reconstructions( + all_models, data_list, run_names, keys, test_ids, device, args.save_path + ) if args.sdf: save_supplemental_figure_sdf_reconstructions(manifest, test_ids, args.save_path) else: - save_supplemental_figure_punctate_reconstructions(manifest, test_ids, run_names, args.save_path) + save_supplemental_figure_punctate_reconstructions( + manifest, test_ids, run_names, args.save_path + ) if __name__ == "__main__": @@ -72,12 +80,18 @@ def main(args): ) parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") - parser.add_argument("--sdf", type=str2bool, default=True, help="Whether the experiments involve SDFs") - parser.add_argument("--test_ids", default=False, nargs='+', help="List of test set cellids to reconstruct") parser.add_argument( - "--generate_reconstructions", type=str2bool, default=False, help="Whether to skip generating reconstructions" + "--sdf", type=str2bool, default=True, help="Whether the experiments involve SDFs" + ) + parser.add_argument( + "--test_ids", default=False, nargs="+", help="List of test set cellids to reconstruct" + ) + parser.add_argument( + "--generate_reconstructions", + type=str2bool, + default=False, + help="Whether to skip generating reconstructions", ) - args = parser.parse_args() From 8f7ac1ddb3e3b879b186420f26a30ee7fd518254 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 15:09:15 -0800 Subject: [PATCH 3/9] bugfixes to punctate plotting --- src/br/analysis/analysis_utils.py | 70 +++++++++++++++++++------ src/br/analysis/save_reconstructions.py | 2 +- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index e9feafc..0df240f 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -3,7 +3,6 @@ import os import subprocess from pathlib import Path - import matplotlib.colors as mcolors import matplotlib.pyplot as plt import mesh_to_sdf @@ -681,7 +680,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d def save_supplemental_figure_punctate_reconstructions( - df, test_ids, run_names, reconstructions_path + df, test_ids, run_names, reconstructions_path, normalize_across_recons ): def slice_(img, slices=None, z_ind=0): if not slices: @@ -693,6 +692,14 @@ def slice_(img, slices=None, z_ind=0): img = img[:, :, mid_z - slices : mid_z + slices].max(2) return img + + def slice_points_(points, z_max, z_loc=0): + inds = np.where(points[:,z_loc] < z_max)[0] + points = points[inds, :] + inds = np.where(points[:,z_loc] > -z_max)[0] + points = points[inds, :] + return points + def _plot_image(input, recon, recon_canonical): num_slice = 8 z_ind = 0 @@ -751,15 +758,22 @@ def _plot_image(input, recon, recon_canonical): fig.subplots_adjust(wspace=0, hspace=0) return fig - def _plot_pc(input, recon, recon_canonical): - max_z = 200 - max_size = 10 - z_ind = 1 + def _plot_pc(input, recon, recon_canonical, struct, cmap): + z_max = 0.3 + max_size = 15 + z_ind = 2 fig, axes = plt.subplots(1, 3, figsize=(10, 5)) - for this_p in [input, recon, recon_canonical]: - this_p = this_p[np.where(this_p[:, z_ind] < max_z)[0]] - this_p = this_p[np.where(this_p[:, z_ind] > -max_z)[0]] - axes[i].scatter(this_p[:, 2], this_p[:, 1], c="gray", s=1) + for j, this_p in enumerate([input, recon, recon_canonical]): + import ipdb + ipdb.set_trace() + print(this_p.max(axis=0), 'pre', j) + if struct in ['NUP153', 'HIST1H2BJ', 'SMC1A', 'SON']: + this_p = slice_points_(this_p, z_max, z_ind) + print(this_p.max(axis=0), 'post', j) + if this_p.shape[-1] == 3: + axes[i].scatter(this_p[:, 1], this_p[:, 0], c="black", s=2, alpha=0.5) + else: + axes[i].scatter(this_p[:, 1], this_p[:, 0], c=cmap(this_p[:, 3]), s=2, alpha=0.5) axes[i].spines["top"].set_visible(False) axes[i].spines["right"].set_visible(False) axes[i].spines["bottom"].set_visible(False) @@ -771,10 +785,27 @@ def _plot_pc(input, recon, recon_canonical): axes[i].set_xticks([]) return fig - for i, c in enumerate(test_ids): - row_index = i - recons = [] - for m in run_names: + for m in run_names: + cmap = None + if normalize_across_recons: + all_df_input = [] + for c in test_ids: + input_path = reconstructions_path + f"{m}/input/{c}.npy" + input = np.load(input_path).squeeze() + if input.shape[-1] == 4: + this_df = pd.DataFrame(input, columns=['x', 'y', 'z', 's']) + all_df_input.append(this_df) + + all_df_input = pd.concat(all_df_input, axis=0).reset_index(drop=True) + if len(all_df_input) > 0: + all_df_input, cmap = normalize_intensities_and_get_colormap(df=all_df_input, pcts=[5, 95]) + + for i, c in enumerate(test_ids): + struct = 'pcna' + if "structure_name" in df.columns: + struct = df.loc[df['CellId'] == c]['structure_name'].iloc[0] + row_index = i + input_path = reconstructions_path + f"{m}/input/{c}.npy" input = np.load(input_path).squeeze() @@ -787,10 +818,17 @@ def _plot_pc(input, recon, recon_canonical): if "image" in m: fig = _plot_image(input, recon, recon_canonical) else: - fig = _plot_pc(input, recon, recon_canonical) + fig = _plot_pc(input, recon, recon_canonical, struct, cmap) + this_save_path_ = ( + Path(reconstructions_path) / Path(m) + ) + print(this_save_path_) + fig.savefig( + this_save_path_ / Path(f"sample_recons_{c}.pdf"), bbox_inches="tight", dpi=300 + ) fig.savefig( - reconstructions_path + f"sample_recons_{c}_{m}.pdf", bbox_inches="tight", dpi=300 + this_save_path_ / Path(f"sample_recons_{c}.png"), bbox_inches="tight", dpi=300 ) diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index 61762e6..c0793e1 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -33,7 +33,7 @@ def main(args): # Setup GPUs and set the device - setup_gpu() + # setup_gpu() device = "cuda:0" # set batch size to 1 for emission stats/features From d7472a3ba08f0596ba93203368e98c6bbeda588e Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 16:28:13 -0800 Subject: [PATCH 4/9] fix bugs with color norm for PC recons --- src/br/analysis/analysis_utils.py | 112 ++++++++++++++---------- src/br/analysis/save_reconstructions.py | 10 ++- 2 files changed, 73 insertions(+), 49 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 0df240f..5ee53b3 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -3,6 +3,7 @@ import os import subprocess from pathlib import Path + import matplotlib.colors as mcolors import matplotlib.pyplot as plt import mesh_to_sdf @@ -692,11 +693,10 @@ def slice_(img, slices=None, z_ind=0): img = img[:, :, mid_z - slices : mid_z + slices].max(2) return img - def slice_points_(points, z_max, z_loc=0): - inds = np.where(points[:,z_loc] < z_max)[0] + inds = np.where(points[:, z_loc] < z_max)[0] points = points[inds, :] - inds = np.where(points[:,z_loc] > -z_max)[0] + inds = np.where(points[:, z_loc] > -z_max)[0] points = points[inds, :] return points @@ -758,77 +758,93 @@ def _plot_image(input, recon, recon_canonical): fig.subplots_adjust(wspace=0, hspace=0) return fig - def _plot_pc(input, recon, recon_canonical, struct, cmap): + def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): z_max = 0.3 max_size = 15 z_ind = 2 fig, axes = plt.subplots(1, 3, figsize=(10, 5)) - for j, this_p in enumerate([input, recon, recon_canonical]): - import ipdb - ipdb.set_trace() - print(this_p.max(axis=0), 'pre', j) - if struct in ['NUP153', 'HIST1H2BJ', 'SMC1A', 'SON']: + for index_, this_p in enumerate([input, recon, recon_canonical]): + print(this_p.max(axis=0), "pre", index_) + if struct in ["NUP153", "HIST1H2BJ", "SMC1A", "SON"]: this_p = slice_points_(this_p, z_max, z_ind) - print(this_p.max(axis=0), 'post', j) + print(this_p.max(axis=0), "post", index_) if this_p.shape[-1] == 3: - axes[i].scatter(this_p[:, 1], this_p[:, 0], c="black", s=2, alpha=0.5) + axes[index_].scatter(this_p[:, 1], this_p[:, 0], c="black", s=2, alpha=0.5) else: - axes[i].scatter(this_p[:, 1], this_p[:, 0], c=cmap(this_p[:, 3]), s=2, alpha=0.5) - axes[i].spines["top"].set_visible(False) - axes[i].spines["right"].set_visible(False) - axes[i].spines["bottom"].set_visible(False) - axes[i].spines["left"].set_visible(False) - axes[i].set_aspect("equal", adjustable="box") - axes[i].set_ylim([-max_size, max_size]) - axes[i].set_xlim([-max_size, max_size]) - axes[i].set_yticks([]) - axes[i].set_xticks([]) + if not cmap: + this_df = pd.DataFrame(input, columns=["x", "y", "z", "s"]) + all_df_input, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df=this_df, pcts=[5, 95] + ) + this_p = pd.DataFrame(this_p, columns=["x", "y", "z", "s"]) + this_p = normalize_intensities_and_get_colormap_apply(this_p, vmin, vmax) + axes[index_].scatter( + this_p["y"].values, + this_p["x"].values, + c=cmap(this_p["inorm"].values), + s=2, + alpha=0.5, + ) + axes[index_].spines["top"].set_visible(False) + axes[index_].spines["right"].set_visible(False) + axes[index_].spines["bottom"].set_visible(False) + axes[index_].spines["left"].set_visible(False) + axes[index_].set_aspect("equal", adjustable="box") + axes[index_].set_ylim([-max_size, max_size]) + axes[index_].set_xlim([-max_size, max_size]) + axes[index_].set_yticks([]) + axes[index_].set_xticks([]) return fig for m in run_names: - cmap = None - if normalize_across_recons: - all_df_input = [] - for c in test_ids: - input_path = reconstructions_path + f"{m}/input/{c}.npy" - input = np.load(input_path).squeeze() - if input.shape[-1] == 4: - this_df = pd.DataFrame(input, columns=['x', 'y', 'z', 's']) - all_df_input.append(this_df) - - all_df_input = pd.concat(all_df_input, axis=0).reset_index(drop=True) - if len(all_df_input) > 0: - all_df_input, cmap = normalize_intensities_and_get_colormap(df=all_df_input, pcts=[5, 95]) - - for i, c in enumerate(test_ids): - struct = 'pcna' + for i, this_id in enumerate(test_ids): + struct = "pcna" if "structure_name" in df.columns: - struct = df.loc[df['CellId'] == c]['structure_name'].iloc[0] + df["CellId"] = df["CellId"].astype(str) + struct = df.loc[df["CellId"] == this_id]["structure_name"].iloc[0] row_index = i - input_path = reconstructions_path + f"{m}/input/{c}.npy" + input_path = reconstructions_path + f"{m}/input/{this_id}.npy" input = np.load(input_path).squeeze() - recon_path = reconstructions_path + f"{m}/{c}.npy" + recon_path = reconstructions_path + f"{m}/{this_id}.npy" recon = np.load(recon_path).squeeze() - recon_path = reconstructions_path + f"{m}/canonical/{c}.npy" + recon_path = reconstructions_path + f"{m}/canonical/{this_id}.npy" recon_canonical = np.load(recon_path).squeeze() if "image" in m: fig = _plot_image(input, recon, recon_canonical) else: - fig = _plot_pc(input, recon, recon_canonical, struct, cmap) + cmap = None + vmin = None + vmax = None + if normalize_across_recons: + all_df_input = [] + for c in test_ids: + input_path_ = reconstructions_path + f"{m}/input/{this_id}.npy" + input_tmp = np.load(input_path_).squeeze() + if input.shape[-1] == 4: + this_df = pd.DataFrame(input_tmp, columns=["x", "y", "z", "s"]) + all_df_input.append(this_df) + if len(all_df_input) > 0: + all_df_input = pd.concat(all_df_input, axis=0).reset_index(drop=True) + _, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df=all_df_input, pcts=[5, 95] + ) + fig = _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax) - this_save_path_ = ( - Path(reconstructions_path) / Path(m) - ) - print(this_save_path_) + this_save_path_ = Path(reconstructions_path) / Path(m) + print(this_save_path_, this_id) fig.savefig( - this_save_path_ / Path(f"sample_recons_{c}.pdf"), bbox_inches="tight", dpi=300 + this_save_path_ / Path(f"sample_recons_{this_id}.pdf"), + bbox_inches="tight", + dpi=300, ) fig.savefig( - this_save_path_ / Path(f"sample_recons_{c}.png"), bbox_inches="tight", dpi=300 + this_save_path_ / Path(f"sample_recons_{this_id}.png"), + bbox_inches="tight", + dpi=300, ) diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index c0793e1..92095af 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -30,6 +30,8 @@ "other_polymorphic": ["691110", "723687", "816468", "800894"], } +projection_ = {} + def main(args): # Setup GPUs and set the device @@ -69,7 +71,7 @@ def main(args): save_supplemental_figure_sdf_reconstructions(manifest, test_ids, args.save_path) else: save_supplemental_figure_punctate_reconstructions( - manifest, test_ids, run_names, args.save_path + manifest, test_ids, run_names, args.save_path, args.normalize_across_recons ) @@ -92,6 +94,12 @@ def main(args): default=False, help="Whether to skip generating reconstructions", ) + parser.add_argument( + "--normalize_across_recons", + type=str2bool, + default=False, + help="Whether to normalize across all inputs", + ) args = parser.parse_args() From 668501d2a3627817c4e97281c05239480ccd9226 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 16:32:41 -0800 Subject: [PATCH 5/9] remove print --- src/br/analysis/analysis_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 5ee53b3..3328cd5 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -767,7 +767,6 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): print(this_p.max(axis=0), "pre", index_) if struct in ["NUP153", "HIST1H2BJ", "SMC1A", "SON"]: this_p = slice_points_(this_p, z_max, z_ind) - print(this_p.max(axis=0), "post", index_) if this_p.shape[-1] == 3: axes[index_].scatter(this_p[:, 1], this_p[:, 0], c="black", s=2, alpha=0.5) else: From 84ea6ac469fa8fda35d04e97b25ae0e889a37da9 Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 16:39:38 -0800 Subject: [PATCH 6/9] add run docs --- src/br/analysis/save_reconstructions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index 92095af..d36790b 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -114,12 +114,14 @@ def main(args): Example run: PCNA dataset - python src/br/analysis/save_reconstructions.py --save_path "./outputs_pcna/reconstructions/" --dataset_name "pcna" --generate_reconstructions True --sdf False + python src/br/analysis/save_reconstructions.py --save_path "./outputs_pcna/reconstructions/" --dataset_name "pcna" --generate_reconstructions True --sdf False --normalize_across_recons True + NPM1 dataset - python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions True + python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions True --sdf True Other polymorphic dataset - python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_polymorphic/reconstructions/" --dataset_name "other_polymorphic" --test_ids 691110 723687 816468 800894 --generate_reconstructions True + python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_polymorphic/reconstructions/" --dataset_name "other_polymorphic" --test_ids 691110 723687 816468 800894 --generate_reconstructions True --sdf True - python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions False + Other punctate dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_punctate/reconstructions/" --dataset_name "other_punctate" --generate_reconstructions True --normalize_across_recons False --sdf False """ From a8c0969be3a51845b684d1fb8d6c0f95f1a1d4bf Mon Sep 17 00:00:00 2001 From: Ritvik Vasan Date: Mon, 2 Dec 2024 22:28:28 -0800 Subject: [PATCH 7/9] add cellpack reconstruction code + pre-commit --- README.md | 6 +- src/br/analysis/analysis_utils.py | 70 ++++++++++++++------- src/br/analysis/save_reconstructions.py | 28 +++++++-- src/br/data/get_datamodules.py | 3 + src/pointcloudutils/datamodules/cellpack.py | 10 ++- subpackages/image_preprocessing/README.md | 1 + 6 files changed, 87 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 6c2e827..5095a1e 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ # Benchmarking Representations + Code for training and benchmarking morphology appropriate representation learning methods, associated with the following manuscript. > **Interpretable representation learning for 3D multi-piece intracellular structures using point clouds** -> +> > Ritvik Vasan, Alexandra J. Ferrante, Antoine Borensztejn, Christopher L. Frick, Nathalie Gaudreault, Saurabh S. Mogre, Benjamin Morris, Guilherme G. Pires, Susanne M. Rafelski, Julie A. Theriot, Matheus P. Viana > -> bioRxiv 2024.07.25.605164; doi: https://doi.org/10.1101/2024.07.25.605164 +> bioRxiv 2024.07.25.605164; doi: https://doi.org/10.1101/2024.07.25.605164 Our analysis is organized as follows. @@ -31,4 +32,5 @@ If you'd like to reproduce this analysis on our data, check out the following do Coming soon # Contact + Allen Institute for Cell Science (cells@alleninstitute.org) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 3328cd5..c64410a 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -95,8 +95,6 @@ def get_mig_ids(gpu_uuid): elif "GPU" in line and in_gpu_section: # Encounter another GPU section break - # print(line) - if in_gpu_section: # Check for MIG devices if "MIG" in line: @@ -125,7 +123,6 @@ def config_gpu(): # Check if GPU utilization is under 20% (indicating it's idle) if utilization < 20: - # print(uuid, utilization) if is_mig: mig_ids = get_mig_ids(uuid) @@ -383,7 +380,6 @@ def viz_other_punctate(this_save_path, viz_params, stratify_key): for pc_bin in df[stratify_key].unique(): this_df = df.loc[df[stratify_key] == pc_bin].reset_index(drop=True) - print(this_df.shape, struct, pc_bin) np_arr = this_df[["x", "y", "z"]].values colors = cmap(this_df["inorm"].values)[:, :3] np_arr2 = colors @@ -603,7 +599,6 @@ def archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features) dist = np.sum(dist, axis=1) closest_idx = np.argmin(dist) closest_real_id = all_ret.iloc[closest_idx]["CellId"] - print(dist, closest_real_id) mesh = pv.read(mesh_folder + str(closest_real_id) + ".stl") mesh.save(this_save_path / Path(f"{i}.ply")) arch_dict["archetype"].append(i) @@ -681,7 +676,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d def save_supplemental_figure_punctate_reconstructions( - df, test_ids, run_names, reconstructions_path, normalize_across_recons + df, test_ids, run_names, reconstructions_path, normalize_across_recons, dataset_name ): def slice_(img, slices=None, z_ind=0): if not slices: @@ -700,19 +695,24 @@ def slice_points_(points, z_max, z_loc=0): points = points[inds, :] return points - def _plot_image(input, recon, recon_canonical): + def _plot_image(input, recon, recon_canonical, dataset_name): num_slice = 8 - z_ind = 0 + + if dataset_name != "cellpack": + z_ind = 0 + else: + z_ind = 2 input = slice_(input, num_slice, z_ind) recon = slice_(recon, num_slice, z_ind) recon_canonical = slice_(recon_canonical, num_slice, 2) + if dataset_name == "cellpack": + recon = recon.T + recon_canonical = recon_canonical.T + i = 2 fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) - # ax.imshow(this[:, :, :].max(i).T, origin='lower', cmap='gray_r') - # ax1.imshow(this2[:, :, :].max(i).T, origin='lower', cmap='gray_r') - # ax2.imshow(this3[:, :, :].max(i).T, origin='lower', cmap='gray_r') ax.imshow(input, cmap="gray_r") ax1.imshow(recon, cmap="gray_r") ax2.imshow(recon_canonical, cmap="gray_r") @@ -758,17 +758,32 @@ def _plot_image(input, recon, recon_canonical): fig.subplots_adjust(wspace=0, hspace=0) return fig - def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): + def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_name): z_max = 0.3 max_size = 15 - z_ind = 2 + + if dataset_name != "cellpack": + z_ind = 2 + canon_z_ind = 2 + if (struct == "pcna") and ((recon == recon_canonical).all() == False): + canon_z_ind = 1 + else: + z_ind = 0 + canon_z_ind = 1 + xy_inds = [i for i in [0, 1, 2] if i != z_ind] fig, axes = plt.subplots(1, 3, figsize=(10, 5)) for index_, this_p in enumerate([input, recon, recon_canonical]): - print(this_p.max(axis=0), "pre", index_) - if struct in ["NUP153", "HIST1H2BJ", "SMC1A", "SON"]: - this_p = slice_points_(this_p, z_max, z_ind) + if struct in ["NUP153", "HIST1H2BJ", "SMC1A", "SON", "pcna"]: + if index_ == 2: + this_p = slice_points_(this_p, z_max, canon_z_ind) + else: + this_p = slice_points_(this_p, z_max, z_ind) if this_p.shape[-1] == 3: - axes[index_].scatter(this_p[:, 1], this_p[:, 0], c="black", s=2, alpha=0.5) + if (index_ == 2) and (canon_z_ind != z_ind): + xy_inds = [i for i in [0, 1, 2] if i != canon_z_ind] + axes[index_].scatter( + this_p[:, xy_inds[0]], this_p[:, xy_inds[1]], c="black", s=2, alpha=0.5 + ) else: if not cmap: this_df = pd.DataFrame(input, columns=["x", "y", "z", "s"]) @@ -777,9 +792,16 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): ) this_p = pd.DataFrame(this_p, columns=["x", "y", "z", "s"]) this_p = normalize_intensities_and_get_colormap_apply(this_p, vmin, vmax) + + if (index_ == 2) and (canon_z_ind != z_ind): + xy_inds = [i for i in [0, 1, 2] if i != canon_z_ind] + + x_vals = this_p.iloc[:, xy_inds[0]].values + y_vals = this_p.iloc[:, xy_inds[1]].values + axes[index_].scatter( - this_p["y"].values, - this_p["x"].values, + x_vals, + y_vals, c=cmap(this_p["inorm"].values), s=2, alpha=0.5, @@ -793,6 +815,10 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): axes[index_].set_xlim([-max_size, max_size]) axes[index_].set_yticks([]) axes[index_].set_xticks([]) + + axes[0].set_title("Input") + axes[1].set_title("Reconstruction") + axes[2].set_title("Canonical Reconstruction") return fig for m in run_names: @@ -813,7 +839,7 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): recon_canonical = np.load(recon_path).squeeze() if "image" in m: - fig = _plot_image(input, recon, recon_canonical) + fig = _plot_image(input, recon, recon_canonical, dataset_name) else: cmap = None vmin = None @@ -831,7 +857,9 @@ def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax): _, cmap, vmin, vmax = normalize_intensities_and_get_colormap( df=all_df_input, pcts=[5, 95] ) - fig = _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax) + fig = _plot_pc( + input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_name + ) this_save_path_ = Path(reconstructions_path) / Path(m) print(this_save_path_, this_id) diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py index d36790b..6b1bc8a 100644 --- a/src/br/analysis/save_reconstructions.py +++ b/src/br/analysis/save_reconstructions.py @@ -14,7 +14,14 @@ from br.models.load_models import get_data_and_models test_ids_per_dataset_ = { - "cellpack": ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"], + "cellpack": [ + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_0", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_1", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_2", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_3", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_4", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_5", + ], "pcna": [ "7624cd5b-715a-478e-9648-3bac4a73abe8", "80d40c5e-65bf-43b0-8dea-b697c421ea78", @@ -30,12 +37,10 @@ "other_polymorphic": ["691110", "723687", "816468", "800894"], } -projection_ = {} - def main(args): # Setup GPUs and set the device - # setup_gpu() + setup_gpu() device = "cuda:0" # set batch size to 1 for emission stats/features @@ -48,6 +53,9 @@ def main(args): if not test_ids: test_ids = test_ids_per_dataset_[args.dataset_name] + if args.dataset_name == "cellpack": + args.debug = True + # Load data and models ( data_list, @@ -71,7 +79,12 @@ def main(args): save_supplemental_figure_sdf_reconstructions(manifest, test_ids, args.save_path) else: save_supplemental_figure_punctate_reconstructions( - manifest, test_ids, run_names, args.save_path, args.normalize_across_recons + manifest, + test_ids, + run_names, + args.save_path, + args.normalize_across_recons, + args.dataset_name, ) @@ -113,9 +126,12 @@ def main(args): """ Example run: + cellPACK dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_cellpack/reconstructions/" --dataset_name "cellpack" --generate_reconstructions True --sdf False + PCNA dataset python src/br/analysis/save_reconstructions.py --save_path "./outputs_pcna/reconstructions/" --dataset_name "pcna" --generate_reconstructions True --sdf False --normalize_across_recons True - + NPM1 dataset python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions True --sdf True diff --git a/src/br/data/get_datamodules.py b/src/br/data/get_datamodules.py index c302348..15c4a43 100644 --- a/src/br/data/get_datamodules.py +++ b/src/br/data/get_datamodules.py @@ -28,5 +28,8 @@ def get_data(dataset_name, batch_size, results_path, debug=False): config["subsample"]["train"] = 4 config["subsample"]["valid"] = 4 config["subsample"]["test"] = 4 + + if config["_target_"] == "pointcloudutils.datamodules.CellPackDataModule": + config["subset_test"] = True data.append(instantiate(config)) return data diff --git a/src/pointcloudutils/datamodules/cellpack.py b/src/pointcloudutils/datamodules/cellpack.py index 6067150..f27e80f 100644 --- a/src/pointcloudutils/datamodules/cellpack.py +++ b/src/pointcloudutils/datamodules/cellpack.py @@ -45,6 +45,7 @@ def __init__( num_rotations: Optional[int] = 3, upsample: Optional[bool] = False, image: Optional[bool] = False, + subset_test: Optional[bool] = False, ): """""" super().__init__() @@ -68,6 +69,7 @@ def __init__( self.norm_feats = norm_feats self.num_rotations = num_rotations self.image = image + self.subset_test = subset_test def _get_dataset(self, split): return CellPackDataset( @@ -89,6 +91,7 @@ def _get_dataset(self, split): self.num_rotations, self.upsample, self.image, + self.subset_test, ) def train_dataloader(self): @@ -150,6 +153,7 @@ def __init__( num_rotations: Optional[int] = 3, upsample: Optional[bool] = False, image: Optional[bool] = False, + subset_test: Optional[bool] = False, ): self.x_label = x_label self.scale = scale @@ -167,6 +171,7 @@ def __init__( self.num_rules = len(self.packing_rules) self.num_rotations = num_rotations self.image = image + self.subset_test = subset_test self.ref_csv = pd.read_csv(ref_path + "manifest.csv") @@ -187,8 +192,9 @@ def __init__( self.ids = _splits[split] - # if split == "test": - # self.ids = ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"] + if self.subset_test: + if split == "test": + self.ids = ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"] if norm_feats: x = self.ref_csv[ diff --git a/subpackages/image_preprocessing/README.md b/subpackages/image_preprocessing/README.md index b9d6bc3..b9dba6d 100644 --- a/subpackages/image_preprocessing/README.md +++ b/subpackages/image_preprocessing/README.md @@ -43,6 +43,7 @@ Before running the script, please ensure the following: 1. **Set the `TMPDIR`:** You need to set the `TMPDIR` environment variable, as the Snakefile requires a temporary directory. You can do this by executing: + ```bash export TMPDIR=/path/to/your/tmpdir ``` From ba411bb561d81b9bac1be1862e51103ac59238ca Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Tue, 3 Dec 2024 09:50:36 -0800 Subject: [PATCH 8/9] Added a block to check for canonical reconstructions --- src/br/analysis/analysis_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index c64410a..6b976e8 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -614,6 +614,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d this_run_name = run_names[j] this_key = keys[j] for batch in this_data.test_dataloader(): + canonical = None if not isinstance(batch["cell_id"], list): if isinstance(batch["cell_id"], torch.Tensor): cell_id = str(batch["cell_id"].item()) @@ -667,12 +668,13 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d this_save_path_input = Path(save_path) / Path(this_run_name) / Path("input") this_save_path_input.mkdir(parents=True, exist_ok=True) np.save(this_save_path_input / Path(f"{cell_id}.npy"), input) - - this_save_path_canon = ( - Path(save_path) / Path(this_run_name) / Path("canonical") - ) - this_save_path_canon.mkdir(parents=True, exist_ok=True) - np.save(this_save_path_canon / Path(f"{cell_id}.npy"), canonical) + + if canonical: + this_save_path_canon = ( + Path(save_path) / Path(this_run_name) / Path("canonical") + ) + this_save_path_canon.mkdir(parents=True, exist_ok=True) + np.save(this_save_path_canon / Path(f"{cell_id}.npy"), canonical) def save_supplemental_figure_punctate_reconstructions( From 5ef922e2b95b8be01d65e98b5316582f5c91fe0d Mon Sep 17 00:00:00 2001 From: Fatwir Mohammed Date: Tue, 3 Dec 2024 13:32:41 -0800 Subject: [PATCH 9/9] Added a small change to account for datasets without canonical reconstructions --- src/br/analysis/analysis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 6b976e8..c4bb50a 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -669,7 +669,7 @@ def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, d this_save_path_input.mkdir(parents=True, exist_ok=True) np.save(this_save_path_input / Path(f"{cell_id}.npy"), input) - if canonical: + if canonical is not None: this_save_path_canon = ( Path(save_path) / Path(this_run_name) / Path("canonical") )