Skip to content

Commit

Permalink
Merge pull request #78 from AllenCell/visualize_pointclouds
Browse files Browse the repository at this point in the history
Visualize pointclouds
  • Loading branch information
ritvikvasan authored Dec 13, 2024
2 parents 6f43518 + 877ae43 commit 8432a89
Show file tree
Hide file tree
Showing 8 changed files with 447 additions and 1,142 deletions.
6 changes: 4 additions & 2 deletions docs/PREPROCESSING.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@ Preprocessing is divided into three steps that use two different virtual environ

# Punctate structures: Generate pointclouds

Edit the data paths in the following file to match the location of the outputs of the alignment, masking, and registration step, then run it.
Use the preprocessed data manifest generated via the alignment, masking, and registration steps from image as input to the pointcloud generation step

```
src
└── br
└── data
   └── preprocessing
      └── pc_preprocessing
         └── punctate_cyto.py <- Point cloud sampling from raw images for punctate structures here
         └── pcna.py <- Point cloud sampling from raw images for DNA replication foci dataset here
         └── punctate_nuc.py <- Point cloud sampling from raw images of nuclear structures from the WTC-11 hIPS single cell image dataset here
         └── punctate_cyto.py <- Point cloud sampling from raw images of cytoplasmic structures from the WTC-11 hIPS single cell image dataset here
```

# Polymorphic structures: Generate SDFs
Expand Down
73 changes: 73 additions & 0 deletions src/br/analysis/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import trimesh
import yaml
from aicsimageio import AICSImage
from skimage import measure
from sklearn.decomposition import PCA
from tqdm import tqdm

Expand Down Expand Up @@ -1025,3 +1026,75 @@ def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_p
plt.tight_layout()
plt.savefig(reconstructions_path + "sample_recons.png", dpi=300, bbox_inches="tight")
plt.savefig(reconstructions_path + "sample_recons.pdf", dpi=300, bbox_inches="tight")


# utility plot functions
def plot_image(ax_array, struct, nuc, mem, vmin, vmax, num_slices=None, show_nuc_countour=True):
mid_z = int(struct.shape[0] / 2)

if num_slices is None:
num_slices = mid_z * 2
z_interp = np.linspace(mid_z - num_slices / 2, mid_z + num_slices / 2, num_slices + 1).astype(
int
)
if z_interp.max() == struct.shape[0]:
z_interp = z_interp[:-1]

struct = np.where(mem, struct, 0)
mem = mem[z_interp].max(0)
nuc = nuc[z_interp].max(0)
mem_contours = measure.find_contours(mem, 0.5)
nuc_contours = measure.find_contours(nuc, 0.5)

for ind, _ in enumerate(ax_array):
this_struct = struct
if ind > 0:
this_struct = np.zeros(struct.shape)
ax_array[ind].imshow(this_struct[z_interp].max(0), cmap="gray_r", vmin=vmin, vmax=vmax)
if ind == 0:
if show_nuc_countour:
for contour in nuc_contours:
ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="cyan")
for contour in mem_contours:
ax_array[ind].plot(contour[:, 1], contour[:, 0], linewidth=1, c="magenta")
ax_array[ind].axis("off")
return ax_array, z_interp


def plot_pointcloud(
this_ax_array,
points_all,
z_interp,
cmap,
save_path=None,
name=None,
center=None,
save=False,
center_slice=False,
):
this_p = points_all.loc[points_all["z"] < max(z_interp)]
if center_slice:
this_p = this_p.loc[this_p["z"] > min(z_interp)]
print(this_p.shape)
intensity = this_p.inorm.values
this_ax_array.scatter(
this_p["x"].values, this_p["y"].values, c=cmap(intensity), s=0.3, alpha=0.5
)
this_ax_array.axis("off")
if save:
z_center, y_center, x_center = center[0], center[1], center[2]

# Center and scale for viz
this_p["z"] = this_p["z"] - z_center
this_p["y"] = this_p["y"] - y_center
this_p["x"] = this_p["x"] - x_center

this_p["z"] = 0.1 * this_p["z"]
this_p["x"] = 0.1 * this_p["x"]
this_p["y"] = 0.1 * this_p["y"]
Path(save_path).mkdir(parents=True, exist_ok=True)
colors = cmap(this_p["inorm"].values)[:, :3]
np_arr = this_p[["x", "y", "z"]].values
np_arr2 = colors
np_arr = np.concatenate([np_arr, np_arr2], axis=1)
np.save(Path(save_path) / Path(f"{name}.npy"), np_arr)
225 changes: 225 additions & 0 deletions src/br/analysis/visualize_pointclouds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import argparse
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from br.analysis.analysis_utils import plot_image, plot_pointcloud
from br.data.preprocessing.pc_preprocessing.pcna import (
compute_labels as compute_labels_pcna,
)
from br.data.preprocessing.pc_preprocessing.punctate_cyto import (
compute_labels as compute_labels_var_cyto,
)
from br.data.preprocessing.pc_preprocessing.punctate_nuc import (
compute_labels as compute_labels_var_nuc,
)
from br.features.utils import normalize_intensities_and_get_colormap_apply

dataset_dict = {
"pcna": {"raw_ind": 2, "nuc_ind": 6, "mem_ind": None},
"other_punctate": {"raw_ind": 2, "nuc_ind": 3, "mem_ind": 4},
}

viz_norms = {
"CETN2": [440, 800],
"NUP153": [420, 600],
"HIST1H2BJ": [450, 2885],
"SON": [420, 1500],
"SLC25A17": [400, 515],
"RAB5A": [420, 600],
"SMC1A": [450, 630],
}

cell_ids_ = {
"pcna": [
"7624cd5b-715a-478e-9648-3bac4a73abe8",
"80d40c5e-65bf-43b0-8dea-b697c421ea78",
"6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4",
"aabbbca4-6c35-4f3d-9467-7d573482f236",
"d23de56e-bacf-4ec8-8e18-39822fea777b",
"c382794f-5baf-4b17-8574-62dccbbbaefc",
"50b52c3e-4756-4684-a281-0141525ded9f",
"8713eea5-da72-4644-96fe-ba8340edb67d",
],
"other_punctate": [721646, 873680, 994027, 490385, 451974, 811336, 835431],
}


def main(args):

# make save path directory
Path(args.save_path).mkdir(parents=True, exist_ok=True)

orig_image_df = pd.read_parquet(args.preprocessed_manifest)

if args.global_path:
orig_image_df["registered_path"] = orig_image_df["registered_path"].apply(
lambda x: args.global_path + x
)

assert args.dataset_name in ["pcna", "other_punctate"]
dataset_ = dataset_dict[args.dataset_name]
raw_ind = dataset_["raw_ind"]
nuc_ind = dataset_["nuc_ind"]
mem_ind = dataset_["mem_ind"]

strat = args.class_column
strat_val = args.class_label

if not strat:
cell_ids = cell_ids_[args.dataset_name]
orig_image_df = orig_image_df.loc[orig_image_df["CellId"].isin(cell_ids)].reset_index(
drop=True
)
else:
orig_image_df = orig_image_df.loc[orig_image_df[strat] == strat_val].sample(n=1)

for _, this_image in orig_image_df.iterrows():
cell_id = this_image["CellId"]
if not strat:
strat_val = this_image["structure_name"]

if args.dataset_name == "pcna":
points_all, _, img, center = compute_labels_pcna(this_image, False)
vmin, vmax = None, None
num_slices = 15
center_slice = True
elif args.dataset_name == "other_punctate":
assert strat == "structure_name"
if strat_val in ["CETN2", "RAB5A", "SLC25A17"]:
points_all, _, img, center = compute_labels_var_cyto(this_image, False)
center_slice = False
num_slices = None
else:
center_slice = True
points_all, _, img, center = compute_labels_var_nuc(this_image, False)
num_slices = 1
this_viz_norm = viz_norms[strat_val]
vmin = this_viz_norm[0]
vmax = this_viz_norm[1]

img_raw = img[raw_ind]
img_nuc = img[nuc_ind]
img_raw = np.where(img_raw < 60000, img_raw, img_raw.min())
img_mem = img_nuc
if mem_ind is not None:
img_mem = img[mem_ind]

if (args.dataset_name == "other_punctate") and (
strat_val in ["CETN2", "RAB5A", "SLC25A17"]
):
img_raw = np.where(img_mem, img_raw, 0) # mask by mem/nuc seg
else:
img_raw = np.where(img_nuc, img_raw, 0) # mask by mem/nuc seg

# Sample sparse point cloud and get images
probs2 = points_all["s"].values
probs2 = np.where(probs2 < 0, 0, probs2)
probs2 = probs2 / probs2.sum()
idxs2 = np.random.choice(np.arange(len(probs2)), size=2048, replace=True, p=probs2)
points = points_all.iloc[idxs2].reset_index(drop=True)

if not vmin:
vmin = points["s"].min()
vmax = points["s"].max()
points = normalize_intensities_and_get_colormap_apply(points, vmin, vmax)
points_all = normalize_intensities_and_get_colormap_apply(points_all, vmin, vmax)

save = True
fig, ax_array = plt.subplots(1, 3, figsize=(10, 5))

ax_array, z_interp = plot_image(
ax_array,
img_raw,
img_nuc,
img_mem,
vmin,
vmax,
num_slices=num_slices,
show_nuc_countour=True,
)
ax_array[0].set_title("Raw image")

name = strat_val + "_" + str(cell_id)

plot_pointcloud(
ax_array[1],
points_all,
z_interp,
plt.get_cmap("YlGnBu"),
save_path=args.save_path,
name=name,
center=center,
save=False,
center_slice=center_slice,
)
ax_array[1].set_title("Sampling dense PC")

plot_pointcloud(
ax_array[2],
points,
z_interp,
plt.get_cmap("YlGnBu"),
save_path=args.save_path,
name=name,
center=center,
save=save,
center_slice=center_slice,
)
ax_array[2].set_title("Sampling sparse PC")
print(f"Saving {name}.png")
fig.savefig(Path(args.save_path) / Path(f"{name}.png"), bbox_inches="tight", dpi=300)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script for computing features")
parser.add_argument("--save_path", type=str, required=True, help="Path to save results.")
parser.add_argument(
"--global_path",
type=str,
default=None,
required=False,
help="Path to append to relative paths in preprocessed manifest",
)
parser.add_argument(
"--preprocessed_manifest",
type=str,
required=True,
help="Path to processed single cell image manifest.",
)
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.")
parser.add_argument(
"--class_column",
type=str,
default=None,
required=False,
help="Column name of class to use for sampling, e.g. cell_stage_fine",
)
parser.add_argument(
"--class_label",
type=str,
default=None,
required=False,
help="Specific class label to sample, e.g. lateS",
)
args = parser.parse_args()

# Validate that required paths are provided
if not args.save_path or not args.dataset_name:
print("Error: Required arguments are missing.")
sys.exit(1)

main(args)

"""
Example run:
PCNA dataset
python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_pcna/processed/manifest.parquet" --dataset_name "pcna" --class_column "cell_stage_fine" --class_label "lateS" --global_path "./subpackages/image_preprocessing/"
Other punctate dataset
python visualize_pointclouds.py --save_path "./plot_pcs_test" --preprocessed_manifest "./subpackages/image_preprocessing/tmp_output_variance/processed/manifest.parquet" --dataset_name "other_punctate" --class_column "structure_name" --class_label "CETN2" --global_path "./subpackages/image_preprocessing/"
"""
Loading

0 comments on commit 8432a89

Please sign in to comment.