Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize pointclouds #78

Merged
merged 10 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/br/analysis/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import trimesh
import yaml
from aicsimageio import AICSImage
from skimage import measure
from sklearn.decomposition import PCA
from tqdm import tqdm

Expand Down Expand Up @@ -1025,3 +1026,75 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p
plt.tight_layout()
plt.savefig(reconstructions_path + "sample_recons.png", dpi=300, bbox_inches="tight")
plt.savefig(reconstructions_path + "sample_recons.pdf", dpi=300, bbox_inches="tight")


# utility plot functions
def plot_image(ax_array, struct, nuc, mem, vmin, vmax, num_slices=None, show_nuc_countour=True):
mid_z = int(struct.shape[0] / 2)

if num_slices is None:
num_slices = mid_z * 2
z_interp = np.linspace(mid_z - num_slices / 2, mid_z + num_slices / 2, num_slices + 1).astype(
int
)
if z_interp.max() == struct.shape[0]:
z_interp = z_interp[:-1]

struct = np.where(mem, struct, 0)
mem = mem[z_interp].max(0)
nuc = nuc[z_interp].max(0)
mem_contours = measure.find_contours(mem, 0.5)
nuc_contours = measure.find_contours(nuc, 0.5)

for ind, _ in enumerate(ax_array):
this_struct = struct
if ind > 0:
this_struct = np.zeros(struct.shape)
ax_array[ind].imshow(this_struct[z_interp].max(0), cmap="gray_r", vmin=vmin, vmax=vmax)
if ind == 0:
if show_nuc_countour:
for contour in nuc_contours:
ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="cyan")
for contour in mem_contours:
ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="magenta")
ax_array[ind].axis("off")
return ax_array, z_interp


def plot_pointcloud(
this_ax_array,
points_all,
z_interp,
cmap,
save_path=None,
name=None,
center=None,
save=False,
center_slice=False,
):
this_p = points_all.loc[points_all["z"] < max(z_interp)]
if center_slice:
this_p = this_p.loc[this_p["z"] > min(z_interp)]
print(this_p.shape)
intensity = this_p.inorm.values
this_ax_array.scatter(
this_p["x"].values, this_p["y"].values, c=cmap(intensity), s=0.3, alpha=0.5
)
this_ax_array.axis("off")
if save:
z_center, y_center, x_center = center[0], center[1], center[2]

# Center and scale for viz
this_p["z"] = this_p["z"] - z_center
this_p["y"] = this_p["y"] - y_center
this_p["x"] = this_p["x"] - x_center

this_p["z"] = 0.1 * this_p["z"]
this_p["x"] = 0.1 * this_p["x"]
this_p["y"] = 0.1 * this_p["y"]
Path(save_path).mkdir(parents=True, exist_ok=True)
colors = cmap(this_p["inorm"].values)[:, :3]
np_arr = this_p[["x", "y", "z"]].values
np_arr2 = colors
np_arr = np.concatenate([np_arr, np_arr2], axis=1)
np.save(Path(save_path) / Path(f"{name}.npy"), np_arr)
223 changes: 223 additions & 0 deletions src/br/analysis/visualize_pointclouds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from br.analysis.analysis_utils import plot_image, plot_pointcloud
from br.data.preprocessing.pc_preprocessing.pcna import (
compute_labels as compute_labels_pcna,
)
from br.data.preprocessing.pc_preprocessing.punctate_cyto import (
compute_labels as compute_labels_var_cyto,
)
from br.data.preprocessing.pc_preprocessing.punctate_nuc import (
compute_labels as compute_labels_var_nuc,
)
from br.features.utils import normalize_intensities_and_get_colormap_apply

dataset_dict = {
"pcna": {"raw_ind": 2, "nuc_ind": 6, "mem_ind": None},
"other_punctate": {"raw_ind": 2, "nuc_ind": 3, "mem_ind": 4},
}

viz_norms = {
"CETN2": [440, 800],
"NUP153": [420, 600],
"HIST1H2BJ": [450, 2885],
"SON": [420, 1500],
"SLC25A17": [400, 515],
"RAB5A": [420, 600],
"SMC1A": [450, 630],
}

cell_ids_ = {
"pcna": [
"7624cd5b-715a-478e-9648-3bac4a73abe8",
"80d40c5e-65bf-43b0-8dea-b697c421ea78",
"6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4",
"aabbbca4-6c35-4f3d-9467-7d573482f236",
"d23de56e-bacf-4ec8-8e18-39822fea777b",
"c382794f-5baf-4b17-8574-62dccbbbaefc",
"50b52c3e-4756-4684-a281-0141525ded9f",
"8713eea5-da72-4644-96fe-ba8340edb67d",
],
"other_punctate": [721646, 873680, 994027, 490385, 451974, 811336, 835431],
}


def main(args):

# make save path directory
Path(args.save_path).mkdir(parents=True, exist_ok=True)

orig_image_df = pd.read_parquet(args.preprocessed_manifest)

if args.global_path:
orig_image_df["registered_path"] = orig_image_df["registered_path"].apply(
lambda x: args.global_path + x
)

assert args.dataset_name in ["pcna", "other_punctate"]
dataset_ = dataset_dict[args.dataset_name]
raw_ind = dataset_["raw_ind"]
nuc_ind = dataset_["nuc_ind"]
mem_ind = dataset_["mem_ind"]

strat = args.class_column
strat_val = args.class_label

if not strat:
cell_ids = cell_ids_[args.dataset_name]
orig_image_df = orig_image_df.loc[orig_image_df["CellId"].isin(cell_ids)].reset_index(
drop=True
)
else:
orig_image_df = orig_image_df.loc[orig_image_df[strat] == strat_val].sample(n=1)

for _, this_image in orig_image_df.iterrows():
cell_id = this_image["CellId"]
if not strat:
strat_val = this_image['structure_name']

if args.dataset_name == "pcna":
points_all, _, img, center = compute_labels_pcna(this_image, False)
vmin, vmax = None, None
num_slices = 15
center_slice = True
elif args.dataset_name == "other_punctate":
assert strat == "structure_name"
if strat_val in ["CETN2", "RAB5A", "SLC25A17"]:
points_all, _, img, center = compute_labels_var_cyto(this_image, False)
center_slice = False
num_slices = None
else:
center_slice = True
points_all, _, img, center = compute_labels_var_nuc(this_image, False)
num_slices = 1
this_viz_norm = viz_norms[strat_val]
vmin = this_viz_norm[0]
vmax = this_viz_norm[1]

img_raw = img[raw_ind]
img_nuc = img[nuc_ind]
img_raw = np.where(img_raw < 60000, img_raw, img_raw.min())
img_mem = img_nuc
if mem_ind is not None:
img_mem = img[mem_ind]

if (args.dataset_name == 'other_punctate') and (strat_val in ["CETN2", "RAB5A", "SLC25A17"]):
img_raw = np.where(img_mem, img_raw, 0) # mask by mem/nuc seg
else:
img_raw = np.where(img_nuc, img_raw, 0) # mask by mem/nuc seg

# Sample sparse point cloud and get images
probs2 = points_all["s"].values
probs2 = np.where(probs2 < 0, 0, probs2)
probs2 = probs2 / probs2.sum()
idxs2 = np.random.choice(np.arange(len(probs2)), size=2048, replace=True, p=probs2)
points = points_all.iloc[idxs2].reset_index(drop=True)

if not vmin:
vmin = points["s"].min()
vmax = points["s"].max()
points = normalize_intensities_and_get_colormap_apply(points, vmin, vmax)
points_all = normalize_intensities_and_get_colormap_apply(points_all, vmin, vmax)

save = True
fig, ax_array = plt.subplots(1, 3, figsize=(10, 5))

ax_array, z_interp = plot_image(
ax_array,
img_raw,
img_nuc,
img_mem,
vmin,
vmax,
num_slices=num_slices,
show_nuc_countour=True,
)
ax_array[0].set_title("Raw image")

name = strat_val + "_" + str(cell_id)

plot_pointcloud(
ax_array[1],
points_all,
z_interp,
plt.get_cmap("YlGnBu"),
save_path=args.save_path,
name=name,
center=center,
save=False,
center_slice=center_slice,
)
ax_array[1].set_title("Sampling dense PC")

plot_pointcloud(
ax_array[2],
points,
z_interp,
plt.get_cmap("YlGnBu"),
save_path=args.save_path,
name=name,
center=center,
save=save,
center_slice=center_slice,
)
ax_array[2].set_title("Sampling sparse PC")
print(f'Saving {name}.png')
fig.savefig(Path(args.save_path) / Path(f"{name}.png"), bbox_inches="tight", dpi=300)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script for computing features")
parser.add_argument("--save_path", type=str, required=True, help="Path to save results.")
parser.add_argument(
"--global_path",
type=str,
default=None,
required=False,
help="Path to append to relative paths in preprocessed manifest",
)
parser.add_argument(
"--preprocessed_manifest",
type=str,
required=True,
help="Path to processed single cell image manifest.",
)
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.")
parser.add_argument(
"--class_column",
type=str,
default=None,
required=False,
help="Column name of class to use for sampling, e.g. cell_stage_fine",
)
parser.add_argument(
"--class_label",
type=str,
default=None,
required=False,
help="Specific class label to sample, e.g. lateS",
)
args = parser.parse_args()

# Validate that required paths are provided
if not args.save_path or not args.dataset_name:
print("Error: Required arguments are missing.")
sys.exit(1)

main(args)

"""
Example run:

PCNA dataset
python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_pcna/processed/manifest.parquet" --dataset_name "pcna" --class_column "cell_stage_fine" --class_label "lateS" --global_path "./subpackages/image_preprocessing/"

Other punctate dataset
python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_variance/processed/manifest.parquet" --dataset_name "other_punctate" --class_column "structure_name" --class_label "CETN2" --global_path "./subpackages/image_preprocessing/"
"""
8 changes: 7 additions & 1 deletion src/br/data/preprocessing/pc_preprocessing/punctate_cyto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@
"endosomes": 500,
"peroxisomes": 500,
"centrioles": 500,
"RAB5A": 500,
"SLC25A17": 500,
"CETN2": 500,
}
REP_DICT = {
"endosomes": True,
"peroxisomes": True,
"centrioles": True,
"RAB5A": True,
"SLC25A17": True,
"CETN2": True,
}


def compute_labels(row, save=True):
num_points = 20480
path = row["crop_raw"]
structure_name = row["Structure"]
structure_name = row["structure_name"]
img_full = AICSImage(path).data[0]
raw = img_full[2] # raw struct

Expand Down
Loading
Loading