diff --git a/README.md b/README.md index 6c2e827..5095a1e 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ # Benchmarking Representations + Code for training and benchmarking morphology appropriate representation learning methods, associated with the following manuscript. > **Interpretable representation learning for 3D multi-piece intracellular structures using point clouds** -> +> > Ritvik Vasan, Alexandra J. Ferrante, Antoine Borensztejn, Christopher L. Frick, Nathalie Gaudreault, Saurabh S. Mogre, Benjamin Morris, Guilherme G. Pires, Susanne M. Rafelski, Julie A. Theriot, Matheus P. Viana > -> bioRxiv 2024.07.25.605164; doi: https://doi.org/10.1101/2024.07.25.605164 +> bioRxiv 2024.07.25.605164; doi: https://doi.org/10.1101/2024.07.25.605164 Our analysis is organized as follows. @@ -31,4 +32,5 @@ If you'd like to reproduce this analysis on our data, check out the following do Coming soon # Contact + Allen Institute for Cell Science (cells@alleninstitute.org) diff --git a/src/br/analysis/analysis_utils.py b/src/br/analysis/analysis_utils.py index bd74970..c4bb50a 100644 --- a/src/br/analysis/analysis_utils.py +++ b/src/br/analysis/analysis_utils.py @@ -4,22 +4,27 @@ import subprocess from pathlib import Path +import matplotlib.colors as mcolors import matplotlib.pyplot as plt +import mesh_to_sdf import numpy as np import pandas as pd import pyvista as pv import torch +import trimesh import yaml +from aicsimageio import AICSImage from sklearn.decomposition import PCA from tqdm import tqdm +from br.data.utils import get_iae_reconstruction_3d_grid from br.features.plot import plot_pc_saved, plot_stratified_pc from br.features.reconstruction import save_pcloud from br.features.utils import ( normalize_intensities_and_get_colormap, normalize_intensities_and_get_colormap_apply, ) -from br.models.utils import get_all_configs_per_dataset +from br.models.utils import get_all_configs_per_dataset, move def str2bool(v): @@ -90,8 +95,6 @@ def get_mig_ids(gpu_uuid): elif "GPU" in line and in_gpu_section: # Encounter another GPU section break - # print(line) - if in_gpu_section: # Check for MIG devices if "MIG" in line: @@ -120,7 +123,6 @@ def config_gpu(): # Check if GPU utilization is under 20% (indicating it's idle) if utilization < 20: - # print(uuid, utilization) if is_mig: mig_ids = get_mig_ids(uuid) @@ -378,7 +380,6 @@ def viz_other_punctate(this_save_path, viz_params, stratify_key): for pc_bin in df[stratify_key].unique(): this_df = df.loc[df[stratify_key] == pc_bin].reset_index(drop=True) - print(this_df.shape, struct, pc_bin) np_arr = this_df[["x", "y", "z"]].values colors = cmap(this_df["inorm"].values)[:, :3] np_arr2 = colors @@ -598,10 +599,416 @@ def archetypes_polymorphic(this_save_path, archetypes_df, all_ret, all_features) dist = np.sum(dist, axis=1) closest_idx = np.argmin(dist) closest_real_id = all_ret.iloc[closest_idx]["CellId"] - print(dist, closest_real_id) mesh = pv.read(mesh_folder + str(closest_real_id) + ".stl") mesh.save(this_save_path / Path(f"{i}.ply")) arch_dict["archetype"].append(i) arch_dict["CellId"].append(closest_real_id) arch_dict = pd.DataFrame(arch_dict) arch_dict.to_csv(this_save_path / "archetypes.csv") + + +def generate_reconstructions(all_models, data_list, run_names, keys, test_ids, device, save_path): + with torch.no_grad(): + for j, model in enumerate(all_models): + this_data = data_list[j] + this_run_name = run_names[j] + this_key = keys[j] + for batch in this_data.test_dataloader(): + canonical = None + if not isinstance(batch["cell_id"], list): + if isinstance(batch["cell_id"], torch.Tensor): + cell_id = str(batch["cell_id"].item()) + else: + cell_id = str(batch["cell_id"]) + else: + cell_id = str(batch["cell_id"][0]) + if cell_id in test_ids: + input = batch[this_key].detach().cpu().numpy().squeeze() + if "pointcloud_SDF" in this_run_name: + eval_scaled_img_resolution = 32 + uni_sample_points = get_iae_reconstruction_3d_grid( + bb_min=-0.5, + bb_max=0.5, + resolution=eval_scaled_img_resolution, + padding=0.1, + ) + uni_sample_points = uni_sample_points.unsqueeze(0).repeat( + batch[this_key].shape[0], 1, 1 + ) + batch["points"] = uni_sample_points + xhat, z, z_params = model( + move(batch, device), decode=True, inference=True, return_params=True + ) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + recon = recon.reshape( + eval_scaled_img_resolution, + eval_scaled_img_resolution, + eval_scaled_img_resolution, + ) + elif ("pointcloud" in this_run_name) and ("SDF" not in this_run_name): + batch = move(batch, device) + z, z_params = model.get_embeddings(batch, inference=True) + xhat = model.decode_embeddings( + z_params, batch, decode=True, return_canonical=True + ) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + canonical = recon + if "canonical" in xhat.keys(): + canonical = xhat["canonical"].detach().cpu().numpy().squeeze() + else: + z = model.encode(move(batch, device)) + xhat = model.decode(z, return_canonical=True) + recon = xhat[this_key].detach().cpu().numpy().squeeze() + canonical = xhat["canonical"].detach().cpu().numpy().squeeze() + + this_save_path = Path(save_path) / Path(this_run_name) + this_save_path.mkdir(parents=True, exist_ok=True) + np.save(this_save_path / Path(f"{cell_id}.npy"), recon) + + this_save_path_input = Path(save_path) / Path(this_run_name) / Path("input") + this_save_path_input.mkdir(parents=True, exist_ok=True) + np.save(this_save_path_input / Path(f"{cell_id}.npy"), input) + + if canonical is not None: + this_save_path_canon = ( + Path(save_path) / Path(this_run_name) / Path("canonical") + ) + this_save_path_canon.mkdir(parents=True, exist_ok=True) + np.save(this_save_path_canon / Path(f"{cell_id}.npy"), canonical) + + +def save_supplemental_figure_punctate_reconstructions( + df, test_ids, run_names, reconstructions_path, normalize_across_recons, dataset_name +): + def slice_(img, slices=None, z_ind=0): + if not slices: + return img.max(z_ind) + mid_z = int(img.shape[0] / 2) + if z_ind == 0: + img = img[mid_z - slices : mid_z + slices].max(0) + if z_ind == 2: + img = img[:, :, mid_z - slices : mid_z + slices].max(2) + return img + + def slice_points_(points, z_max, z_loc=0): + inds = np.where(points[:, z_loc] < z_max)[0] + points = points[inds, :] + inds = np.where(points[:, z_loc] > -z_max)[0] + points = points[inds, :] + return points + + def _plot_image(input, recon, recon_canonical, dataset_name): + num_slice = 8 + + if dataset_name != "cellpack": + z_ind = 0 + else: + z_ind = 2 + + input = slice_(input, num_slice, z_ind) + recon = slice_(recon, num_slice, z_ind) + recon_canonical = slice_(recon_canonical, num_slice, 2) + + if dataset_name == "cellpack": + recon = recon.T + recon_canonical = recon_canonical.T + + i = 2 + fig, (ax, ax1, ax2) = plt.subplots(1, 3, figsize=(8, 4)) + ax.imshow(input, cmap="gray_r") + ax1.imshow(recon, cmap="gray_r") + ax2.imshow(recon_canonical, cmap="gray_r") + + ax.set_xticks([]) + ax.set_yticks([]) + ax2.set_xticks([]) + ax2.set_yticks([]) + ax1.set_xticks([]) + ax1.set_yticks([]) + # max_size = 192 + max_size = recon_canonical.shape[1] + ax.set_aspect("equal", adjustable="box") + ax1.set_aspect("equal", adjustable="box") + ax2.set_aspect("equal", adjustable="box") + + # max_size = 6 + ax.set_ylim([0, max_size]) + ax1.set_ylim([0, max_size]) + ax2.set_ylim([0, max_size]) + ax.set_xlim([0, max_size]) + ax1.set_xlim([0, max_size]) + ax2.set_xlim([0, max_size]) + + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + + ax1.spines["top"].set_visible(False) + ax1.spines["right"].set_visible(False) + ax1.spines["bottom"].set_visible(False) + ax1.spines["left"].set_visible(False) + + ax2.spines["top"].set_visible(False) + ax2.spines["right"].set_visible(False) + ax2.spines["bottom"].set_visible(False) + ax2.spines["left"].set_visible(False) + + ax.set_title("Input") + ax1.set_title("Reconstruction") + ax2.set_title("Canonical Reconstruction") + fig.subplots_adjust(wspace=0, hspace=0) + return fig + + def _plot_pc(input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_name): + z_max = 0.3 + max_size = 15 + + if dataset_name != "cellpack": + z_ind = 2 + canon_z_ind = 2 + if (struct == "pcna") and ((recon == recon_canonical).all() == False): + canon_z_ind = 1 + else: + z_ind = 0 + canon_z_ind = 1 + xy_inds = [i for i in [0, 1, 2] if i != z_ind] + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + for index_, this_p in enumerate([input, recon, recon_canonical]): + if struct in ["NUP153", "HIST1H2BJ", "SMC1A", "SON", "pcna"]: + if index_ == 2: + this_p = slice_points_(this_p, z_max, canon_z_ind) + else: + this_p = slice_points_(this_p, z_max, z_ind) + if this_p.shape[-1] == 3: + if (index_ == 2) and (canon_z_ind != z_ind): + xy_inds = [i for i in [0, 1, 2] if i != canon_z_ind] + axes[index_].scatter( + this_p[:, xy_inds[0]], this_p[:, xy_inds[1]], c="black", s=2, alpha=0.5 + ) + else: + if not cmap: + this_df = pd.DataFrame(input, columns=["x", "y", "z", "s"]) + all_df_input, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df=this_df, pcts=[5, 95] + ) + this_p = pd.DataFrame(this_p, columns=["x", "y", "z", "s"]) + this_p = normalize_intensities_and_get_colormap_apply(this_p, vmin, vmax) + + if (index_ == 2) and (canon_z_ind != z_ind): + xy_inds = [i for i in [0, 1, 2] if i != canon_z_ind] + + x_vals = this_p.iloc[:, xy_inds[0]].values + y_vals = this_p.iloc[:, xy_inds[1]].values + + axes[index_].scatter( + x_vals, + y_vals, + c=cmap(this_p["inorm"].values), + s=2, + alpha=0.5, + ) + axes[index_].spines["top"].set_visible(False) + axes[index_].spines["right"].set_visible(False) + axes[index_].spines["bottom"].set_visible(False) + axes[index_].spines["left"].set_visible(False) + axes[index_].set_aspect("equal", adjustable="box") + axes[index_].set_ylim([-max_size, max_size]) + axes[index_].set_xlim([-max_size, max_size]) + axes[index_].set_yticks([]) + axes[index_].set_xticks([]) + + axes[0].set_title("Input") + axes[1].set_title("Reconstruction") + axes[2].set_title("Canonical Reconstruction") + return fig + + for m in run_names: + for i, this_id in enumerate(test_ids): + struct = "pcna" + if "structure_name" in df.columns: + df["CellId"] = df["CellId"].astype(str) + struct = df.loc[df["CellId"] == this_id]["structure_name"].iloc[0] + row_index = i + + input_path = reconstructions_path + f"{m}/input/{this_id}.npy" + input = np.load(input_path).squeeze() + + recon_path = reconstructions_path + f"{m}/{this_id}.npy" + recon = np.load(recon_path).squeeze() + + recon_path = reconstructions_path + f"{m}/canonical/{this_id}.npy" + recon_canonical = np.load(recon_path).squeeze() + + if "image" in m: + fig = _plot_image(input, recon, recon_canonical, dataset_name) + else: + cmap = None + vmin = None + vmax = None + if normalize_across_recons: + all_df_input = [] + for c in test_ids: + input_path_ = reconstructions_path + f"{m}/input/{this_id}.npy" + input_tmp = np.load(input_path_).squeeze() + if input.shape[-1] == 4: + this_df = pd.DataFrame(input_tmp, columns=["x", "y", "z", "s"]) + all_df_input.append(this_df) + if len(all_df_input) > 0: + all_df_input = pd.concat(all_df_input, axis=0).reset_index(drop=True) + _, cmap, vmin, vmax = normalize_intensities_and_get_colormap( + df=all_df_input, pcts=[5, 95] + ) + fig = _plot_pc( + input, recon, recon_canonical, struct, cmap, vmin, vmax, dataset_name + ) + + this_save_path_ = Path(reconstructions_path) / Path(m) + print(this_save_path_, this_id) + fig.savefig( + this_save_path_ / Path(f"sample_recons_{this_id}.pdf"), + bbox_inches="tight", + dpi=300, + ) + fig.savefig( + this_save_path_ / Path(f"sample_recons_{this_id}.png"), + bbox_inches="tight", + dpi=300, + ) + + +def save_supplemental_figure_sdf_reconstructions(df, test_ids, reconstructions_path): + import pyvista as pv + + pv.start_xvfb() + gt_test_sdfs = [] + gt_test_segs = [] + for tid in test_ids: + path = df[df["CellId"] == int(tid)]["sdf_path"].values[0] + sdf = np.load(path) + path = df[df["CellId"] == int(tid)]["seg_path"].values[0] + seg = np.load(path) + gt_test_sdfs.append(sdf) + gt_test_segs.append(seg) + + gt_orig_struc = [] + gt_orig_cell = [] + gt_orig_nuc = [] + for tid in test_ids: + path = df[df["CellId"] == int(tid)]["crop_seg_masked"].values[0] + seg = AICSImage(path).data.squeeze() + path = df[df["CellId"] == int(tid)]["crop_seg"].values[0] + img = AICSImage(path).data.squeeze() + gt_orig_struc.append(seg) + gt_orig_nuc.append(img[0]) + gt_orig_cell.append(img[1]) + + eval_scaled_img_resolution = 32 + mid_slice_ = int(eval_scaled_img_resolution / 2) + uni_sample_points_grid = get_iae_reconstruction_3d_grid( + bb_min=-0.5, bb_max=0.5, resolution=eval_scaled_img_resolution, padding=0.1 + ) + gt_test_sdfs_iae = [] + mesh_folder = df["mesh_folder"].iloc[0] + for tid in test_ids: + path = mesh_folder + str(tid) + ".stl" + mesh = trimesh.load(path) + bbox = mesh.bounding_box.bounds + loc = (bbox[0] + bbox[1]) / 2 + scale_factor = (bbox[1] - bbox[0]).max() + mesh = mesh.apply_translation(-loc) + mesh = mesh.apply_scale(1 / scale_factor) + sdf_vals = mesh_to_sdf.mesh_to_sdf(mesh, query_points=uni_sample_points_grid.numpy()) + gt_test_sdfs_iae.append(sdf_vals) + + cmap_inverted = mcolors.ListedColormap(np.flipud(plt.cm.gray(np.arange(256)))) + + model_order = [ + "Classical_image_seg", + "Rotation_invariant_image_seg", + "Classical_image_SDF", + "Rotation_invariant_image_SDF", + "Rotation_invariant_pointcloud_SDF", + ] + + for split, split_ids, gt_segs, gt_sdfs, gt_test_i_sdfs in [ + ("test", test_ids, gt_test_segs, gt_test_sdfs, gt_test_sdfs_iae) + ]: + + num_rows = len(split_ids) + num_columns = len(model_order) + 4 + fig, axs = plt.subplots(num_rows, num_columns, figsize=(num_columns * 5, num_rows * 5)) + + for i, c in enumerate(split_ids): + gt_seg = gt_segs[i] + gt_sdf = np.clip(gt_sdfs[i], -2, 2) + gt_sdf_i = gt_test_i_sdfs[i].reshape( + eval_scaled_img_resolution, eval_scaled_img_resolution, eval_scaled_img_resolution + ) + row_index = i + recons = [] + for m in model_order: + recon_path = reconstructions_path + f"{m}/{c}.npy" + recon = np.load(recon_path).squeeze() + + if "SDF" in m or "vn" in m: + mid_z = recon.shape[0] // 2 + if ("Rotation" in m) and (m != "Rotation_invariant_pointcloud_SDF"): + z_slice = recon[mid_z, :, :].T + from scipy.ndimage import rotate + + z_slice = rotate(z_slice, angle=-135, cval=2) + z_slice = z_slice[14:-14, 14:-14] + else: + z_slice = recon[:, :, mid_z].T + else: + z_slice = recon.max(0) + recons.append(z_slice) + + struc_seg = gt_orig_struc[i] + cell_seg = gt_orig_cell[i] + nuc_seg = gt_orig_nuc[i] + + axs[row_index, 0].imshow( + cell_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5 + ) + axs[row_index, 0].imshow(nuc_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5) + axs[row_index, 0].imshow( + struc_seg.max(0), cmap=cmap_inverted, origin="lower", alpha=0.5 + ) + axs[row_index, 0].axis("off") + axs[row_index, 0].set_title("") # (f'GT Seg CellId {c}') + + axs[row_index, 1].imshow(gt_seg.max(0), cmap=cmap_inverted, origin="lower") + axs[row_index, 1].axis("off") + axs[row_index, 1].set_title("") # (f'GT Seg CellId {c}') + + for i, img in enumerate(recons[:2]): + axs[row_index, i + 2].imshow(img, cmap=cmap_inverted, origin="lower") + axs[row_index, i + 2].axis("off") + axs[row_index, i + 2].set_title("") # run_to_displ_name[model_order[i]]) + + axs[row_index, 4].imshow( + gt_sdf[:, :, mid_slice_].T, cmap="seismic", origin="lower", vmin=-2, vmax=2 + ) + axs[row_index, 4].axis("off") + axs[row_index, 4].set_title("") # (f'GT SDF CellId {c}') + + for i, img in enumerate(recons[2:4]): + axs[row_index, i + 5].imshow(img, cmap="seismic", origin="lower", vmin=-2, vmax=2) + axs[row_index, i + 5].axis("off") + axs[row_index, i + 5].set_title("") # run_to_displ_name[model_order[i]]) + + axs[row_index, 7].imshow( + gt_sdf_i[:, :, mid_slice_].T.clip(-0.5, 0.5), cmap="seismic", origin="lower" + ) + axs[row_index, 7].axis("off") + axs[row_index, 7].set_title("") # (f'GT SDF CellId {c}') + + axs[row_index, 8].imshow(recons[-1].clip(-0.5, 0.5), cmap="seismic", origin="lower") + axs[row_index, 8].axis("off") + axs[row_index, 8].set_title("") # (f'GT SDF CellId {c}') + + plt.tight_layout() + plt.savefig(reconstructions_path + "sample_recons.png", dpi=300, bbox_inches="tight") + plt.savefig(reconstructions_path + "sample_recons.pdf", dpi=300, bbox_inches="tight") diff --git a/src/br/analysis/save_reconstructions.py b/src/br/analysis/save_reconstructions.py new file mode 100644 index 0000000..6b1bc8a --- /dev/null +++ b/src/br/analysis/save_reconstructions.py @@ -0,0 +1,143 @@ +# Free up cache +import argparse +import os +import sys +from pathlib import Path + +from br.analysis.analysis_utils import ( + generate_reconstructions, + save_supplemental_figure_punctate_reconstructions, + save_supplemental_figure_sdf_reconstructions, + setup_gpu, + str2bool, +) +from br.models.load_models import get_data_and_models + +test_ids_per_dataset_ = { + "cellpack": [ + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_0", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_1", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_2", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_3", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_4", + "9c1ff213-4e9e-4b73-a942-3baf9d37a50f_5", + ], + "pcna": [ + "7624cd5b-715a-478e-9648-3bac4a73abe8", + "80d40c5e-65bf-43b0-8dea-b697c421ea78", + "6a3ab51f-fa68-4fe1-a13b-2b2461ed71b4", + "aabbbca4-6c35-4f3d-9467-7d573482f236", + "d23de56e-bacf-4ec8-8e18-39822fea777b", + "c382794f-5baf-4b17-8574-62dccbbbaefc", + "50b52c3e-4756-4684-a281-0141525ded9f", + "8713eea5-da72-4644-96fe-ba8340edb67d", + ], + "other_punctate": ["721646", "873680", "994027", "490385", "451974", "811336", "835431"], + "npm1": ["964798", "661110", "644401", "967887", "703621"], + "other_polymorphic": ["691110", "723687", "816468", "800894"], +} + + +def main(args): + # Setup GPUs and set the device + setup_gpu() + device = "cuda:0" + + # set batch size to 1 for emission stats/features + batch_size = 1 + + # Get config path from CYTODL_CONFIG_PATH + config_path = os.environ.get("CYTODL_CONFIG_PATH") + + test_ids = args.test_ids + if not test_ids: + test_ids = test_ids_per_dataset_[args.dataset_name] + + if args.dataset_name == "cellpack": + args.debug = True + + # Load data and models + ( + data_list, + all_models, + run_names, + model_sizes, + manifest, + keys, + latent_dims, + ) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug) + + # make save path directory + Path(args.save_path).mkdir(parents=True, exist_ok=True) + + if args.generate_reconstructions: + generate_reconstructions( + all_models, data_list, run_names, keys, test_ids, device, args.save_path + ) + + if args.sdf: + save_supplemental_figure_sdf_reconstructions(manifest, test_ids, args.save_path) + else: + save_supplemental_figure_punctate_reconstructions( + manifest, + test_ids, + run_names, + args.save_path, + args.normalize_across_recons, + args.dataset_name, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script for computing features") + parser.add_argument( + "--save_path", type=str, required=True, help="Path to save the embeddings." + ) + parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.") + parser.add_argument("--debug", type=str2bool, default=False, help="Enable debug mode.") + parser.add_argument( + "--sdf", type=str2bool, default=True, help="Whether the experiments involve SDFs" + ) + parser.add_argument( + "--test_ids", default=False, nargs="+", help="List of test set cellids to reconstruct" + ) + parser.add_argument( + "--generate_reconstructions", + type=str2bool, + default=False, + help="Whether to skip generating reconstructions", + ) + parser.add_argument( + "--normalize_across_recons", + type=str2bool, + default=False, + help="Whether to normalize across all inputs", + ) + + args = parser.parse_args() + + # Validate that required paths are provided + if not args.save_path or not args.dataset_name: + print("Error: Required arguments are missing.") + sys.exit(1) + + main(args) + + """ + Example run: + + cellPACK dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_cellpack/reconstructions/" --dataset_name "cellpack" --generate_reconstructions True --sdf False + + PCNA dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_pcna/reconstructions/" --dataset_name "pcna" --generate_reconstructions True --sdf False --normalize_across_recons True + + NPM1 dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_npm1/reconstructions/" --dataset_name "npm1" --test_ids 964798 661110 644401 967887 703621 --generate_reconstructions True --sdf True + + Other polymorphic dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_polymorphic/reconstructions/" --dataset_name "other_polymorphic" --test_ids 691110 723687 816468 800894 --generate_reconstructions True --sdf True + + Other punctate dataset + python src/br/analysis/save_reconstructions.py --save_path "./outputs_other_punctate/reconstructions/" --dataset_name "other_punctate" --generate_reconstructions True --normalize_across_recons False --sdf False + """ diff --git a/src/br/data/get_datamodules.py b/src/br/data/get_datamodules.py index c302348..15c4a43 100644 --- a/src/br/data/get_datamodules.py +++ b/src/br/data/get_datamodules.py @@ -28,5 +28,8 @@ def get_data(dataset_name, batch_size, results_path, debug=False): config["subsample"]["train"] = 4 config["subsample"]["valid"] = 4 config["subsample"]["test"] = 4 + + if config["_target_"] == "pointcloudutils.datamodules.CellPackDataModule": + config["subset_test"] = True data.append(instantiate(config)) return data diff --git a/src/pointcloudutils/datamodules/cellpack.py b/src/pointcloudutils/datamodules/cellpack.py index 6067150..f27e80f 100644 --- a/src/pointcloudutils/datamodules/cellpack.py +++ b/src/pointcloudutils/datamodules/cellpack.py @@ -45,6 +45,7 @@ def __init__( num_rotations: Optional[int] = 3, upsample: Optional[bool] = False, image: Optional[bool] = False, + subset_test: Optional[bool] = False, ): """""" super().__init__() @@ -68,6 +69,7 @@ def __init__( self.norm_feats = norm_feats self.num_rotations = num_rotations self.image = image + self.subset_test = subset_test def _get_dataset(self, split): return CellPackDataset( @@ -89,6 +91,7 @@ def _get_dataset(self, split): self.num_rotations, self.upsample, self.image, + self.subset_test, ) def train_dataloader(self): @@ -150,6 +153,7 @@ def __init__( num_rotations: Optional[int] = 3, upsample: Optional[bool] = False, image: Optional[bool] = False, + subset_test: Optional[bool] = False, ): self.x_label = x_label self.scale = scale @@ -167,6 +171,7 @@ def __init__( self.num_rules = len(self.packing_rules) self.num_rotations = num_rotations self.image = image + self.subset_test = subset_test self.ref_csv = pd.read_csv(ref_path + "manifest.csv") @@ -187,8 +192,9 @@ def __init__( self.ids = _splits[split] - # if split == "test": - # self.ids = ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"] + if self.subset_test: + if split == "test": + self.ids = ["9c1ff213-4e9e-4b73-a942-3baf9d37a50f"] if norm_feats: x = self.ref_csv[ diff --git a/subpackages/image_preprocessing/README.md b/subpackages/image_preprocessing/README.md index b9d6bc3..b9dba6d 100644 --- a/subpackages/image_preprocessing/README.md +++ b/subpackages/image_preprocessing/README.md @@ -43,6 +43,7 @@ Before running the script, please ensure the following: 1. **Set the `TMPDIR`:** You need to set the `TMPDIR` environment variable, as the Snakefile requires a temporary directory. You can do this by executing: + ```bash export TMPDIR=/path/to/your/tmpdir ```