diff --git a/docs/PREPROCESSING.md b/docs/PREPROCESSING.md index 4d74e0f..e529c11 100644 --- a/docs/PREPROCESSING.md +++ b/docs/PREPROCESSING.md @@ -37,7 +37,7 @@ Preprocessing is divided into three steps that use two different virtual environ # Punctate structures: Generate pointclouds -Edit the data paths in the following file to match the location of the outputs of the alignment, masking, and registration step, then run it. +Use the preprocessed data manifest generated via the alignment, masking, and registration steps from image as input to the pointcloud generation step ``` src @@ -45,7 +45,9 @@ src └── data    └── preprocessing       └── pc_preprocessing -          └── punctate_cyto.py <- Point cloud sampling from raw images for punctate structures here +          └── pcna.py <- Point cloud sampling from raw images for DNA replication foci dataset here +          └── punctate_nuc.py <- Point cloud sampling from raw images of nuclear structures from the WTC-11 hIPS single cell image dataset here +          └── punctate_cyto.py <- Point cloud sampling from raw images of cytoplasmic structures from the WTC-11 hIPS single cell image dataset here ``` # Polymorphic structures: Generate SDFs diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index 4559023..379455f 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -15,6 +15,7 @@ import trimesh import yaml from aicsimageio import AICSImage +from skimage import measure from sklearn.decomposition import PCA from tqdm import tqdm @@ -1025,3 +1026,75 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p plt.tight_layout() plt.savefig(reconstructions_path + "sample_recons.png", dpi=300, bbox_inches="tight") plt.savefig(reconstructions_path + "sample_recons.pdf", dpi=300, bbox_inches="tight") + + +# utility plot functions +def plot_image(ax_array, struct, nuc, mem, vmin, vmax, num_slices=None, show_nuc_countour=True): + mid_z = int(struct.shape[0] / 2) + + if num_slices is None: + num_slices = mid_z * 2 + z_interp = np.linspace(mid_z - num_slices / 2, mid_z + num_slices / 2, num_slices + 1).astype( + int + ) + if z_interp.max() == struct.shape[0]: + z_interp = z_interp[:-1] + + struct = np.where(mem, struct, 0) + mem = mem[z_interp].max(0) + nuc = nuc[z_interp].max(0) + mem_contours = measure.find_contours(mem, 0.5) + nuc_contours = measure.find_contours(nuc, 0.5) + + for ind, _ in enumerate(ax_array): + this_struct = struct + if ind > 0: + this_struct = np.zeros(struct.shape) + ax_array[ind].imshow(this_struct[z_interp].max(0), cmap="gray_r", vmin=vmin, vmax=vmax) + if ind == 0: + if show_nuc_countour: + for contour in nuc_contours: + ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="cyan") + for contour in mem_contours: + ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="magenta") + ax_array[ind].axis("off") + return ax_array, z_interp + + +def plot_pointcloud( + this_ax_array, + points_all, + z_interp, + cmap, + save_path=None, + name=None, + center=None, + save=False, + center_slice=False, +): + this_p = points_all.loc[points_all["z"] < max(z_interp)] + if center_slice: + this_p = this_p.loc[this_p["z"] > min(z_interp)] + print(this_p.shape) + intensity = this_p.inorm.values + this_ax_array.scatter( + this_p["x"].values, this_p["y"].values, c=cmap(intensity), s=0.3, alpha=0.5 + ) + this_ax_array.axis("off") + if save: + z_center, y_center, x_center = center[0], center[1], center[2] + + # Center and scale for viz + this_p["z"] = this_p["z"] - z_center + this_p["y"] = this_p["y"] - y_center + this_p["x"] = this_p["x"] - x_center + + this_p["z"] = 0.1 * this_p["z"] + this_p["x"] = 0.1 * this_p["x"] + this_p["y"] = 0.1 * this_p["y"] + Path(save_path).mkdir(parents=True, exist_ok=True) + colors = cmap(this_p["inorm"].values)[:, :3] + np_arr = this_p[["x", "y", "z"]].values + np_arr2 = colors + np_arr = np.concatenate([np_arr, np_arr2], axis=1) + np.save(Path(save_path) / Path(f"{name}.npy"), np_arr) diff --git a/src/br/analysis/visualize_pointclouds.py b/src/br/analysis/visualize_pointclouds.py new file mode 100644 index 0000000..4401dad --- /dev/null +++ b/src/br/analysis/visualize_pointclouds.py @@ -0,0 +1,225 @@ +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from br.analysis.analysis_utils import plot_image, plot_pointcloud +from br.data.preprocessing.pc_preprocessing.pcna import ( + compute_labels as compute_labels_pcna, +) +from br.data.preprocessing.pc_preprocessing.punctate_cyto import ( + compute_labels as compute_labels_var_cyto, +) +from br.data.preprocessing.pc_preprocessing.punctate_nuc import ( + compute_labels as compute_labels_var_nuc, +) +from br.features.utils import normalize_intensities_and_get_colormap_apply + +dataset_dict = { + "pcna": {"raw_ind": 2, "nuc_ind": 6, "mem_ind": None}, + "other_punctate": {"raw_ind": 2, "nuc_ind": 3, "mem_ind": 4}, +} + +viz_norms = { + "CETN2": [440, 800], + "NUP153": [420, 600], + "HIST1H2BJ": [450, 2885], + "SON": [420, 1500], + "SLC25A17": [400, 515], + "RAB5A": [420, 600], + "SMC1A": [450, 630], +} + +cell_ids_ = { + "pcna": [ + "7624cd5b-715a-478e-9648-3bac4a73abe8", + "80d40c5e-65bf-43b0-8dea-b697c421ea78", + "6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4", + "aabbbca4-6c35-4f3d-9467-7d573482f236", + "d23de56e-bacf-4ec8-8e18-39822fea777b", + "c382794f-5baf-4b17-8574-62dccbbbaefc", + "50b52c3e-4756-4684-a281-0141525ded9f", + "8713eea5-da72-4644-96fe-ba8340edb67d", + ], + "other_punctate": [721646, 873680, 994027, 490385, 451974, 811336, 835431], +} + + +def main(args): + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + orig_image_df = pd.read_parquet(args.preprocessed_manifest) + + if args.global_path: + orig_image_df["registered_path"] = orig_image_df["registered_path"].apply( + lambda x: args.global_path + x + ) + + assert args.dataset_name in ["pcna", "other_punctate"] + dataset_ = dataset_dict[args.dataset_name] + raw_ind = dataset_["raw_ind"] + nuc_ind = dataset_["nuc_ind"] + mem_ind = dataset_["mem_ind"] + + strat = args.class_column + strat_val = args.class_label + + if not strat: + cell_ids = cell_ids_[args.dataset_name] + orig_image_df = orig_image_df.loc[orig_image_df["CellId"].isin(cell_ids)].reset_index( + drop=True + ) + else: + orig_image_df = orig_image_df.loc[orig_image_df[strat] == strat_val].sample(n=1) + + for _, this_image in orig_image_df.iterrows(): + cell_id = this_image["CellId"] + if not strat: + strat_val = this_image["structure_name"] + + if args.dataset_name == "pcna": + points_all, _, img, center = compute_labels_pcna(this_image, False) + vmin, vmax = None, None + num_slices = 15 + center_slice = True + elif args.dataset_name == "other_punctate": + assert strat == "structure_name" + if strat_val in ["CETN2", "RAB5A", "SLC25A17"]: + points_all, _, img, center = compute_labels_var_cyto(this_image, False) + center_slice = False + num_slices = None + else: + center_slice = True + points_all, _, img, center = compute_labels_var_nuc(this_image, False) + num_slices = 1 + this_viz_norm = viz_norms[strat_val] + vmin = this_viz_norm[0] + vmax = this_viz_norm[1] + + img_raw = img[raw_ind] + img_nuc = img[nuc_ind] + img_raw = np.where(img_raw < 60000, img_raw, img_raw.min()) + img_mem = img_nuc + if mem_ind is not None: + img_mem = img[mem_ind] + + if (args.dataset_name == "other_punctate") and ( + strat_val in ["CETN2", "RAB5A", "SLC25A17"] + ): + img_raw = np.where(img_mem, img_raw, 0) # mask by mem/nuc seg + else: + img_raw = np.where(img_nuc, img_raw, 0) # mask by mem/nuc seg + + # Sample sparse point cloud and get images + probs2 = points_all["s"].values + probs2 = np.where(probs2 < 0, 0, probs2) + probs2 = probs2 / probs2.sum() + idxs2 = np.random.choice(np.arange(len(probs2)), size=2048, replace=True, p=probs2) + points = points_all.iloc[idxs2].reset_index(drop=True) + + if not vmin: + vmin = points["s"].min() + vmax = points["s"].max() + points = normalize_intensities_and_get_colormap_apply(points, vmin, vmax) + points_all = normalize_intensities_and_get_colormap_apply(points_all, vmin, vmax) + + save = True + fig, ax_array = plt.subplots(1, 3, figsize=(10, 5)) + + ax_array, z_interp = plot_image( + ax_array, + img_raw, + img_nuc, + img_mem, + vmin, + vmax, + num_slices=num_slices, + show_nuc_countour=True, + ) + ax_array[0].set_title("Raw image") + + name = strat_val + "_" + str(cell_id) + + plot_pointcloud( + ax_array[1], + points_all, + z_interp, + plt.get_cmap("YlGnBu"), + save_path=args.save_path, + name=name, + center=center, + save=False, + center_slice=center_slice, + ) + ax_array[1].set_title("Sampling dense PC") + + plot_pointcloud( + ax_array[2], + points, + z_interp, + plt.get_cmap("YlGnBu"), + save_path=args.save_path, + name=name, + center=center, + save=save, + center_slice=center_slice, + ) + ax_array[2].set_title("Sampling sparse PC") + print(f"Saving {name}.png") + fig.savefig(Path(args.save_path) / Path(f"{name}.png"), bbox_inches="tight", dpi=300) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing features") + parser.add_argument("--save_path", type=str, required=True, help="Path to save results.") + parser.add_argument( + "--global_path", + type=str, + default=None, + required=False, + help="Path to append to relative paths in preprocessed manifest", + ) + parser.add_argument( + "--preprocessed_manifest", + type=str, + required=True, + help="Path to processed single cell image manifest.", + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument( + "--class_column", + type=str, + default=None, + required=False, + help="Column name of class to use for sampling, e.g. cell_stage_fine", + ) + parser.add_argument( + "--class_label", + type=str, + default=None, + required=False, + help="Specific class label to sample, e.g. lateS", + ) + args = parser.parse_args() + + # Validate that required paths are provided + if not args.save_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + + """ + Example run: + + PCNA dataset + python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_pcna/processed/manifest.parquet" --dataset_name "pcna" --class_column "cell_stage_fine" --class_label "lateS" --global_path "./subpackages/image_preprocessing/" + + Other punctate dataset + python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_variance/processed/manifest.parquet" --dataset_name "other_punctate" --class_column "structure_name" --class_label "CETN2" --global_path "./subpackages/image_preprocessing/" + """ diff --git a/src/br/data/preprocessing/pc_preprocessing/pcna.py b/src/br/data/preprocessing/pc_preprocessing/pcna.py index 15fb431..5e1f080 100644 --- a/src/br/data/preprocessing/pc_preprocessing/pcna.py +++ b/src/br/data/preprocessing/pc_preprocessing/pcna.py @@ -1,4 +1,6 @@ +import argparse from multiprocessing import Pool +from pathlib import Path import numpy as np import pandas as pd @@ -64,12 +66,12 @@ def compute_labels(row, save=True): cell_id = str(row["CellId"]) - save_path = path_prefix + cell_id + ".ply" + save_path = Path(path_prefix) / Path(cell_id + ".ply") new_cents = new_cents.astype(float) cloud = PyntCloud(new_cents) - cloud.to_file(save_path) + cloud.to_file(str(save_path)) def get_center_of_mass(img): @@ -77,13 +79,21 @@ def get_center_of_mass(img): return np.floor(center_of_mass + 0.5).astype(int) -if __name__ == "__main__": - df = pd.read_csv(PCNA_SINGLE_CELL_PATH) +def main(args): + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + df = pd.read_parquet(args.preprocessed_manifest) + + if args.global_path: + df["registered_path"] = df["registered_path"].apply(lambda x: args.global_path + x) - path_prefix = SAVE_LOCATION + global path_prefix + path_prefix = args.save_path all_rows = [] - for ind, row in tqdm(df.iterrows(), total=len(df)): + for _, row in tqdm(df.iterrows(), total=len(df)): all_rows.append(row) with Pool(40) as p: @@ -97,3 +107,31 @@ def get_center_of_mass(img): desc="compute_everything", ) ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script for computing point clouds for PCNA dataset" + ) + parser.add_argument("--save_path", type=str, required=True, help="Path to save results.") + parser.add_argument( + "--global_path", + type=str, + default=None, + required=False, + help="Path to append to relative paths in preprocessed manifest", + ) + parser.add_argument( + "--preprocessed_manifest", + type=str, + required=True, + help="Path to processed single cell image manifest.", + ) + args = parser.parse_args() + main(args) + + """ + Example run: + + python src/br/data/preprocessing/pc_preprocessing/pcna --save_path "./make_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_pcna/processed/manifest.parquet" --global_path "./subpackages/image_preprocessing/" + """ diff --git a/src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py b/src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py index 0aef8e3..d1a97ae 100644 --- a/src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py +++ b/src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py @@ -1,5 +1,7 @@ +import argparse import warnings from multiprocessing import Pool +from pathlib import Path import numpy as np import pandas as pd @@ -15,18 +17,24 @@ "endosomes": 500, "peroxisomes": 500, "centrioles": 500, + "RAB5A": 500, + "SLC25A17": 500, + "CETN2": 500, } REP_DICT = { "endosomes": True, "peroxisomes": True, "centrioles": True, + "RAB5A": True, + "SLC25A17": True, + "CETN2": True, } def compute_labels(row, save=True): num_points = 20480 path = row["crop_raw"] - structure_name = row["Structure"] + structure_name = row["structure_name"] img_full = AICSImage(path).data[0] raw = img_full[2] # raw struct @@ -90,11 +98,11 @@ def compute_labels(row, save=True): cell_id = str(row["CellId"]) - save_path = path_prefix + cell_id + ".ply" + save_path = Path(path_prefix) / Path(cell_id + ".ply") new_cents = new_cents.astype(float) cloud = PyntCloud(new_cents) - cloud.to_file(save_path) + cloud.to_file(str(save_path)) def get_center_of_mass(img): @@ -102,13 +110,22 @@ def get_center_of_mass(img): return np.floor(center_of_mass + 0.5).astype(int) -if __name__ == "__main__": - df = pd.read_parquet(SINGLE_CELL_IMAGE_PATH) +def main(args): + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + df = pd.read_parquet(args.preprocessed_manifest) + df = df.loc[df["structure_name"].isin(SKEW_EXP_DICT.keys())] - path_prefix = SAVE_LOCATION + if args.global_path: + df["registered_path"] = df["registered_path"].apply(lambda x: args.global_path + x) + + global path_prefix + path_prefix = args.save_path all_rows = [] - for ind, row in tqdm(df.iterrows(), total=len(df)): + for _, row in tqdm(df.iterrows(), total=len(df)): all_rows.append(row) with Pool(40) as p: @@ -122,3 +139,31 @@ def get_center_of_mass(img): desc="compute_everything", ) ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script for computing point clouds for cytoplasmic structures from WTC-11 hIPS single cell image dataset" + ) + parser.add_argument("--save_path", type=str, required=True, help="Path to save results.") + parser.add_argument( + "--global_path", + type=str, + default=None, + required=False, + help="Path to append to relative paths in preprocessed manifest", + ) + parser.add_argument( + "--preprocessed_manifest", + type=str, + required=True, + help="Path to processed single cell image manifest.", + ) + args = parser.parse_args() + main(args) + + """ + Example run: + + python src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py --save_path "./make_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_variance/processed/manifest.parquet" --global_path "./subpackages/image_preprocessing/" + """ diff --git a/src/br/data/preprocessing/pc_preprocessing/punctate_nuc.py b/src/br/data/preprocessing/pc_preprocessing/punctate_nuc.py index 495b08c..3c551c0 100644 --- a/src/br/data/preprocessing/pc_preprocessing/punctate_nuc.py +++ b/src/br/data/preprocessing/pc_preprocessing/punctate_nuc.py @@ -1,3 +1,7 @@ +import argparse +from multiprocessing import Pool +from pathlib import Path + import numpy as np import pandas as pd from pyntcloud import PyntCloud @@ -5,6 +9,8 @@ from skimage.io import imread from tqdm import tqdm +STRUCTS = ["HIST1H2BJ", "NUP153", "SMC1A", "SON"] + def compute_labels(row, save=True): path = row["registered_path"] @@ -64,11 +70,11 @@ def compute_labels(row, save=True): cell_id = str(row["CellId"]) - save_path = path_prefix + cell_id + ".ply" + save_path = Path(path_prefix) / Path(cell_id + ".ply") new_cents = new_cents.astype(float) cloud = PyntCloud(new_cents) - cloud.to_file(save_path) + cloud.to_file(str(save_path)) def get_center_of_mass(img): @@ -76,19 +82,23 @@ def get_center_of_mass(img): return np.floor(center_of_mass + 0.5).astype(int) -if __name__ == "__main__": - df = pd.read_parquet(SINGLE_CELL_IMAGE_PATH) +def main(args): + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + df = pd.read_parquet(args.preprocessed_manifest) + df = df.loc[df["structure_name"].isin(STRUCTS)] - path_prefix = SAVE_LOCATION + if args.global_path: + df["registered_path"] = df["registered_path"].apply(lambda x: args.global_path + x) + + global path_prefix + path_prefix = args.save_path all_rows = [] - for ind, row in tqdm(df.iterrows(), total=len(df)): + for _, row in tqdm(df.iterrows(), total=len(df)): all_rows.append(row) - # if str(row['CellId']) == '660844': - # print('yes') - # compute_labels(row) - - from multiprocessing import Pool with Pool(40) as p: _ = tuple( @@ -101,3 +111,31 @@ def get_center_of_mass(img): desc="compute_everything", ) ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script for computing point clouds for nuclear structures from WTC-11 hIPS single cell image dataset" + ) + parser.add_argument("--save_path", type=str, required=True, help="Path to save results.") + parser.add_argument( + "--global_path", + type=str, + default=None, + required=False, + help="Path to append to relative paths in preprocessed manifest", + ) + parser.add_argument( + "--preprocessed_manifest", + type=str, + required=True, + help="Path to processed single cell image manifest.", + ) + args = parser.parse_args() + main(args) + + """ + Example run: + + python src/br/data/preprocessing/pc_preprocessing/punctate_nuc.py --save_path "./make_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_variance/processed/manifest.parquet" --global_path "./subpackages/image_preprocessing/" + """ diff --git a/src/br/notebooks/visualize_pointclouds.ipynb b/src/br/notebooks/visualize_pointclouds.ipynb deleted file mode 100644 index 555899f..0000000 --- a/src/br/notebooks/visualize_pointclouds.ipynb +++ /dev/null @@ -1,421 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "e4c82855-27a4-4942-a437-f09a5b666db4", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "from pyntcloud import PyntCloud\n", - "from skimage import measure\n", - "from skimage.io import imread\n", - "\n", - "from br.data.preprocessing.pc_preprocessing.pcna import (\n", - " compute_labels as compute_labels_pcna,\n", - ")\n", - "from br.data.preprocessing.pc_preprocessing.punctate_cyto import (\n", - " compute_labels as compute_labels_var_cyto,\n", - ")\n", - "from br.data.preprocessing.pc_preprocessing.punctate_nuc import (\n", - " compute_labels as compute_labels_var_nuc,\n", - ")\n", - "from br.features.utils import normalize_intensities_and_get_colormap_apply" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9cf4ecc9-18c1-4315-bb50-90a338c4b2b7", - "metadata": {}, - "outputs": [], - "source": [ - "# utility plot functions\n", - "def plot_image(ax_array, struct, nuc, mem, vmin, vmax, num_slices=None, show_nuc_countour=True):\n", - " mid_z = int(struct.shape[0] / 2)\n", - "\n", - " if num_slices is None:\n", - " num_slices = mid_z * 2\n", - " z_interp = np.linspace(mid_z - num_slices / 2, mid_z + num_slices / 2, num_slices + 1).astype(\n", - " int\n", - " )\n", - " if z_interp.max() == struct.shape[0]:\n", - " z_interp = z_interp[:-1]\n", - "\n", - " struct = np.where(mem, struct, 0)\n", - " mem = mem[z_interp].max(0)\n", - " nuc = nuc[z_interp].max(0)\n", - " mem_contours = measure.find_contours(mem, 0.5)\n", - " nuc_contours = measure.find_contours(nuc, 0.5)\n", - "\n", - " for ind, _ in enumerate(ax_array):\n", - " this_struct = struct\n", - " if ind > 0:\n", - " this_struct = np.zeros(struct.shape)\n", - " ax_array[ind].imshow(this_struct[z_interp].max(0), cmap=\"gray_r\", vmin=vmin, vmax=vmax)\n", - " if ind == 0:\n", - " if show_nuc_countour:\n", - " for contour in nuc_contours:\n", - " ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c=\"cyan\")\n", - " for contour in mem_contours:\n", - " ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c=\"magenta\")\n", - " ax_array[ind].axis(\"off\")\n", - " return ax_array, z_interp\n", - "\n", - "\n", - "def plot_pointcloud(\n", - " this_ax_array,\n", - " points_all,\n", - " z_interp,\n", - " cmap,\n", - " save_path=None,\n", - " name=None,\n", - " center=None,\n", - " save=False,\n", - " center_slice=False,\n", - "):\n", - " this_p = points_all.loc[points_all[\"z\"] < max(z_interp)]\n", - " if center_slice:\n", - " this_p = this_p.loc[this_p[\"z\"] > min(z_interp)]\n", - " print(this_p.shape)\n", - " intensity = this_p.inorm.values\n", - " this_ax_array.scatter(\n", - " this_p[\"x\"].values, this_p[\"y\"].values, c=cmap(intensity), s=0.3, alpha=0.5\n", - " )\n", - " this_ax_array.axis(\"off\")\n", - " if save:\n", - " z_center, y_center, x_center = center[0], center[1], center[2]\n", - "\n", - " # Center and scale for viz\n", - " this_p[\"z\"] = this_p[\"z\"] - z_center\n", - " this_p[\"y\"] = this_p[\"y\"] - y_center\n", - " this_p[\"x\"] = this_p[\"x\"] - x_center\n", - "\n", - " this_p[\"z\"] = 0.1 * this_p[\"z\"]\n", - " this_p[\"x\"] = 0.1 * this_p[\"x\"]\n", - " this_p[\"y\"] = 0.1 * this_p[\"y\"]\n", - " Path(save_path).mkdir(parents=True, exist_ok=True)\n", - " colors = cmap(this_p[\"inorm\"].values)[:, :3]\n", - " np_arr = this_p[[\"x\", \"y\", \"z\"]].values\n", - " np_arr2 = colors\n", - " np_arr = np.concatenate([np_arr, np_arr2], axis=1)\n", - " np.save(Path(save_path) / Path(f\"{name}.npy\"), np_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fb693743-c6f0-43f0-b18e-fda4e0892a4b", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "os.chdir(\"../../\")\n", - "save_path = \"./viz_variance_pointclouds2/\"\n", - "# save_path = './viz_pcna_pointclouds/'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f25f30ad-6f79-4ad0-8970-c9c83cff339b", - "metadata": {}, - "outputs": [], - "source": [ - "# PCNA\n", - "# PCNA_SINGLE_CELL_PROCESSED_PATH = \"\"\n", - "# ORIG_SINGLE_CELL_MANIFEST = \"\"\n", - "# orig_image_df = pd.read_parquet(PCNA_SINGLE_CELL_PROCESSED_PATH)\n", - "# df_all = pd.read_csv(ORIG_SINGLE_CELL_MANIFEST)\n", - "# orig_image_df = orig_image_df.merge(df_all[['CellId', 'crop_raw', 'crop_seg']], on='CellId')\n", - "# raw_ind = 2\n", - "# nuc_ind = 6\n", - "# mem_ind = 7\n", - "\n", - "# Other punctate\n", - "# PUNCTATE_SINGLE_CELL_PROCESSED_PATH = \"\"\n", - "# ORIG_SINGLE_CELL_MANIFEST = \"\"\n", - "orig_image_df = pd.read_parquet(PUNCTATE_SINGLE_CELL_PROCESSED_PATH)\n", - "df_full = pd.read_csv(ORIG_SINGLE_CELL_MANIFEST)\n", - "orig_image_df = orig_image_df.merge(df_full[[\"CellId\", \"crop_seg\"]], on=\"CellId\")\n", - "raw_ind = 2\n", - "nuc_ind = 3\n", - "mem_ind = 4\n", - "\n", - "# for nuc structures\n", - "df = pd.read_parquet(PUNCTATE_SINGLE_CELL_PROCESSED_PATH)\n", - "orig_image_df = orig_image_df.merge(df[[\"registered_path\", \"CellId\"]], on=\"CellId\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a64912b-8d65-4e6d-9037-13bd46652a3a", - "metadata": {}, - "outputs": [], - "source": [ - "# Sample CellId\n", - "# strat = 'cell_stage_fine'\n", - "# strat_val = 'lateS-G2'\n", - "\n", - "strat = \"Structure\"\n", - "strat_val = \"SON\"\n", - "this_image = orig_image_df.loc[orig_image_df[strat] == strat_val].sample(n=1)\n", - "# this_image = orig_image_df.loc[orig_image_df['CellId'] == 'c6b66235-554c-4fd3-b0a2-a1e5468afb64']\n", - "cell_id = this_image[\"CellId\"].iloc[0]\n", - "strat_val = this_image[strat].iloc[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a0e26050-d34c-4c5d-872e-e47d3656fdd7", - "metadata": {}, - "outputs": [], - "source": [ - "strat_val" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b147ef5f-a6d8-47c3-9f59-d2c8e96fb695", - "metadata": {}, - "outputs": [], - "source": [ - "# Sample dense point cloud and get images\n", - "\n", - "# points_all, struct, img, center = compute_labels_pcna(this_image.iloc[0], False)\n", - "points_all, struct, img, center = compute_labels_var_nuc(this_image.iloc[0], False)\n", - "# points_all, struct, img, center = compute_labels_var_cyto(this_image.iloc[0], False)\n", - "img_raw = img[raw_ind]\n", - "img_nuc = img[nuc_ind]\n", - "img_mem = img[mem_ind]\n", - "img_raw = np.where(img_raw < 60000, img_raw, img_raw.min())\n", - "\n", - "# from saved PC\n", - "# points_all = PyntCloud.from_file(this_image['pcloud_path_updated_morepoints'].iloc[0]).points\n", - "# z_center, y_center, x_center = center[0], center[1], center[2]\n", - "# # Center and scale for viz\n", - "# points_all[\"z\"] = points_all[\"z\"] + z_center\n", - "# points_all[\"y\"] = points_all[\"y\"] + y_center\n", - "# points_all[\"x\"] = points_all[\"x\"] + x_center\n", - "\n", - "# points_saved = PyntCloud.from_file(this_image['pcloud_path_structure_norm'].iloc[0]).points\n", - "# z_center, y_center, x_center = center[0], center[1], center[2]\n", - "# # Center and scale for viz\n", - "# points_saved[\"z\"] = points_saved[\"z\"] + z_center\n", - "# points_saved[\"y\"] = points_saved[\"y\"] + y_center\n", - "# points_saved[\"x\"] = points_saved[\"x\"] + x_center" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2cbe7541-b31c-4353-a784-8f3e7619cb78", - "metadata": {}, - "outputs": [], - "source": [ - "# Sample sparse point cloud and get images\n", - "\n", - "probs2 = points_all[\"s\"].values\n", - "probs2 = np.where(probs2 < 0, 0, probs2)\n", - "probs2 = probs2 / probs2.sum()\n", - "idxs2 = np.random.choice(np.arange(len(probs2)), size=2048, replace=True, p=probs2)\n", - "points = points_all.iloc[idxs2].reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea1d39cb-0d9e-4352-8d1f-0578a28a069e", - "metadata": {}, - "outputs": [], - "source": [ - "# Apply contrast to point clouds\n", - "\n", - "# for perox 415, 515\n", - "# for endos 440, 600\n", - "vmin = 420\n", - "vmax = 600\n", - "points = normalize_intensities_and_get_colormap_apply(points, vmin, vmax)\n", - "points_all = normalize_intensities_and_get_colormap_apply(points_all, vmin, vmax)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5e44410-2f1f-4250-98cb-689f29c0c45d", - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "\n", - "save = True\n", - "center_slice = False\n", - "\n", - "fig, ax_array = plt.subplots(1, 3, figsize=(10, 5))\n", - "\n", - "ax_array, z_interp = plot_image(\n", - " ax_array, img_raw, img_nuc, img_mem, vmin, vmax, num_slices=15, show_nuc_countour=True\n", - ")\n", - "ax_array[0].set_title(\"Raw image\")\n", - "\n", - "name = strat_val + \"_\" + str(cell_id)\n", - "\n", - "plot_pointcloud(\n", - " ax_array[1],\n", - " points_all,\n", - " z_interp,\n", - " plt.get_cmap(\"YlGnBu\"),\n", - " save_path=save_path,\n", - " name=name,\n", - " center=center,\n", - " save=False,\n", - " center_slice=center_slice,\n", - ")\n", - "ax_array[1].set_title(\"Sampling dense PC\")\n", - "\n", - "plot_pointcloud(\n", - " ax_array[2],\n", - " points,\n", - " z_interp,\n", - " plt.get_cmap(\"YlGnBu\"),\n", - " save_path=save_path,\n", - " name=name,\n", - " center=center,\n", - " save=save,\n", - " center_slice=center_slice,\n", - ")\n", - "ax_array[2].set_title(\"Sampling sparse PC\")\n", - "\n", - "# plt.show()\n", - "fig.savefig(save_path + f\"/{name}.png\", bbox_inches=\"tight\", dpi=300)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c44e606-1326-4e9d-aba2-bfb8a570a1c5", - "metadata": {}, - "outputs": [], - "source": [ - "# add scalebar\n", - "# df = orig_image_df\n", - "# from aicsimageio import AICSImage\n", - "# from matplotlib_scalebar.scalebar import ScaleBar\n", - "# selected_cellids=['c6b66235-554c-4fd3-b0a2-a1e5468afb64']\n", - "# sub_slice_list = ['PCNA']\n", - "# padding=10\n", - "# imgs = []\n", - "# for i, cell_id in enumerate(selected_cellids):\n", - "# structure=sub_slice_list[i]\n", - "# img_path = df.loc[df['CellId'] == cell_id, 'crop_raw'].values[0]\n", - "# seg_path = df.loc[df['CellId'] == cell_id, 'crop_seg'].values[0]\n", - "# img = AICSImage(img_path).data.squeeze()[-1]\n", - "# mem = AICSImage(seg_path).data.squeeze()[1]\n", - "\n", - "# mem = mem[z_interp].max(0)\n", - "# mem_contours = measure.find_contours(mem, 0.5)\n", - "\n", - "# if structure in ['HIST1H2BJ','NUP153','SMC1A']:\n", - "# seg_idx = 0\n", - "# else:\n", - "# seg_idx = 1\n", - "# seg = AICSImage(seg_path).data.squeeze()[seg_idx]\n", - "# binary_mask = seg.astype(bool)\n", - "# background_value = np.median(img)\n", - "# masked_img = np.full_like(img, fill_value=background_value)\n", - "# masked_img[binary_mask] = img[binary_mask]\n", - "# if structure == 'NUP153':\n", - "# slice = masked_img.shape[0] // 2\n", - "# displ_img = masked_img[slice,:,:]\n", - "# else:\n", - "# displ_img = masked_img.max(0)\n", - "# rows = np.any(displ_img, axis=1)\n", - "# cols = np.any(displ_img, axis=0)\n", - "# rmin, rmax = np.where(rows)[0][[0, -1]]\n", - "# cmin, cmax = np.where(cols)[0][[0, -1]]\n", - "# rmin = max(rmin - padding, 0)\n", - "# rmax = min(rmax + padding, displ_img.shape[0])\n", - "# cmin = max(cmin - padding, 0)\n", - "# cmax = min(cmax + padding, displ_img.shape[1])\n", - "# res_img = displ_img[rmin:rmax, cmin:cmax]\n", - "# imgs.append(res_img)\n", - "\n", - "# # scale_formatter = lambda value, unit: f\"\"\n", - "\n", - "# max_ht, max_wid = imgs[0].shape\n", - "# for i, cell_id in enumerate(selected_cellids, start=1):\n", - "# structure=sub_slice_list[i-1]\n", - "# img = imgs[i-1]\n", - "# background_value = np.median(img)\n", - "\n", - "# pad_height = max(max_ht - img.shape[0], 0)\n", - "# pad_width = max(max_wid - img.shape[1], 0)\n", - "\n", - "# pad_img = np.pad(img,\n", - "# pad_width=((pad_height//2, pad_height - pad_height//2),\n", - "# (pad_width//2, pad_width - pad_width//2)),\n", - "# mode='constant',\n", - "# constant_values=background_value)\n", - "\n", - "# fig, ax = plt.subplots(1,1, figsize=(8, 5))\n", - "# ax.imshow(pad_img, cmap='gray_r')\n", - "# for contour in mem_contours:\n", - "# ax.plot(contour[:, 1], contour[:, 0], linewidth=1, c='magenta')\n", - "# ax.set_title(f'CellId: {cell_id}; {structure}')\n", - "# ax.axis('off')\n", - "# scalebar = ScaleBar(0.108333, 'um', length_fraction=0.25,\n", - "# location='upper right',\n", - "# frameon=True,\n", - "# color='black',\n", - "# scale_loc='bottom',\n", - "# box_color='white',\n", - "# box_alpha=1)\n", - "# #scale_formatter=scale_formatter)\n", - "# ax.add_artist(scalebar)\n", - "\n", - "# fig.savefig(save_path + 'mids-lates_scalaebar.png', bbox_inches='tight', dpi=300)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41e7a65b-3afc-4bd6-8ff2-0d67f0a256f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/br/notebooks/visualize_reconstructions.ipynb b/src/br/notebooks/visualize_reconstructions.ipynb deleted file mode 100644 index 43f985f..0000000 --- a/src/br/notebooks/visualize_reconstructions.ipynb +++ /dev/null @@ -1,695 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "1f0404d5-49ff-4700-993f-c692d6248a58", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-e23853d6-1ca4-59e9-ac9a-1887267908f3\"\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import torch\n", - "import yaml\n", - "from hydra.utils import instantiate\n", - "from PIL import Image\n", - "from sklearn.manifold import TSNE\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from br.features.rotation_invariance import rotation_image_batch_z, rotation_pc_batch_z\n", - "from br.models.load_models import load_model_from_path\n", - "\n", - "device = \"cuda:0\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a6562cb-d3b6-46ef-98c1-e50a96c2c370", - "metadata": {}, - "outputs": [], - "source": [ - "# Set paths\n", - "os.chdir(\"/allen/aics/modeling/ritvik/projects/benchmarking_representations/\")\n", - "save_path = \"./test_cellpack_recons/\"\n", - "results_path = \"./configs/results/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb9a887b-f408-4fc9-b01a-2a556a80571c", - "metadata": {}, - "outputs": [], - "source": [ - "# Load data yaml and test batch\n", - "cellpack_data = \"./configs/data/cellpack/pc.yaml\"\n", - "with open(cellpack_data) as stream:\n", - " cellpack_data = yaml.safe_load(stream)\n", - "data = instantiate(cellpack_data)\n", - "batch = next(iter(data.test_dataloader()))" - ] - }, - { - "cell_type": "markdown", - "id": "5a94fcd7-44c7-4b17-a7e1-fba0f498bc95", - "metadata": {}, - "source": [ - "# Save examples of raw data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91e97c87-cd4e-4bf4-807b-d2b040783f1e", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "this_save_path = Path(save_path) / Path(\"panel_a\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "all_arr = []\n", - "for i in range(6):\n", - " np_arr = batch[\"pcloud\"][i].numpy()\n", - " new_array = np.zeros(np_arr.shape)\n", - " z = np_arr[:, 0]\n", - " # inds = np.where(z > 0.1)[0]\n", - " new_array[:, 0] = np_arr[:, 2]\n", - " new_array[:, 1] = z\n", - " new_array[:, 2] = np_arr[:, 1]\n", - " new_array = new_array[inds]\n", - " all_arr.append(new_array)\n", - " np.save(this_save_path / Path(f\"{i}.npy\"), new_array)" - ] - }, - { - "cell_type": "markdown", - "id": "ed9cbe31-6e56-42fc-8bfa-277a8a3908b1", - "metadata": {}, - "source": [ - "# Visualize reconstructions and rotation invariance " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b0f770c-0c2d-491e-aa2c-2b1d20652763", - "metadata": {}, - "outputs": [], - "source": [ - "# utility function for plotting\n", - "def plot_pc(this_p, axes, max_size, color=\"gray\", x_ind=2, y_ind=1):\n", - " axes.scatter(this_p[:, x_ind], this_p[:, y_ind], c=color, s=1)\n", - " axes.spines[\"top\"].set_visible(False)\n", - " axes.spines[\"right\"].set_visible(False)\n", - " axes.spines[\"bottom\"].set_visible(False)\n", - " axes.spines[\"left\"].set_visible(False)\n", - " axes.set_aspect(\"equal\", adjustable=\"box\")\n", - " axes.set_ylim([-max_size, max_size])\n", - " axes.set_xlim([-max_size, max_size])\n", - " axes.set_yticks([])\n", - " axes.set_xticks([])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2c9ebef-92e7-408a-be52-0e4e2ccbfa63", - "metadata": {}, - "outputs": [], - "source": [ - "models, names, sizes = load_model_from_path(\"cellpack\", results_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1a636dfc-3e3f-47d1-9692-dba55342e73d", - "metadata": {}, - "outputs": [], - "source": [ - "names" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e6aac001-643a-4d2e-8835-bcae40204783", - "metadata": {}, - "outputs": [], - "source": [ - "model = models[-3]\n", - "this_name = names[-3]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ab4359fb-8bd2-415f-9093-9a9196358efc", - "metadata": {}, - "outputs": [], - "source": [ - "for key in batch.keys():\n", - " if key not in [\n", - " \"split\",\n", - " \"bf_meta_dict\",\n", - " \"egfp_meta_dict\",\n", - " \"filenames\",\n", - " \"image_meta_dict\",\n", - " \"cell_id\",\n", - " ]:\n", - " if not isinstance(batch[key], list):\n", - " batch[key] = batch[key].to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "181fec61-fea8-47b7-b8e6-26fa22b399eb", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path = Path(save_path) / Path(f\"Recons_{this_name}\")\n", - "this_save_path.mkdir(parents=True, exist_ok=True)\n", - "\n", - "this_key = \"pcloud\"\n", - "\n", - "max_z = {0: 20, 1: 20, 2: 20, 3: 1, 4: 20, 5: 20}\n", - "max_size = 10\n", - "\n", - "all_thetas = [\n", - " 0,\n", - " 1 * 90,\n", - " 2 * 90,\n", - " 3 * 90,\n", - "]\n", - "\n", - "\n", - "all_xhat = []\n", - "all_canon = []\n", - "all_input = []\n", - "with torch.no_grad():\n", - " for jl, theta in enumerate(all_thetas):\n", - " this_input_rot = rotation_pc_batch_z(\n", - " batch,\n", - " theta,\n", - " )\n", - " batch_input = {this_key: torch.tensor(this_input_rot).to(device).float()}\n", - " z, z_params = model.get_embeddings(batch_input, inference=True)\n", - " xhat = model.decode_embeddings(z_params, batch_input, decode=True, return_canonical=True)\n", - " all_input.append(this_input_rot)\n", - " if theta == 0:\n", - " for ind in range(6):\n", - " this_p = this_input_rot[ind]\n", - " this_max_z = max_z[ind]\n", - " this_p = this_p[np.where(this_p[:, 0] < this_max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 0] > -this_max_z)[0]]\n", - " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", - " plot_pc(this_p, ax, max_size, \"black\")\n", - " fig.savefig(this_save_path / f\"input_{ind}.png\")\n", - "\n", - " if \"canonical\" in xhat.keys():\n", - " this_canon = xhat[\"canonical\"].detach().cpu().numpy()\n", - " all_canon.append(this_canon)\n", - " if theta == 0:\n", - " for ind in range(6):\n", - " this_p = this_canon[ind]\n", - " this_max_z = max_z[ind]\n", - " this_p = this_p[np.where(this_p[:, 1] < this_max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 1] > -this_max_z)[0]]\n", - " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", - " plot_pc(this_p, ax, max_size, \"black\", x_ind=2, y_ind=1)\n", - " fig.savefig(this_save_path / f\"canon_{ind}.png\")\n", - " this_recon = xhat[this_key].detach().cpu().numpy()\n", - " all_xhat.append(this_recon)\n", - " if theta == 0:\n", - " for ind in range(6):\n", - " this_p = this_recon[ind]\n", - " this_max_z = max_z[ind]\n", - " this_p = this_p[np.where(this_p[:, 0] < this_max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 0] > -this_max_z)[0]]\n", - " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", - " plot_pc(this_p, ax, max_size, \"black\")\n", - " fig.savefig(this_save_path / f\"recon_{ind}.png\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4ac9950-9a17-4234-a62a-0edc5fcf2034", - "metadata": {}, - "outputs": [], - "source": [ - "all_input[0][0].max(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0bf8e776-bab0-4827-babe-1b5e866482f5", - "metadata": {}, - "outputs": [], - "source": [ - "i = 0 # rot ind\n", - "ind = 0 # rule\n", - "max_z = 1\n", - "max_size = 10\n", - "\n", - "# this_p = all_input[i][ind].detach().cpu().numpy()\n", - "this_p = all_xhat[i][ind]\n", - "this_p = this_p[np.where(this_p[:, 0] < max_z)[0]]\n", - "this_p = this_p[np.where(this_p[:, 0] > -max_z)[0]]\n", - "print(this_p.max(axis=0))\n", - "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", - "plot_pc(this_p, ax, max_size, \"black\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0099d68-31aa-4fed-83ee-0328e903ec87", - "metadata": {}, - "outputs": [], - "source": [ - "fig, axes = plt.subplots(1, 4, figsize=(10, 5))\n", - "\n", - "\n", - "ind = 0\n", - "max_z = 200\n", - "max_size = 10\n", - "for i in range(4):\n", - " this_p = all_input[i][ind].detach().cpu().numpy()\n", - " this_p = this_p[np.where(this_p[:, 1] < max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 1] > -max_z)[0]]\n", - " print(this_p.shape)\n", - " plot_pc(this_p, axes[i], max_size)\n", - "\n", - "# fig.savefig('./cellpack_rot_test/7abfecf1-44db-468a-b799-4959a23cfb0d_pc_rot_input.png', dpi=300, bbox_inches='tight')\n", - "\n", - "fig, axes = plt.subplots(1, 4, figsize=(10, 5))\n", - "for i in range(4):\n", - " this_p = all_canon[i][ind]\n", - " this_p = this_p[np.where(this_p[:, 1] < max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 1] > -max_z)[0]]\n", - " plot_pc(this_p, axes[i], max_size)\n", - "\n", - "# fig.savefig('./cellpack_rot_test/7abfecf1-44db-468a-b799-4959a23cfb0d_pc_rot_canon.png', dpi=300, bbox_inches='tight')\n", - "fig, axes = plt.subplots(1, 4, figsize=(10, 5))\n", - "for i in range(4):\n", - " this_p = all_xhat[i][ind]\n", - " this_p = this_p[np.where(this_p[:, 1] < max_z)[0]]\n", - " this_p = this_p[np.where(this_p[:, 1] > -max_z)[0]]\n", - " plot_pc(this_p, axes[i], max_size)\n", - "\n", - "# fig.savefig('./cellpack_rot_test/7abfecf1-44db-468a-b799-4959a23cfb0d_pc_rot_recon_classical.png', dpi=300, bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f22fa598-5f46-4654-8df0-bb4dd888d5f4", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9cf4dae1-3b26-41ae-9284-afab824bdfb9", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98c9ceed-cdf5-47e7-8e2d-4da1ae06a617", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15c81de7-c116-4070-88c0-3bc3aaa8e383", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4cb85790-6f68-4ebf-9c66-0d95d40f182c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a853643c-71bc-45bd-9513-e1ead83282af", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "gg = pd.read_csv(\"/allen/aics/modeling/ritvik/forSaurabh/manifest.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d494b69c-307e-49a5-b074-1df88e009413", - "metadata": {}, - "outputs": [], - "source": [ - "path = gg.loc[gg[\"CellId\"] == \"9c1ff213-4e9e-4b73-a942-3baf9d37a50f\"][\"nucobj_path\"].iloc[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5120ec07-727d-49b0-b54e-c5142a4cd8f5", - "metadata": {}, - "outputs": [], - "source": [ - "this_save_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7f91bc8a-3983-4bbd-aadf-c7349a1bc6ec", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "mi.set_variant(\"scalar_rgb\")\n", - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import mitsuba as mi\n", - "import numpy as np\n", - "import trimesh\n", - "from mitsuba import ScalarTransform4f as T\n", - "from trimesh.transformations import rotation_matrix\n", - "\n", - "\n", - "def plot(this_mesh_path, angle, angle2=None, angle3=None, name=None):\n", - " myMesh = trimesh.load(this_mesh_path)\n", - "\n", - " # Scale the mesh to approximately one unit based on the height\n", - " sf = 1.0\n", - " myMesh.apply_scale(sf / myMesh.extents.max())\n", - "\n", - " # for 3_1\n", - " myMesh = myMesh.apply_transform(rotation_matrix(np.deg2rad(angle), [0, 0, -1]))\n", - " if angle2:\n", - " myMesh = myMesh.apply_transform(rotation_matrix(np.deg2rad(angle2), [0, 1, 0]))\n", - "\n", - " if angle3:\n", - " myMesh = myMesh.apply_transform(rotation_matrix(np.deg2rad(angle3), [1, 0, 0]))\n", - " # myMesh = myMesh.apply_transform(rotation_matrix(np.deg2rad(0), [1,0,0]))\n", - "\n", - " # Translate the mesh so that it's centroid is at the origin and rests on the ground plane\n", - " myMesh.apply_translation(\n", - " [\n", - " -myMesh.bounds[0, 0] - myMesh.extents[0] / 2.0,\n", - " -myMesh.bounds[0, 1] - myMesh.extents[1] / 2.0,\n", - " -myMesh.bounds[0, 2],\n", - " ]\n", - " )\n", - "\n", - " # Fix the mesh normals for the mesh\n", - " myMesh.fix_normals()\n", - "\n", - " # Write the mesh to an external file (Wavefront .obj)\n", - " with open(\"mesh.obj\", \"w\") as f:\n", - " f.write(trimesh.exchange.export.export_obj(myMesh, include_normals=True))\n", - "\n", - " # Create a sensor that is used for rendering the scene\n", - " def load_sensor(r, phi, theta):\n", - " # Apply two rotations to convert from spherical coordinates to world 3D coordinates.\n", - " origin = T.rotate([0, 0, 1], phi).rotate([0, 1, 0], theta) @ mi.ScalarPoint3f([0, 0, r])\n", - "\n", - " return mi.load_dict(\n", - " {\n", - " \"type\": \"perspective\",\n", - " \"fov\": 15.0,\n", - " \"to_world\": T.look_at(\n", - " origin=origin, target=[0, 0, myMesh.extents[2] / 2], up=[0, 0, 1]\n", - " ),\n", - " \"sampler\": {\"type\": \"independent\", \"sample_count\": 16},\n", - " \"film\": {\n", - " \"type\": \"hdrfilm\",\n", - " \"width\": 1024,\n", - " \"height\": 768,\n", - " \"rfilter\": {\n", - " \"type\": \"tent\",\n", - " },\n", - " \"pixel_format\": \"rgb\",\n", - " },\n", - " }\n", - " )\n", - "\n", - " # Scene parameters\n", - " relativeLightHeight = 8\n", - "\n", - " # A scene dictionary contains the description of the rendering scene.\n", - " scene2 = mi.load_dict(\n", - " {\n", - " \"type\": \"scene\",\n", - " # The keys below correspond to object IDs and can be chosen arbitrarily\n", - " \"integrator\": {\"type\": \"path\"},\n", - " \"mesh\": {\n", - " \"type\": \"obj\",\n", - " \"filename\": \"mesh.obj\",\n", - " \"face_normals\": True, # This prevents smoothing of sharp-corners by discarding surface-normals. Useful for engineering CAD.\n", - " \"bsdf\": {\n", - " # 'type': 'diffuse',\n", - " # 'reflectance': {\n", - " # 'type': 'rgb',\n", - " # 'value': [0.1, 0.27, 0.86]\n", - " # }\n", - " # 'type': 'plastic',\n", - " # 'diffuse_reflectance': {\n", - " # 'type': 'rgb',\n", - " # 'value': [0.1, 0.27, 0.36]\n", - " # },\n", - " # 'int_ior': 1.9\n", - " # 'type': 'roughplastic'\n", - " \"type\": \"pplastic\",\n", - " \"diffuse_reflectance\": {\"type\": \"rgb\", \"value\": [0.05, 0.03, 0.1]},\n", - " \"alpha\": 0.02,\n", - " },\n", - " },\n", - " # A general emitter is used for illuminating the entire scene (renders the background white)\n", - " \"light\": {\"type\": \"constant\", \"radiance\": 1.0},\n", - " \"areaLight\": {\n", - " \"type\": \"rectangle\",\n", - " # The height of the light can be adjusted below\n", - " \"to_world\": T.translate([0, 0.0, myMesh.bounds[1, 2] + relativeLightHeight])\n", - " .scale(1.0)\n", - " .rotate([1, 0, 0], 5.0),\n", - " \"flip_normals\": True,\n", - " \"emitter\": {\n", - " \"type\": \"area\",\n", - " \"radiance\": {\n", - " \"type\": \"spectrum\",\n", - " \"value\": 30.0,\n", - " },\n", - " },\n", - " },\n", - " \"floor\": {\n", - " \"type\": \"disk\",\n", - " \"to_world\": T.scale(3).translate([0.0, 0.0, 0.0]),\n", - " \"material\": {\n", - " \"type\": \"diffuse\",\n", - " \"reflectance\": {\"type\": \"rgb\", \"value\": 0.75},\n", - " },\n", - " },\n", - " }\n", - " )\n", - "\n", - " sensor_count = 1\n", - "\n", - " radius = 4\n", - " phis = [130.0]\n", - " theta = 60.0\n", - "\n", - " sensors = [load_sensor(radius, phi, theta) for phi in phis]\n", - "\n", - " \"\"\"\n", - " Render the Scene\n", - " The render samples are specified in spp\n", - " \"\"\"\n", - " image = mi.render(scene2, sensor=sensors[0], spp=256)\n", - "\n", - " # Write the output\n", - "\n", - " save_path = this_save_path\n", - " mi.util.write_bitmap(str(save_path) + f\"{name}.png\", image)\n", - " # mi.util.write_bitmap(save_path + \".exr\", image)\n", - "\n", - " # Display the output in an Image\n", - " plt.imshow(image ** (1.0 / 2.2))\n", - " plt.axis(\"off\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ece02c4-3736-42f0-a7e2-6b22d79148a3", - "metadata": {}, - "outputs": [], - "source": [ - "path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d723074-5583-42a4-a0f6-da8f9242db23", - "metadata": {}, - "outputs": [], - "source": [ - "plot(path, 0, 90, 0, \"nuc\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8e4427d-b444-4ff6-9f52-769f1a67f985", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "568b4ed2-dc37-4a84-8c31-5010cdba56ce", - "metadata": {}, - "outputs": [], - "source": [ - "aa = np.load(\n", - " \"/allen/aics/modeling/ritvik/projects/benchmarking_representations/viz_pcna_pointclouds/midS-lateS_c6b66235-554c-4fd3-b0a2-a1e5468afb64.npy\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a617fa5f-c771-4647-a0fa-50f6b423bd55", - "metadata": {}, - "outputs": [], - "source": [ - "aa.max(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "037c5bae-26cf-4155-874a-87f396af097d", - "metadata": {}, - "outputs": [], - "source": [ - "bb = np.load(\n", - " \"/allen/aics/modeling/ritvik/projects/benchmarking_representations/notebooks_old/variance_all_punctate/pcna/latent_walks/viz/midS-lateS_0_1.npy\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "62c5f71f-da3f-4068-9f7a-ecde178b29dc", - "metadata": {}, - "outputs": [], - "source": [ - "bb.max(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52ff4691-41b5-4d24-8206-8f523f13555d", - "metadata": {}, - "outputs": [], - "source": [ - "aa = np.load(\n", - " \"/allen/aics/modeling/ritvik/projects/benchmarking_representations/viz_variance_pointclouds2/NUP153_692417.npy\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09eff7ff-beaa-4d88-b5ba-4082eb53d541", - "metadata": {}, - "outputs": [], - "source": [ - "aa.max(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cafed763-05c6-4634-988b-aa67a47a6091", - "metadata": {}, - "outputs": [], - "source": [ - "bb = np.load(\n", - " \"/allen/aics/modeling/ritvik/projects/benchmarking_representations/test_var_punctate_embeddings/latent_walks/structure_name_NUP153_0_0.npy\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9504b891-3c49-4185-b426-b96908f61cdf", - "metadata": {}, - "outputs": [], - "source": [ - "bb.max(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c46aa811-c3ab-471f-bd1a-b010604d5b5c", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}