From a53a8b15294549ab1081a20ec2ff413c06f21a85 Mon Sep 17 00:00:00 2001 From: Lenz Date: Thu, 4 Apr 2024 18:27:50 +0200 Subject: [PATCH 1/4] Added ome-tiff inference pipeline --- src/nimbus_inference/nimbus.py | 27 +++--- src/nimbus_inference/utils.py | 160 ++++++++++++++++++++++++++++++++- tests/test_utils.py | 130 ++++++++++++++++++++++++--- 3 files changed, 290 insertions(+), 27 deletions(-) diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 4850e8c..6445e2f 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,11 +1,8 @@ from alpineer import io_utils, misc_utils from skimage.util.shape import view_as_windows import nimbus_inference -from nimbus_inference.utils import ( - prepare_normalization_dict, - predict_fovs, - predict_ome_fovs, - nimbus_preprocess, +from nimbus_inference.utils import (prepare_normalization_dict, prepare_normalization_dict_ome, + predict_fovs, predict_ome_fovs, nimbus_preprocess, ) from huggingface_hub import hf_hub_download from nimbus_inference.unet import UNet @@ -161,14 +158,14 @@ def initialize_model(self, padding="reflect"): self.checkpoint_path = os.path.join( path, "assets", - "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" + "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt" ) if not os.path.exists(self.checkpoint_path): local_dir = os.path.join(path, "assets") print("Downloading weights from Hugging Face Hub...") self.checkpoint_path = hf_hub_download( repo_id="JLrumberger/Nimbus-Inference", - filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", + filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt", local_dir=local_dir, local_dir_use_symlinks=False, ) @@ -192,11 +189,17 @@ def prepare_normalization_dict( if os.path.exists(self.normalization_dict_path) and not overwrite: self.normalization_dict = json.load(open(self.normalization_dict_path)) else: - n_jobs = os.cpu_count() if multiprocessing else 1 - self.normalization_dict = prepare_normalization_dict( - self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, n_jobs - ) + if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: + self.normalization_dict = prepare_normalization_dict_ome( + self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, + n_jobs + ) + else: + self.normalization_dict = prepare_normalization_dict( + self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, + n_jobs + ) if self.include_channels == []: self.include_channels = list(self.normalization_dict.keys()) @@ -223,7 +226,7 @@ def predict_fovs(self): half_resolution=self.half_resolution, batch_size=self.batch_size, test_time_augmentation=self.test_time_aug, suffix=self.suffix, ) - elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]: + else: self.cell_table = predict_fovs( nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, normalization_dict=self.normalization_dict, diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 7d74fea..ea2b4d2 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -12,6 +12,7 @@ from joblib.externals.loky import get_reusable_executor from skimage.segmentation import find_boundaries from skimage.measure import regionprops_table +from pyometiff import OMETIFFReader def calculate_normalization(channel_path, quantile): @@ -196,6 +197,8 @@ def predict_fovs( os.path.normpath(output_dir), os.path.basename(fov_path) ) df_fov = pd.DataFrame() + instance_path = segmentation_naming_convention(fov_path) + instance_mask = np.squeeze(io.imread(instance_path)) for channel in tqdm(os.listdir(fov_path)): channel_path = os.path.join(fov_path, channel) channel_ = channel.split(".")[0] @@ -204,8 +207,6 @@ def predict_fovs( ): continue mplex_img = np.squeeze(io.imread(channel_path)) - instance_path = segmentation_naming_convention(fov_path) - instance_mask = np.squeeze(io.imread(instance_path)) input_data = prepare_input_data(mplex_img, instance_mask) if half_resolution: scale = 0.5 @@ -275,5 +276,156 @@ def nimbus_preprocess(image, **kwargs): return output -def predict_ome_fovs(): - pass \ No newline at end of file +def calculate_normalization_ome(ome_path, quantile, include_channels): + """Calculates the normalization values for a given ome file + Args: + ome_path (str): path to ome file + quantile (float): quantile to use for normalization + include_channels (list): list of channels to include + Returns: + normalization_values (dict): dict with channel names as keys and norm factors as values + """ + reader = OMETIFFReader(fpath=ome_path) + img_array, metadata, _ = reader.read() + channel_names = list(metadata["Channels"].keys()) + if not include_channels: + include_channels = channel_names + # check if include_channels are included in ome file metadata + for channel in include_channels: + if channel not in channel_names: + raise ValueError(f"Channel {channel} not found in ome file metadata.") + normalization_values = {} + for channel in include_channels: + idx = channel_names.index(channel) + mplex_img = img_array[idx] + mplex_img = mplex_img.astype(np.float32) + foreground = mplex_img[mplex_img > 0] + normalization_values[channel] = np.quantile(foreground, quantile) + return normalization_values + + +def prepare_normalization_dict_ome( + fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1, + output_name="normalization_dict.json" + ): + """Prepares the normalization dict for a list of ome.tif fovs + Args: + fov_paths (list): list of paths to fovs + output_dir (str): path to output directory + quantile (float): quantile to use for normalization + exclude_channels (list): list of channels to exclude + n_subset (int): number of fovs to use for normalization + n_jobs (int): number of jobs to use for joblib multiprocessing + output_name (str): name of output file + Returns: + normalization_dict (dict): dict with channel names as keys and norm factors as values + """ + normalization_dict = {} + if n_subset is not None: + random.shuffle(fov_paths) + fov_paths = fov_paths[:n_subset] + print("Iterate over fovs...") + if n_jobs > 1: + normalization_values = Parallel(n_jobs=n_jobs)( + delayed(calculate_normalization_ome)(ome_path, quantile, include_channels) + for ome_path in fov_paths + ) + else: + normalization_values = [ + calculate_normalization_ome(ome_path, quantile, include_channels) + for ome_path in fov_paths + ] + for norm_dict in normalization_values: + for channel, normalization_value in norm_dict.items(): + if channel not in normalization_dict: + normalization_dict[channel] = [] + normalization_dict[channel].append(normalization_value) + if n_jobs > 1: + get_reusable_executor().shutdown(wait=True) + for channel in normalization_dict.keys(): + normalization_dict[channel] = np.mean(normalization_dict[channel]) + # save normalization dict + with open(os.path.join(output_dir, output_name), 'w') as f: + json.dump(normalization_dict, f) + return normalization_dict + + +def predict_ome_fovs( + nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir, + suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4, + test_time_augmentation=True + ): + """Predicts the segmentation map for each mplex channel in each ome.tif fov + Args: + nimbus (Nimbus): nimbus object + fov_paths (list): list of fov paths + normalization_dict (dict): dict with channel names as keys and norm factors as values + segmentation_naming_convention (function): function to get instance mask path from fov path + output_dir (str): path to output dir + suffix (str): suffix of mplex images + include_channels (list): list of channels to include + save_predictions (bool): whether to save predictions + half_resolution (bool): whether to use half resolution + batch_size (int): batch size + test_time_augmentation (bool): whether to use test time augmentation + Returns: + cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell + """ + fov_dict_list = [] + for fov_path in fov_paths: + print(f"Predicting {fov_path}...") + out_fov_path = os.path.join( + os.path.normpath(output_dir), os.path.basename(fov_path).split(".")[0] + ) + df_fov = pd.DataFrame() + reader = OMETIFFReader(fpath=fov_path) + img_array, metadata, _ = reader.read() + channel_names = list(metadata["Channels"].keys()) + instance_path = segmentation_naming_convention(fov_path) + instance_mask = np.squeeze(io.imread(instance_path)) + if not include_channels: + include_channels = channel_names + for channel in tqdm(include_channels): + idx = channel_names.index(channel) + mplex_img = np.squeeze(img_array[idx]) + input_data = prepare_input_data(mplex_img, instance_mask) + if half_resolution: + scale = 0.5 + input_data = np.squeeze(input_data) + _, h,w = input_data.shape + img = cv2.resize(input_data[0], [int(h*scale), int(w*scale)]) + binary_mask = cv2.resize( + input_data[1], [int(h*scale), int(w*scale)], interpolation=0 + ) + input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...] + if test_time_augmentation: + prediction = test_time_aug( + input_data, channel, nimbus, normalization_dict, batch_size=batch_size + ) + + else: + prediction = nimbus.predict_segmentation( + input_data, + preprocess_kwargs={ + "normalize": True, "marker": channel, + "normalization_dict": normalization_dict + }, + ) + prediction = np.squeeze(prediction) + if half_resolution: + prediction = cv2.resize(prediction, (h, w)) + df = pd.DataFrame(segment_mean(instance_mask, prediction)) + if df_fov.empty: + df_fov["label"] = df["label"] + df_fov["fov"] = os.path.basename(fov_path) + df_fov[channel] = df["intensity_mean"] + if save_predictions: + os.makedirs(out_fov_path, exist_ok=True) + pred_int = (prediction*255.0).astype(np.uint8) + io.imwrite( + os.path.join(out_fov_path, channel+".tiff"), pred_int, photometric="minisblack", + # compress=0, + ) + fov_dict_list.append(df_fov) + cell_table = pd.concat(fov_dict_list, ignore_index=True) + return cell_table diff --git a/tests/test_utils.py b/tests/test_utils.py index 7882bae..77b6203 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,11 @@ -from nimbus_inference.utils import prepare_normalization_dict, calculate_normalization -from nimbus_inference.utils import predict_fovs, predict_ome_fovs, prepare_input_data +from nimbus_inference.utils import (prepare_normalization_dict, calculate_normalization, +predict_fovs, predict_ome_fovs, calculate_normalization_ome, prepare_normalization_dict_ome, +prepare_input_data) from nimbus_inference.utils import test_time_aug as tt_aug from nimbus_inference.nimbus import Nimbus from skimage import io from pyometiff import OMETIFFWriter +import pytest import numpy as np import tempfile import torch @@ -56,25 +58,30 @@ def prepare_tif_data(num_samples, temp_dir, selected_markers, random=False, std= def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, std=1): np.random.seed(42) metadata_dict = { - "PhysicalSizeX" : "0.88", + "SizeX" : 256, + "SizeY" : 256, + "SizeC" : len(selected_markers) + 3, + "PhysicalSizeX" : 0.5, "PhysicalSizeXUnit" : "µm", - "PhysicalSizeY" : "0.88", + "PhysicalSizeY" : 0.5, "PhysicalSizeYUnit" : "µm", - "PhysicalSizeZ" : "3.3", - "PhysicalSizeZUnit" : "µm", } - + fov_paths = [] + inst_paths = [] + if isinstance(std, (int, float)) or len(std) != len(selected_markers): + std = [std] * len(selected_markers) for i in range(num_samples): metadata_dict["Channels"] = {} channels = [] - for marker in zip(selected_markers): + for j, (marker, s) in enumerate(zip(selected_markers, std)): if random: - img = np.random.rand(256, 256) * std + img = np.random.rand(256, 256) * s else: img = np.ones([256, 256]) - channels.append(img) + channels.append(img) metadata_dict["Channels"][marker] = { "Name" : marker, + "ID": str(j), "SamplesPerPixel": 1, } channel_data = np.stack(channels, axis=0) @@ -87,7 +94,17 @@ def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, metadata=metadata_dict, explicit_tiffdata=False) writer.write() - return None + deepcell_dir = os.path.join(temp_dir, "deepcell_output") + os.makedirs(deepcell_dir, exist_ok=True) + inst_path = os.path.join(deepcell_dir, f"fov_{i}_whole_cell.tiff") + io.imsave( + inst_path, np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + ).repeat(64, axis=1).repeat(64, axis=0) + ) + fov_paths.append(sample_name) + inst_paths.append(inst_path) + return fov_paths, inst_paths def test_calculate_normalization(): @@ -237,3 +254,94 @@ def segmentation_naming_convention(fov_path): ) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) + + +def test_calculate_normalization_ome(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56", "CD45"] + ) + + norm_dict = calculate_normalization_ome( + fov_paths[0], 0.999, include_channels=["CD4", "CD56"] + ) + # check if we get the correct normalization values + assert np.isclose(norm_dict["CD4"], 1.0, 0.01) + assert np.isclose(norm_dict["CD56"], 1.0, 0.01) + # check if ValueError is raised if include_channels are not in the ome.tif metadata + with pytest.raises(ValueError): + calculate_normalization_ome( + fov_paths[0], 0.999, include_channels=["CD42", "CD56"] + ) + + +def test_prepare_normalization_dict_ome(): + with tempfile.TemporaryDirectory() as temp_dir: + scales = [0.5, 1.0, 1.5, 2.0, 5.0] + channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] + fov_paths, _ = prepare_ome_tif_data( + num_samples=3, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales + ) + normalization_dict = prepare_normalization_dict_ome( + fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, include_channels=channels, + output_name="normalization_dict.json" + ) + # test if normalization dict got saved + assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) + assert normalization_dict == json.load( + open(os.path.join(temp_dir, "normalization_dict.json")) + ) + # test if normalization dict is correct + for channel, scale in zip(channels, scales): + assert np.isclose(normalization_dict[channel], scale, 0.01) + + # test if multiprocessing yields approximately the same results + normalization_dict_mp = prepare_normalization_dict_ome( + fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, include_channels=channels, + output_name="normalization_dict.json" + ) + for key in normalization_dict.keys(): + assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) + + +def test_predict_ome_fovs(): + def segmentation_naming_convention(fov_path): + temp_dir_, fov_ = os.path.split(fov_path) + fov_ = fov_.split(".")[0] + return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") + + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] + ) + output_dir = os.path.join(temp_dir, "nimbus_output") + nimbus = Nimbus( + fov_paths, segmentation_naming_convention, output_dir, suffix=".ome.tiff" + ) + output_dir = os.path.join(temp_dir, "nimbus_output") + nimbus.prepare_normalization_dict() + cell_table = predict_ome_fovs( + nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, + segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + save_predictions=False, half_resolution=True, + ) + # check if we get the correct number of cells + assert len(cell_table) == 15 + # check if we get the correct columns (fov, label, CD4, CD56) + assert np.alltrue( + set(cell_table.columns) == set(["fov", "label", "CD4", "CD56"]) + ) + # check if predictions don't get written to output_dir + assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) + assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) + # + # run again with save_predictions=True and check if predictions get written to output_dir + cell_table = predict_ome_fovs( + nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, + segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + save_predictions=True, half_resolution=True, + ) + assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) + assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) From eafcfc207675132c16e9199a289fc01d64ed10a0 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 8 Apr 2024 17:16:10 +0200 Subject: [PATCH 2/4] Added an abstraction layer for the different datasets to reduce code complexity everywhere else. --- pyproject.toml | 1 + src/nimbus_inference/nimbus.py | 61 +--- src/nimbus_inference/utils.py | 394 ++++++++++++++------------ src/nimbus_inference/viewer_widget.py | 122 +++++--- tests/test_cell_analyzer.py | 2 +- tests/test_nimbus.py | 26 +- tests/test_utils.py | 212 +++++++------- tests/test_viewer_widget.py | 50 +++- 8 files changed, 449 insertions(+), 419 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b5202d9..394a110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "ipywidgets", "natsort", "ipython", + "zarr", ] diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 6445e2f..a3f1413 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,8 +1,8 @@ from alpineer import io_utils, misc_utils from skimage.util.shape import view_as_windows import nimbus_inference -from nimbus_inference.utils import (prepare_normalization_dict, prepare_normalization_dict_ome, - predict_fovs, predict_ome_fovs, nimbus_preprocess, +from nimbus_inference.utils import (prepare_normalization_dict, + predict_fovs, nimbus_preprocess, MultiplexDataset ) from huggingface_hub import hf_hub_download from nimbus_inference.unet import UNet @@ -76,15 +76,13 @@ class Nimbus(nn.Module): """Nimbus application class for predicting marker activity for cells in multiplexed images.""" def __init__( - self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True, - include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True, - input_shape=[1024, 1024], suffix=".tiff", device="auto", + self, dataset: MultiplexDataset, output_dir: str, save_predictions: bool=True, + include_channels: list=[], half_resolution: bool=True, batch_size: int=4, + test_time_aug: bool=True, input_shape: list=[1024, 1024], device: str="auto", ): """Initializes a Nimbus Application. Args: - fov_paths (list): List of paths to fovs to be analyzed. - segmentation_naming_convention (function): Function that returns the path to the - segmentation mask for a given fov path. + dataset (MultiplexDataset): Path to directory containing fovs. output_dir (str): Path to directory to save output. save_predictions (bool): Whether to save predictions. include_channels (list): List of channels to include in analysis. @@ -97,9 +95,8 @@ def __init__( , with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto". """ super(Nimbus, self).__init__() - self.fov_paths = fov_paths + self.dataset = dataset self.include_channels = include_channels - self.segmentation_naming_convention = segmentation_naming_convention self.output_dir = output_dir self.half_resolution = half_resolution self.save_predictions = save_predictions @@ -107,7 +104,6 @@ def __init__( self.checked_inputs = False self.test_time_aug = test_time_aug self.input_shape = input_shape - self.suffix = suffix if self.output_dir != "": os.makedirs(self.output_dir, exist_ok=True) @@ -124,20 +120,6 @@ def __init__( def check_inputs(self): """check inputs for Nimbus model""" - # check if all paths in fov_paths exists - if not isinstance(self.fov_paths, (list, tuple)): - self.fov_paths = [self.fov_paths] - io_utils.validate_paths(self.fov_paths) - - # check if segmentation_naming_convention returns valid paths - path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0]) - if not os.path.exists(path_to_segmentation): - raise FileNotFoundError( - "Function segmentation_naming_convention does not return valid\ - path. Segmentation path {} does not exist.".format( - path_to_segmentation - ) - ) # check if output_dir exists io_utils.validate_paths([self.output_dir]) @@ -190,16 +172,10 @@ def prepare_normalization_dict( self.normalization_dict = json.load(open(self.normalization_dict_path)) else: n_jobs = os.cpu_count() if multiprocessing else 1 - if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: - self.normalization_dict = prepare_normalization_dict_ome( - self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, - n_jobs - ) - else: - self.normalization_dict = prepare_normalization_dict( - self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, - n_jobs - ) + self.normalization_dict = prepare_normalization_dict( + self.dataset, self.output_dir, quantile, self.include_channels, n_subset, + n_jobs + ) if self.include_channels == []: self.include_channels = list(self.normalization_dict.keys()) @@ -218,20 +194,9 @@ def predict_fovs(self): print("Predictions will be saved in {}".format(self.output_dir)) print("Iterating through fovs will take a while...") if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: - self.cell_table = predict_ome_fovs( - nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, - normalization_dict=self.normalization_dict, - segmentation_naming_convention=self.segmentation_naming_convention, - include_channels=self.include_channels, save_predictions=self.save_predictions, - half_resolution=self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, suffix=self.suffix, - ) - else: self.cell_table = predict_fovs( - nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, - normalization_dict=self.normalization_dict, - segmentation_naming_convention=self.segmentation_naming_convention, - include_channels=self.include_channels, save_predictions=self.save_predictions, + nimbus=self, dataset=self.dataset, output_dir=self.output_dir, + normalization_dict=self.normalization_dict, save_predictions=self.save_predictions, half_resolution=self.half_resolution, batch_size=self.batch_size, test_time_augmentation=self.test_time_aug, suffix=self.suffix, ) diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index ea2b4d2..34a44f1 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -6,81 +6,195 @@ import numpy as np import pandas as pd import imageio as io -# from skimage import io from tqdm.autonotebook import tqdm from joblib import Parallel, delayed from joblib.externals.loky import get_reusable_executor from skimage.segmentation import find_boundaries from skimage.measure import regionprops_table from pyometiff import OMETIFFReader +from alpineer import io_utils, misc_utils +from typing import Callable +import tifffile +import zarr -def calculate_normalization(channel_path, quantile): - """Calculates the normalization value for a given channel - Args: - channel_path (str): path to channel - quantile (float): quantile to use for normalization - Returns: - normalization_value (float): normalization value - """ - mplex_img = io.imread(channel_path) - mplex_img = mplex_img.astype(np.float32) - foreground = mplex_img[mplex_img > 0] - normalization_value = np.quantile(foreground, quantile) - chan = os.path.basename(channel_path).split(".")[0] - return chan, normalization_value +class LazyOMETIFFReader(OMETIFFReader): + def __init__(self, fpath: str): + """Lazy OMETIFFReader class that reads channels only when needed + Args: + fpath (str): path to ome.tif file + """ + super().__init__(fpath) + self.metadata = self.get_metadata() + self.channels = self.get_channel_names() + self.shape = self.get_shape() + + def get_metadata(self): + """Get the metadata of the OME-TIFF file + Returns: + metadata (dict): metadata of the OME-TIFF file + """ + with tifffile.TiffFile(str(self.fpath)) as tif: + if tif.is_ome: + omexml_string = tif.ome_metadata + return self.parse_metadata(omexml_string) + else: + raise ValueError("File is not an OME-TIFF file.") + def get_channel_names(self): + """Get the channel names of the OME-TIFF file + Returns: + channel_names (list): list of channel names + """ + if hasattr(self, "metadata"): + return list(self.metadata["Channels"].keys()) + else: + return [] + + def get_shape(self): + """Get the shape of the OME-TIFF file array data + Returns: + shape (tuple): shape of the array data + """ + with tifffile.imread(str(self.fpath), aszarr=True) as store: + z = zarr.open(store, mode='r') + shape = z.shape + return shape -def prepare_normalization_dict( - fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1, - output_name="normalization_dict.json" - ): - """Prepares the normalization dict for a list of fovs - Args: - fov_paths (list): list of paths to fovs - output_dir (str): path to output directory - quantile (float): quantile to use for normalization - exclude_channels (list): list of channels to exclude - n_subset (int): number of fovs to use for normalization - n_jobs (int): number of jobs to use for joblib multiprocessing - output_name (str): name of output file - Returns: - normalization_dict (dict): dict with channel names as keys and norm factors as values - """ - normalization_dict = {} - if n_subset is not None: - random.shuffle(fov_paths) - fov_paths = fov_paths[:n_subset] - print("Iterate over fovs...") - for fov_path in tqdm(fov_paths): - channels = os.listdir(fov_path) - if include_channels: + def get_channel(self, channel_name: str): + """Get an individual channel from the OME-TIFF file by name + Args: + channel_name (str): name of the channel + Returns: + channel (np.array): channel image + """ + idx = self.channels.index(channel_name) + with tifffile.imread(str(self.fpath), aszarr=True) as store: + z = zarr.open(store, mode='r') + # correct DimOrder, often DimOrder is TZCYX, but image is stored as CYX, thus we remove + # the trailing dimensions + dim_order = self.metadata["DimOrder"] + dim_order = dim_order[-len(z.shape):] + channel_idx = dim_order.find("C") + slice_string = "z[" + ":," * channel_idx + str(idx) + "]" + channel = eval(slice_string) + return channel + + +class MultiplexDataset(): + def __init__( + self, fov_paths: list, segmentation_naming_convention: Callable = None, + suffix: str = ".tiff" + ): + """Multiplex dataset class that gives a common interface for data loading of multiplex + datasets stored as individual channel images in folders or as multi-channel tiffs. + Args: + fov_paths (list): list of paths to fovs + segmentation_naming_convention (function): function to get instance mask path from fov + path + """ + self.fov_paths = fov_paths + self.segmentation_naming_convention = segmentation_naming_convention + self.suffix = suffix + self.check_inputs() + self.multi_channel = self.is_multi_channel_tiff(fov_paths[0]) + self.fovs = self.get_fovs() + self.channels = self.get_channels() + + def check_inputs(self): + """check inputs for Nimbus model""" + # check if all paths in fov_paths exists + if not isinstance(self.fov_paths, (list, tuple)): + self.fov_paths = [self.fov_paths] + io_utils.validate_paths(self.fov_paths) + print("All inputs are valid") + + def __len__(self): + """Return the number of fovs in the dataset""" + return len(self.fov_paths) + + def is_multi_channel_tiff(self, fov_path: str): + """Check if fov is a multi-channel tiff + Args: + fov_path (str): path to fov + Returns: + multi_channel (bool): whether fov is multi-channel + """ + multi_channel = False + if fov_path.lower().endswith(("ome.tif", "ome.tiff")): + self.img_reader = LazyOMETIFFReader(fov_path) + if len(self.img_reader.shape) > 2: + multi_channel = True + return multi_channel + + def get_channels(self): + """Get the channel names for the dataset""" + if self.multi_channel: + return self.img_reader.channels + else: channels = [ - channel for channel in channels if channel.split(".")[0] in include_channels + channel.replace(self.suffix, "") for channel in os.listdir(self.fov_paths[0]) \ + if channel.endswith(self.suffix) ] - channel_paths = [os.path.join(fov_path, channel) for channel in channels] - if n_jobs > 1: - normalization_values = Parallel(n_jobs=n_jobs)( - delayed(calculate_normalization)(channel_path, quantile) - for channel_path in channel_paths - ) + return channels + + def get_fovs(self): + """Get the fovs in the dataset""" + return [os.path.basename(fov).replace(self.suffix, "") for fov in self.fov_paths] + + def get_channel(self, fov: str, channel: str): + """Get the channel from a fov + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + channel (np.array): channel image + """ + if self.multi_channel: + return self.get_channel_stack(fov, channel) else: - normalization_values = [ - calculate_normalization(channel_path, quantile) - for channel_path in channel_paths - ] - for channel, normalization_value in normalization_values: - if channel not in normalization_dict: - normalization_dict[channel] = [] - normalization_dict[channel].append(normalization_value) - if n_jobs > 1: - get_reusable_executor().shutdown(wait=True) - for channel in normalization_dict.keys(): - normalization_dict[channel] = np.mean(normalization_dict[channel]) - # save normalization dict - with open(os.path.join(output_dir, output_name), 'w') as f: - json.dump(normalization_dict, f) - return normalization_dict + return self.get_channel_single(fov, channel) + + def get_channel_single(self, fov: str, channel: str): + """Get the channel from a fov stored as a folder with individual channel images + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + channel (np.array): channel image + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + channel_path = os.path.join(fov_path, channel + self.suffix) + channel = np.squeeze(io.imread(channel_path)) + return channel + + def get_channel_stack(self, fov: str, channel: str): + """Get the channel from a multi-channel tiff + Args: + fov (str): name of a fov + channel (str): channel name + data_format (str): data format + Returns: + channel (np.array): channel image + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + self.img_reader = LazyOMETIFFReader(fov_path) + return np.squeeze(self.img_reader.get_channel(channel)) + + def get_segmentation(self, fov: str): + """Get the instance mask for a fov + Args: + fov (str): name of a fov + Returns: + instance_mask (np.array): instance mask + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + instance_path = self.segmentation_naming_convention(fov_path) + instance_mask = np.squeeze(io.imread(instance_path)) + return instance_mask def prepare_input_data(mplex_img, instance_mask): @@ -170,19 +284,17 @@ def test_time_aug( def predict_fovs( - nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir, - suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4, - test_time_augmentation=True + nimbus, dataset: MultiplexDataset, normalization_dict: dict, output_dir: str, + suffix: str="tiff", save_predictions: bool=True, half_resolution: bool=False, + batch_size: int=4, test_time_augmentation: bool=True ): """Predicts the segmentation map for each mplex image in each fov Args: nimbus (Nimbus): nimbus object - fov_paths (list): list of fov paths + dataset (MultiplexDataset): dataset object normalization_dict (dict): dict with channel names as keys and norm factors as values - segmentation_naming_convention (function): function to get instance mask path from fov path output_dir (str): path to output dir suffix (str): suffix of mplex images - include_channels (list): list of channels to include save_predictions (bool): whether to save predictions half_resolution (bool): whether to use half resolution batch_size (int): batch size @@ -191,22 +303,15 @@ def predict_fovs( cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell """ fov_dict_list = [] - for fov_path in fov_paths: + for fov_path, fov in zip(dataset.fov_paths, dataset.fovs): print(f"Predicting {fov_path}...") out_fov_path = os.path.join( os.path.normpath(output_dir), os.path.basename(fov_path) ) df_fov = pd.DataFrame() - instance_path = segmentation_naming_convention(fov_path) - instance_mask = np.squeeze(io.imread(instance_path)) - for channel in tqdm(os.listdir(fov_path)): - channel_path = os.path.join(fov_path, channel) - channel_ = channel.split(".")[0] - if not channel.endswith(suffix) or ( - include_channels != [] and channel_ not in include_channels - ): - continue - mplex_img = np.squeeze(io.imread(channel_path)) + instance_mask = dataset.get_segmentation(fov) + for channel_name in tqdm(dataset.channels): + mplex_img = dataset.get_channel(fov, channel_name) input_data = prepare_input_data(mplex_img, instance_mask) if half_resolution: scale = 0.5 @@ -219,13 +324,13 @@ def predict_fovs( input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...] if test_time_augmentation: prediction = test_time_aug( - input_data, channel, nimbus, normalization_dict, batch_size=batch_size + input_data, channel_name, nimbus, normalization_dict, batch_size=batch_size ) else: prediction = nimbus.predict_segmentation( input_data, preprocess_kwargs={ - "normalize": True, "marker": channel, + "normalize": True, "marker": channel_name, "normalization_dict": normalization_dict }, ) @@ -236,12 +341,12 @@ def predict_fovs( if df_fov.empty: df_fov["label"] = df["label"] df_fov["fov"] = os.path.basename(fov_path) - df_fov[channel.split(".")[0]] = df["intensity_mean"] + df_fov[channel_name] = df["intensity_mean"] if save_predictions: os.makedirs(out_fov_path, exist_ok=True) pred_int = (prediction*255.0).astype(np.uint8) io.imwrite( - os.path.join(out_fov_path, channel), pred_int, photometric="minisblack", + os.path.join(out_fov_path, channel_name + suffix), pred_int, photometric="minisblack", # compress=0, ) fov_dict_list.append(df_fov) @@ -276,41 +381,40 @@ def nimbus_preprocess(image, **kwargs): return output -def calculate_normalization_ome(ome_path, quantile, include_channels): +def calculate_normalization(dataset: MultiplexDataset, quantile: float, include_channels: list): """Calculates the normalization values for a given ome file Args: - ome_path (str): path to ome file + dataset (MultiplexDataset): dataset object quantile (float): quantile to use for normalization include_channels (list): list of channels to include Returns: normalization_values (dict): dict with channel names as keys and norm factors as values """ - reader = OMETIFFReader(fpath=ome_path) - img_array, metadata, _ = reader.read() - channel_names = list(metadata["Channels"].keys()) + + channel_names = dataset.channels if not include_channels: include_channels = channel_names + if isinstance(include_channels, str): + include_channels = [include_channels] # check if include_channels are included in ome file metadata - for channel in include_channels: - if channel not in channel_names: - raise ValueError(f"Channel {channel} not found in ome file metadata.") + misc_utils.verify_in_list(include_channels=include_channels, dataset_channels=channel_names) normalization_values = {} for channel in include_channels: - idx = channel_names.index(channel) - mplex_img = img_array[idx] + mplex_img = dataset.get_channel(dataset.fovs[0], channel) mplex_img = mplex_img.astype(np.float32) foreground = mplex_img[mplex_img > 0] normalization_values[channel] = np.quantile(foreground, quantile) return normalization_values -def prepare_normalization_dict_ome( - fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1, - output_name="normalization_dict.json" +def prepare_normalization_dict( + dataset: MultiplexDataset, output_dir: str, quantile: float=0.999, + include_channels: list=[], n_subset: int=10, n_jobs: int=1, + output_name: str="normalization_dict.json" ): """Prepares the normalization dict for a list of ome.tif fovs Args: - fov_paths (list): list of paths to fovs + MultiplexDataset (list): list of paths to fovs output_dir (str): path to output directory quantile (float): quantile to use for normalization exclude_channels (list): list of channels to exclude @@ -321,19 +425,26 @@ def prepare_normalization_dict_ome( normalization_dict (dict): dict with channel names as keys and norm factors as values """ normalization_dict = {} + fov_paths = dataset.fov_paths if n_subset is not None: random.shuffle(fov_paths) fov_paths = fov_paths[:n_subset] print("Iterate over fovs...") if n_jobs > 1: normalization_values = Parallel(n_jobs=n_jobs)( - delayed(calculate_normalization_ome)(ome_path, quantile, include_channels) - for ome_path in fov_paths + delayed(calculate_normalization)( + MultiplexDataset( + [fov_path], dataset.segmentation_naming_convention, dataset.suffix + ), quantile, include_channels) + for fov_path in fov_paths ) else: normalization_values = [ - calculate_normalization_ome(ome_path, quantile, include_channels) - for ome_path in fov_paths + calculate_normalization( + MultiplexDataset( + [fov_path], dataset.segmentation_naming_convention, dataset.suffix + ), quantile, include_channels) + for fov_path in fov_paths ] for norm_dict in normalization_values: for channel, normalization_value in norm_dict.items(): @@ -348,84 +459,3 @@ def prepare_normalization_dict_ome( with open(os.path.join(output_dir, output_name), 'w') as f: json.dump(normalization_dict, f) return normalization_dict - - -def predict_ome_fovs( - nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir, - suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4, - test_time_augmentation=True - ): - """Predicts the segmentation map for each mplex channel in each ome.tif fov - Args: - nimbus (Nimbus): nimbus object - fov_paths (list): list of fov paths - normalization_dict (dict): dict with channel names as keys and norm factors as values - segmentation_naming_convention (function): function to get instance mask path from fov path - output_dir (str): path to output dir - suffix (str): suffix of mplex images - include_channels (list): list of channels to include - save_predictions (bool): whether to save predictions - half_resolution (bool): whether to use half resolution - batch_size (int): batch size - test_time_augmentation (bool): whether to use test time augmentation - Returns: - cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell - """ - fov_dict_list = [] - for fov_path in fov_paths: - print(f"Predicting {fov_path}...") - out_fov_path = os.path.join( - os.path.normpath(output_dir), os.path.basename(fov_path).split(".")[0] - ) - df_fov = pd.DataFrame() - reader = OMETIFFReader(fpath=fov_path) - img_array, metadata, _ = reader.read() - channel_names = list(metadata["Channels"].keys()) - instance_path = segmentation_naming_convention(fov_path) - instance_mask = np.squeeze(io.imread(instance_path)) - if not include_channels: - include_channels = channel_names - for channel in tqdm(include_channels): - idx = channel_names.index(channel) - mplex_img = np.squeeze(img_array[idx]) - input_data = prepare_input_data(mplex_img, instance_mask) - if half_resolution: - scale = 0.5 - input_data = np.squeeze(input_data) - _, h,w = input_data.shape - img = cv2.resize(input_data[0], [int(h*scale), int(w*scale)]) - binary_mask = cv2.resize( - input_data[1], [int(h*scale), int(w*scale)], interpolation=0 - ) - input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...] - if test_time_augmentation: - prediction = test_time_aug( - input_data, channel, nimbus, normalization_dict, batch_size=batch_size - ) - - else: - prediction = nimbus.predict_segmentation( - input_data, - preprocess_kwargs={ - "normalize": True, "marker": channel, - "normalization_dict": normalization_dict - }, - ) - prediction = np.squeeze(prediction) - if half_resolution: - prediction = cv2.resize(prediction, (h, w)) - df = pd.DataFrame(segment_mean(instance_mask, prediction)) - if df_fov.empty: - df_fov["label"] = df["label"] - df_fov["fov"] = os.path.basename(fov_path) - df_fov[channel] = df["intensity_mean"] - if save_predictions: - os.makedirs(out_fov_path, exist_ok=True) - pred_int = (prediction*255.0).astype(np.uint8) - io.imwrite( - os.path.join(out_fov_path, channel+".tiff"), pred_int, photometric="minisblack", - # compress=0, - ) - fov_dict_list.append(df_fov) - cell_table = pd.concat(fov_dict_list, ignore_index=True) - return cell_table diff --git a/src/nimbus_inference/viewer_widget.py b/src/nimbus_inference/viewer_widget.py index e406e79..79f6123 100644 --- a/src/nimbus_inference/viewer_widget.py +++ b/src/nimbus_inference/viewer_widget.py @@ -7,25 +7,25 @@ import numpy as np from natsort import natsorted from skimage.segmentation import find_boundaries - +from nimbus_inference.utils import MultiplexDataset class NimbusViewer(object): def __init__( - self, input_dir, output_dir, segmentation_naming_convention=None, img_width='600px' + self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff" ): """Viewer for Nimbus application. Args: - input_dir (str): Path to directory containing individual channels of multiplexed images + dataset (MultiplexDataset): dataset object output_dir (str): Path to directory containing output of Nimbus application. segmentation_naming_convention (fn): Function that maps input path to segmentation path img_width (str): Width of images in viewer. + suffix (str): Suffix of images in dataset. """ self.image_width = img_width - self.input_dir = input_dir + self.dataset = dataset self.output_dir = output_dir - self.segmentation_naming_convention = segmentation_naming_convention - self.fov_names = [os.path.basename(p) for p in os.listdir(output_dir) if \ - os.path.isdir(os.path.join(output_dir, p))] + self.suffix = suffix + self.fov_names = self.dataset.fovs self.fov_names = natsorted(self.fov_names) self.update_button = widgets.Button(description="Update Image") self.update_button.on_click(self.update_button_click) @@ -65,13 +65,58 @@ def select_fov(self, change): Args: change (dict): Change dictionary from ipywidgets. """ - fov_path = os.path.join(self.output_dir, self.fov_select.value) - channels = [ - ch for ch in os.listdir(fov_path) if os.path.isfile(os.path.join(fov_path, ch)) - ] - self.red_select.options = natsorted(channels) - self.green_select.options = natsorted(channels) - self.blue_select.options = natsorted(channels) + channels = natsorted(self.dataset.channels) + self.red_select.options = channels + self.green_select.options = channels + self.blue_select.options = channels + + def overlay(self, composite_image, add_boundaries=False, add_overlay=False): + """Adds overlay to composite image. + Args: + composite_image (np.array): Composite image to add overlay to. + boundaries (bool): Whether to add boundaries to overlay. + Returns: + composite_image (np.array): Composite image with overlay. + """ + seg_img = self.dataset.get_segmentation(self.fov_select.value) + seg_boundaries = find_boundaries(seg_img, mode='inner') + seg_img[seg_boundaries] = 0 + seg_img = np.clip(seg_img, 0, 1) + seg_img = np.repeat(seg_img[..., np.newaxis], 3, axis=-1) * np.max(composite_image) + background_mask = composite_image < np.max(composite_image) * 0.2 + if add_overlay: + composite_image[background_mask] += (seg_img[background_mask] * 0.2).astype( + composite_image.dtype + ) + if add_boundaries: + val = (np.max(composite_image, axis=(0,1))*0.5).astype(composite_image.dtype) + val = np.min(val[val>0]) + composite_image[seg_boundaries] = [val]*3 + else: + seg_boundaries = None + return composite_image, seg_boundaries + + def create_composite_from_dataset(self, path_dict): + """Creates composite image from input paths. + Args: + path_dict (dict): Dictionary of paths to images. + Returns: + composite_image (np.array): Composite image. + """ + for k in ["red", "green", "blue"]: + if k not in path_dict.keys(): + path_dict[k] = None + output_image = [] + for p in list(path_dict.values()): + if p: + img = self.dataset.get_channel(p["fov"], p["channel"]) + output_image.append(img) + else: + p = [p for p in path_dict.values() if p][0] + img = self.dataset.get_channel(p["fov"], p["channel"]) + output_image.append(img*0) + composite_image = np.stack(output_image, axis=-1) + return composite_image def create_composite_image(self, path_dict, add_overlay=True, add_boundaries=False): """Creates composite image from input paths. @@ -94,29 +139,7 @@ def create_composite_image(self, path_dict, add_overlay=True, add_boundaries=Fal output_image.append(img*0) # add overlay of instances composite_image = np.stack(output_image, axis=-1) - if self.segmentation_naming_convention and add_overlay: - fov_path = os.path.split(list(path_dict.values())[0])[0] - seg_path = self.segmentation_naming_convention(fov_path) - seg_img = io.imread(seg_path) - seg_boundaries = find_boundaries(seg_img, mode='inner') - seg_img[seg_boundaries] = 0 - seg_img = np.clip(seg_img, 0, 1) - seg_img = np.repeat(seg_img[..., np.newaxis], 3, axis=-1) * np.max(composite_image) - background_mask = composite_image < np.max(composite_image) * 0.2 - composite_image[background_mask] += (seg_img[background_mask] * 0.2).astype( - composite_image.dtype - ) - elif self.segmentation_naming_convention and add_boundaries: - fov_path = os.path.split(list(path_dict.values())[0])[0] - seg_path = self.segmentation_naming_convention(fov_path) - seg_img = io.imread(seg_path) - seg_boundaries = find_boundaries(seg_img, mode='inner') - val = (np.max(composite_image, axis=(0,1))*0.5).astype(composite_image.dtype) - val = np.min(val[val>0]) - composite_image[seg_boundaries] = [val]*3 - else: - seg_boundaries = None - return composite_image, seg_boundaries + return composite_image def layout(self): """Creates layout for viewer.""" @@ -192,25 +215,30 @@ def update_composite(self): in_path_dict = copy(path_dict) if self.red_select.value: path_dict["red"] = os.path.join( - self.output_dir, self.fov_select.value, self.red_select.value + self.output_dir, self.fov_select.value, self.red_select.value + self.suffix ) - in_path_dict["red"] = self.search_for_similar(self.red_select.value) + in_path_dict["red"] = {"fov": self.fov_select.value, "channel": self.red_select.value} if self.green_select.value: path_dict["green"] = os.path.join( - self.output_dir, self.fov_select.value, self.green_select.value + self.output_dir, self.fov_select.value, self.green_select.value + self.suffix ) - in_path_dict["green"] = self.search_for_similar(self.green_select.value) + in_path_dict["green"] = { + "fov": self.fov_select.value, "channel": self.green_select.value + } if self.blue_select.value: path_dict["blue"] = os.path.join( - self.output_dir, self.fov_select.value, self.blue_select.value + self.output_dir, self.fov_select.value, self.blue_select.value + self.suffix ) - in_path_dict["blue"] = self.search_for_similar(self.blue_select.value) + in_path_dict["blue"] = { + "fov": self.fov_select.value, "channel": self.blue_select.value + } non_none = [p for p in path_dict.values() if p] if not non_none: return - composite_image, _ = self.create_composite_image(path_dict) - in_composite_image, seg_boundaries = self.create_composite_image( - in_path_dict, add_overlay=False, add_boundaries=self.overlay_checkbox.value + composite_image = self.create_composite_image(path_dict) + in_composite_image = self.create_composite_from_dataset(in_path_dict) + in_composite_image, seg_boundaries = self.overlay( + in_composite_image, add_boundaries=self.overlay_checkbox.value ) in_composite_image = in_composite_image / np.quantile( in_composite_image, 0.999, axis=(0,1) @@ -230,4 +258,4 @@ def display(self): """Displays viewer.""" self.select_fov(None) self.layout() - self.update_composite() \ No newline at end of file + self.update_composite() diff --git a/tests/test_cell_analyzer.py b/tests/test_cell_analyzer.py index e34d96d..d156060 100644 --- a/tests/test_cell_analyzer.py +++ b/tests/test_cell_analyzer.py @@ -1,6 +1,6 @@ from nimbus_inference.cell_analyzer import CellAnalyzer from nimbus_inference.nimbus import Nimbus, prep_naming_convention -from tests.test_utils import prepare_ome_tif_data, prepare_tif_data +from tests.test_utils import prepare_tif_data import pandas as pd import numpy as np import tempfile diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py index a954e20..ccdc034 100644 --- a/tests/test_nimbus.py +++ b/tests/test_nimbus.py @@ -1,5 +1,6 @@ from tests.test_utils import prepare_ome_tif_data, prepare_tif_data import tempfile +from nimbus_inference.utils import MultiplexDataset from nimbus_inference.nimbus import Nimbus, prep_naming_convention from nimbus_inference.unet import UNet from skimage.data import astronaut @@ -15,16 +16,15 @@ def test_check_inputs(): selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] fov_paths, _ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - nimbus = Nimbus( - fov_paths=fov_paths, segmentation_naming_convention=naming_convention, - output_dir=temp_dir - ) + dataset = MultiplexDataset(fov_paths, naming_convention) + nimbus = Nimbus(dataset=dataset, output_dir=temp_dir) nimbus.check_inputs() def test_initialize_model(): + dataset = MultiplexDataset(["tests"]) nimbus = Nimbus( - fov_paths=[""], segmentation_naming_convention="", output_dir="", + dataset, output_dir="", input_shape=[512,512], batch_size=4 ) nimbus.initialize_model(padding="valid") @@ -43,19 +43,15 @@ def test_prepare_normalization_dict(): selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] fov_paths,_ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - nimbus = Nimbus( - fov_paths, naming_convention, temp_dir, - include_channels=["CD45", "CD3", "CD8"] - ) + dataset = MultiplexDataset(fov_paths, naming_convention) + nimbus = Nimbus(dataset, temp_dir, include_channels=["CD45", "CD3", "CD8"]) # test if normalization dict gets prepared and saved nimbus.prepare_normalization_dict(overwrite=True) assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) assert "ChyTr" not in nimbus.normalization_dict.keys() # test if normalization dict gets loaded - nimbus_2 = Nimbus( - fov_paths, naming_convention, temp_dir, include_channels=["CD45", "CD3", "CD8"] - ) + nimbus_2 = Nimbus(dataset, temp_dir, include_channels=["CD45", "CD3", "CD8"]) nimbus_2.prepare_normalization_dict() assert nimbus_2.normalization_dict == nimbus.normalization_dict @@ -64,7 +60,8 @@ def test_tile_input(): image = torch.rand([1,2,768,768]) tile_size = (512, 512) output_shape = (320,320) - nimbus = Nimbus(fov_paths=[""], segmentation_naming_convention="", output_dir="") + dataset = MultiplexDataset(["tests"]) + nimbus = Nimbus(MultiplexDataset, output_dir="") nimbus.model = lambda x: x[..., 96:-96, 96:-96] tiled_input, padding = nimbus._tile_input(image, tile_size, output_shape) assert tiled_input.shape == (3,3,1,2,512,512) @@ -76,8 +73,7 @@ def test_tile_and_stitch(): image = rescale(astronaut(), 1.5, channel_axis=-1) image = np.moveaxis(image, -1, 0)[np.newaxis, ...] nimbus = Nimbus( - fov_paths=[""], segmentation_naming_convention="", output_dir="", - input_shape=[512,512], batch_size=4 + dataset="", output_dir="", input_shape=[512,512], batch_size=4 ) # check if tile and stitch works for mock model unequal input and output shape # mock model only center crops the input, so that the stitched output is equal to the input diff --git a/tests/test_utils.py b/tests/test_utils.py index 77b6203..f3c450a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,5 @@ from nimbus_inference.utils import (prepare_normalization_dict, calculate_normalization, -predict_fovs, predict_ome_fovs, calculate_normalization_ome, prepare_normalization_dict_ome, -prepare_input_data) +predict_fovs, prepare_input_data, MultiplexDataset, LazyOMETIFFReader) from nimbus_inference.utils import test_time_aug as tt_aug from nimbus_inference.nimbus import Nimbus from skimage import io @@ -109,44 +108,63 @@ def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, def test_calculate_normalization(): with tempfile.TemporaryDirectory() as temp_dir: - fov_paths, _ = prepare_tif_data( + # test for single channel data + tif_fov_paths, _ = prepare_tif_data( num_samples=1, temp_dir=temp_dir, selected_markers=["CD4"], random=True, std=[0.5] ) channel = "CD4" - channel_path = os.path.join(fov_paths[0], channel + ".tiff") - channel_out, norm_val = calculate_normalization(channel_path, 0.999) - # test if we get the correct channel and normalization value - assert channel_out == channel - assert np.isclose(norm_val, 0.5, 0.01) + tif_dataset = MultiplexDataset(tif_fov_paths, suffix=".tiff") + ome_fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"], random=True, std=[0.5] + ) + ome_dataset = MultiplexDataset(ome_fov_paths, suffix=".ome.tiff") + for dataset in [tif_dataset, ome_dataset]: + norm_dict = calculate_normalization(dataset, 0.999, include_channels=[channel]) + channel_out, norm_val = list(norm_dict.items())[0] + # test if we get the correct channel and normalization value + assert channel_out == channel + assert np.isclose(norm_val, 0.5, 0.01) + + # test if ValueError is raised if include_channels are not in the dataset + with pytest.raises(ValueError): + calculate_normalization(dataset, 0.999, include_channels=["CD42", "CD56"]) def test_prepare_normalization_dict(): with tempfile.TemporaryDirectory() as temp_dir: scales = [0.5, 1.0, 1.5, 2.0, 5.0] channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] - fov_paths, _ = prepare_tif_data( + tif_fov_paths, _ = prepare_tif_data( num_samples=5, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales ) - normalization_dict = prepare_normalization_dict( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, - output_name="normalization_dict.json" - ) - # test if normalization dict got saved - assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) - assert normalization_dict == json.load( - open(os.path.join(temp_dir, "normalization_dict.json")) + tif_dataset = MultiplexDataset(tif_fov_paths, suffix=".tiff") + + # test if everything works for multi channel data + ome_fov_paths, _ = prepare_ome_tif_data( + num_samples=5, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales ) - # test if normalization dict is correct - for channel, scale in zip(channels, scales): - assert np.isclose(normalization_dict[channel], scale, 0.01) + ome_dataset = MultiplexDataset(ome_fov_paths, suffix=".ome.tiff") + for dataset in [tif_dataset, ome_dataset]: + normalization_dict = prepare_normalization_dict( + dataset, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, + output_name="normalization_dict.json" + ) + # test if normalization dict got saved + assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) + assert normalization_dict == json.load( + open(os.path.join(temp_dir, "normalization_dict.json")) + ) + # test if normalization dict is correct + for channel, scale in zip(channels, scales): + assert np.isclose(normalization_dict[channel], scale, 0.01) - # test if multiprocessing yields approximately the same results - normalization_dict_mp = prepare_normalization_dict( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, - output_name="normalization_dict.json" - ) - for key in normalization_dict.keys(): - assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) + # test if multiprocessing yields approximately the same results + normalization_dict_mp = prepare_normalization_dict( + dataset, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, + output_name="normalization_dict.json" + ) + for key in normalization_dict.keys(): + assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) def test_prepare_input_data(): @@ -178,9 +196,8 @@ def segmentation_naming_convention(fov_path): num_samples=1, temp_dir=temp_dir, selected_markers=[channel] ) output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus = Nimbus( - fov_paths, segmentation_naming_convention, output_dir, - ) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") + nimbus = Nimbus(dataset, output_dir) nimbus.prepare_normalization_dict() mplex_img = io.imread(os.path.join(fov_paths[0], channel+".tiff")) instance_mask = io.imread(inst_paths[0]) @@ -223,16 +240,14 @@ def segmentation_naming_convention(fov_path): fov_paths, _ = prepare_tif_data( num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] ) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus = Nimbus( - fov_paths, segmentation_naming_convention, output_dir, - ) + nimbus = Nimbus(dataset, output_dir) output_dir = os.path.join(temp_dir, "nimbus_output") nimbus.prepare_normalization_dict() cell_table = predict_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + nimbus=nimbus, dataset=dataset, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, suffix=".tiff", save_predictions=False, half_resolution=True, ) # check if we get the correct number of cells @@ -247,101 +262,64 @@ def segmentation_naming_convention(fov_path): # # run again with save_predictions=True and check if predictions get written to output_dir cell_table = predict_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + nimbus=nimbus, dataset=dataset, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, suffix=".tiff", save_predictions=True, half_resolution=True, ) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) -def test_calculate_normalization_ome(): +def test_LazyOMETIFFReader(): with tempfile.TemporaryDirectory() as temp_dir: fov_paths, _ = prepare_ome_tif_data( - num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56", "CD45"] + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] ) + reader = LazyOMETIFFReader(fov_paths[0]) + assert hasattr(reader, "metadata") + assert reader.channels == ["CD4", "CD56"] + cd4_channel = reader.get_channel("CD4") + cd56_channel = reader.get_channel("CD56") + assert cd4_channel.shape == (256, 256) + assert cd56_channel.shape == (256, 256) - norm_dict = calculate_normalization_ome( - fov_paths[0], 0.999, include_channels=["CD4", "CD56"] - ) - # check if we get the correct normalization values - assert np.isclose(norm_dict["CD4"], 1.0, 0.01) - assert np.isclose(norm_dict["CD56"], 1.0, 0.01) - # check if ValueError is raised if include_channels are not in the ome.tif metadata - with pytest.raises(ValueError): - calculate_normalization_ome( - fov_paths[0], 0.999, include_channels=["CD42", "CD56"] - ) - -def test_prepare_normalization_dict_ome(): +def test_MultiplexDataset(): with tempfile.TemporaryDirectory() as temp_dir: - scales = [0.5, 1.0, 1.5, 2.0, 5.0] - channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] - fov_paths, _ = prepare_ome_tif_data( - num_samples=3, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales - ) - normalization_dict = prepare_normalization_dict_ome( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, include_channels=channels, - output_name="normalization_dict.json" - ) - # test if normalization dict got saved - assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) - assert normalization_dict == json.load( - open(os.path.join(temp_dir, "normalization_dict.json")) - ) - # test if normalization dict is correct - for channel, scale in zip(channels, scales): - assert np.isclose(normalization_dict[channel], scale, 0.01) + def segmentation_naming_convention(fov_path): + temp_dir_, fov_ = os.path.split(fov_path) + fov_ = fov_.split(".")[0] + return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") - # test if multiprocessing yields approximately the same results - normalization_dict_mp = prepare_normalization_dict_ome( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, include_channels=channels, - output_name="normalization_dict.json" + fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] ) - for key in normalization_dict.keys(): - assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) - + # check if check inputs raises error when inputs are incorrect + with pytest.raises(FileNotFoundError): + dataset = MultiplexDataset(["abc"], segmentation_naming_convention, suffix=".ome.tiff") + # check if we get the correct channels and fov_paths + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".ome.tiff") + assert len(dataset) == 1 + assert dataset.channels == ["CD4", "CD56"] + assert dataset.fov_paths == fov_paths + assert dataset.multi_channel == True + cd4_channel = io.imread(fov_paths[0])[0] + cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4") + assert np.alltrue(cd4_channel == cd4_channel_) + fov_0_seg = io.imread(segmentation_naming_convention(fov_paths[0])) + fov_0_seg_ = dataset.get_segmentation(fov="fov_0") + assert np.alltrue(fov_0_seg == fov_0_seg_) -def test_predict_ome_fovs(): - def segmentation_naming_convention(fov_path): - temp_dir_, fov_ = os.path.split(fov_path) - fov_ = fov_.split(".")[0] - return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") - - with tempfile.TemporaryDirectory() as temp_dir: - fov_paths, _ = prepare_ome_tif_data( + # test everything again with single channel data + fov_paths, _ = prepare_tif_data( num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] ) - output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus = Nimbus( - fov_paths, segmentation_naming_convention, output_dir, suffix=".ome.tiff" - ) - output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus.prepare_normalization_dict() - cell_table = predict_ome_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", - save_predictions=False, half_resolution=True, - ) - # check if we get the correct number of cells - assert len(cell_table) == 15 - # check if we get the correct columns (fov, label, CD4, CD56) - assert np.alltrue( - set(cell_table.columns) == set(["fov", "label", "CD4", "CD56"]) - ) - # check if predictions don't get written to output_dir - assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) - assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) - # - # run again with save_predictions=True and check if predictions get written to output_dir - cell_table = predict_ome_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", - save_predictions=True, half_resolution=True, - ) - assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) - assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") + assert len(dataset) == 1 + assert dataset.channels == ["CD4", "CD56"] + assert dataset.fov_paths == fov_paths + assert dataset.multi_channel == False + cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4") + assert np.alltrue(cd4_channel == cd4_channel_) + fov_0_seg_ = dataset.get_segmentation(fov="fov_0") + assert np.alltrue(fov_0_seg == fov_0_seg_) diff --git a/tests/test_viewer_widget.py b/tests/test_viewer_widget.py index 9e37108..fb266ab 100644 --- a/tests/test_viewer_widget.py +++ b/tests/test_viewer_widget.py @@ -1,5 +1,6 @@ from nimbus_inference.viewer_widget import NimbusViewer from nimbus_inference.nimbus import Nimbus, prep_naming_convention +from nimbus_inference.utils import MultiplexDataset from tests.test_utils import prepare_ome_tif_data, prepare_tif_data import numpy as np import tempfile @@ -8,36 +9,67 @@ def test_NimbusViewer(): with tempfile.TemporaryDirectory() as temp_dir: - _ = prepare_tif_data( + fov_paths, _ = prepare_tif_data( num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] ) - viewer_widget = NimbusViewer(temp_dir, temp_dir) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) assert isinstance(viewer_widget, NimbusViewer) def test_composite_image(): with tempfile.TemporaryDirectory() as temp_dir: - _ = prepare_tif_data( + fov_paths, _ = prepare_tif_data( num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] ) - viewer_widget = NimbusViewer(temp_dir, temp_dir) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) path_dict = { "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), } - composite_image, _ = viewer_widget.create_composite_image(path_dict) + composite_image = viewer_widget.create_composite_image(path_dict) assert isinstance(composite_image, np.ndarray) assert composite_image.shape == (256, 256, 3) path_dict["blue"] = os.path.join(temp_dir, "fov_0", "CD56.tiff") - composite_image, _ = viewer_widget.create_composite_image(path_dict) + composite_image = viewer_widget.create_composite_image(path_dict) assert composite_image.shape == (256, 256, 3) + + +def test_create_composite_from_dataset(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) + path_dict = { + "red": {"fov": "fov_0", "channel": "CD4"}, + "green": {"fov": "fov_0", "channel": "CD11c"}, + } + composite_image = viewer_widget.create_composite_from_dataset(path_dict) + assert isinstance(composite_image, np.ndarray) + assert composite_image.shape == (256, 256, 3) + + +def test_overlay(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + path_dict = { + "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), + "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), + } # test if segmentation gets added naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - viewer_widget = NimbusViewer( - temp_dir, temp_dir, segmentation_naming_convention=naming_convention + dataset = MultiplexDataset(fov_paths, naming_convention) + viewer_widget = NimbusViewer(dataset, temp_dir) + composite_image = viewer_widget.create_composite_image(path_dict) + composite_image, seg_boundaries = viewer_widget.overlay( + composite_image, add_boundaries=True ) - composite_image, seg_boundaries = viewer_widget.create_composite_image(path_dict) assert composite_image.shape == (256, 256, 3) assert seg_boundaries.shape == (256, 256) assert np.unique(seg_boundaries).tolist() == [0, 1] From f455e502ba017ed1184b6bf849e021eab4ee1b4c Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 9 Apr 2024 16:24:20 +0200 Subject: [PATCH 3/4] Sorted out some bugs in the ViewerWidget, added notebook for ome tiff inference --- src/nimbus_inference/nimbus.py | 27 +-- src/nimbus_inference/utils.py | 92 +++++--- src/nimbus_inference/viewer_widget.py | 9 +- templates/1_Nimbus_Predict.ipynb | 34 ++- templates/1_Nimbus_Predict_OME.ipynb | 326 ++++++++++++++++++++++++++ tests/test_nimbus.py | 8 +- tests/test_utils.py | 6 +- 7 files changed, 436 insertions(+), 66 deletions(-) create mode 100644 templates/1_Nimbus_Predict_OME.ipynb diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index a3f1413..3bdffc0 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -66,7 +66,7 @@ def segmentation_naming_convention(fov_path): Returns: seg_path (str): paths to segmentation fovs """ - fov_name = os.path.basename(fov_path) + fov_name = os.path.basename(fov_path).replace(".ome.tiff", "") return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff") return segmentation_naming_convention @@ -77,15 +77,14 @@ class Nimbus(nn.Module): def __init__( self, dataset: MultiplexDataset, output_dir: str, save_predictions: bool=True, - include_channels: list=[], half_resolution: bool=True, batch_size: int=4, - test_time_aug: bool=True, input_shape: list=[1024, 1024], device: str="auto", + half_resolution: bool=True, batch_size: int=4, test_time_aug: bool=True, + input_shape: list=[1024, 1024], device: str="auto", ): """Initializes a Nimbus Application. Args: dataset (MultiplexDataset): Path to directory containing fovs. output_dir (str): Path to directory to save output. save_predictions (bool): Whether to save predictions. - include_channels (list): List of channels to include in analysis. half_resolution (bool): Whether to run model on half resolution images. batch_size (int): Batch size for model inference. test_time_aug (bool): Whether to use test time augmentation. @@ -96,7 +95,6 @@ def __init__( """ super(Nimbus, self).__init__() self.dataset = dataset - self.include_channels = include_channels self.output_dir = output_dir self.half_resolution = half_resolution self.save_predictions = save_predictions @@ -123,8 +121,6 @@ def check_inputs(self): # check if output_dir exists io_utils.validate_paths([self.output_dir]) - if isinstance(self.include_channels, str): - self.include_channels = [self.include_channels] self.checked_inputs = True print("All inputs are valid.") @@ -173,11 +169,9 @@ def prepare_normalization_dict( else: n_jobs = os.cpu_count() if multiprocessing else 1 self.normalization_dict = prepare_normalization_dict( - self.dataset, self.output_dir, quantile, self.include_channels, n_subset, + self.dataset, self.output_dir, quantile, n_subset, n_jobs ) - if self.include_channels == []: - self.include_channels = list(self.normalization_dict.keys()) def predict_fovs(self): """Predicts cell classification for input data. @@ -193,13 +187,12 @@ def predict_fovs(self): print("Available GPUs: ", gpus) print("Predictions will be saved in {}".format(self.output_dir)) print("Iterating through fovs will take a while...") - if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: - self.cell_table = predict_fovs( - nimbus=self, dataset=self.dataset, output_dir=self.output_dir, - normalization_dict=self.normalization_dict, save_predictions=self.save_predictions, - half_resolution=self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, suffix=self.suffix, - ) + self.cell_table = predict_fovs( + nimbus=self, dataset=self.dataset, output_dir=self.output_dir, + normalization_dict=self.normalization_dict, save_predictions=self.save_predictions, + half_resolution=self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, suffix=self.dataset.suffix, + ) self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False) return self.cell_table diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 34a44f1..9e35350 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -6,16 +6,31 @@ import numpy as np import pandas as pd import imageio as io +from copy import copy from tqdm.autonotebook import tqdm from joblib import Parallel, delayed from joblib.externals.loky import get_reusable_executor from skimage.segmentation import find_boundaries from skimage.measure import regionprops_table from pyometiff import OMETIFFReader +from pyometiff.omexml import OMEXML from alpineer import io_utils, misc_utils from typing import Callable import tifffile import zarr +import sys, os +import logging +import os, sys + + +class HidePrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout class LazyOMETIFFReader(OMETIFFReader): @@ -37,7 +52,9 @@ def get_metadata(self): with tifffile.TiffFile(str(self.fpath)) as tif: if tif.is_ome: omexml_string = tif.ome_metadata - return self.parse_metadata(omexml_string) + with HidePrints(): + metadata = self.parse_metadata(omexml_string) + return metadata else: raise ValueError("File is not an OME-TIFF file.") @@ -71,8 +88,8 @@ def get_channel(self, channel_name: str): idx = self.channels.index(channel_name) with tifffile.imread(str(self.fpath), aszarr=True) as store: z = zarr.open(store, mode='r') - # correct DimOrder, often DimOrder is TZCYX, but image is stored as CYX, thus we remove - # the trailing dimensions + # correct DimOrder, often DimOrder is TZCYX, but image is stored as CYX, + # thus we remove the trailing dimensions dim_order = self.metadata["DimOrder"] dim_order = dim_order[-len(z.shape):] channel_idx = dim_order.find("C") @@ -84,7 +101,7 @@ def get_channel(self, channel_name: str): class MultiplexDataset(): def __init__( self, fov_paths: list, segmentation_naming_convention: Callable = None, - suffix: str = ".tiff" + include_channels: list = [], suffix: str = ".tiff", silent=False, ): """Multiplex dataset class that gives a common interface for data loading of multiplex datasets stored as individual channel images in folders or as multi-channel tiffs. @@ -92,14 +109,30 @@ def __init__( fov_paths (list): list of paths to fovs segmentation_naming_convention (function): function to get instance mask path from fov path + suffix (str): suffix of channel images + silent (bool): whether to print messages """ self.fov_paths = fov_paths self.segmentation_naming_convention = segmentation_naming_convention self.suffix = suffix - self.check_inputs() + self.silent = silent + self.include_channels = include_channels self.multi_channel = self.is_multi_channel_tiff(fov_paths[0]) - self.fovs = self.get_fovs() self.channels = self.get_channels() + self.check_inputs() + self.fovs = self.get_fovs() + self.channels = self.filter_channels(self.channels) + + def filter_channels(self, channels): + """Filter channels based on include_channels + Args: + channels (list): list of channel names + Returns: + channels (list): filtered list of channel names + """ + if self.include_channels: + return [channel for channel in channels if channel in self.include_channels] + return channels def check_inputs(self): """check inputs for Nimbus model""" @@ -107,7 +140,13 @@ def check_inputs(self): if not isinstance(self.fov_paths, (list, tuple)): self.fov_paths = [self.fov_paths] io_utils.validate_paths(self.fov_paths) - print("All inputs are valid") + if isinstance(self.include_channels, str): + self.include_channels = [self.include_channels] + misc_utils.verify_in_list( + include_channels=self.include_channels, dataset_channels=self.channels + ) + if not self.silent: + print("All inputs are valid") def __len__(self): """Return the number of fovs in the dataset""" @@ -143,7 +182,7 @@ def get_fovs(self): return [os.path.basename(fov).replace(self.suffix, "") for fov in self.fov_paths] def get_channel(self, fov: str, channel: str): - """Get the channel from a fov + """Get a channel from a fov Args: fov (str): name of a fov channel (str): channel name @@ -156,7 +195,7 @@ def get_channel(self, fov: str, channel: str): return self.get_channel_single(fov, channel) def get_channel_single(self, fov: str, channel: str): - """Get the channel from a fov stored as a folder with individual channel images + """Get a channel from a fov stored as a folder with individual channel images Args: fov (str): name of a fov channel (str): channel name @@ -170,7 +209,7 @@ def get_channel_single(self, fov: str, channel: str): return channel def get_channel_stack(self, fov: str, channel: str): - """Get the channel from a multi-channel tiff + """Get a channel from a multi-channel tiff Args: fov (str): name of a fov channel (str): channel name @@ -235,7 +274,7 @@ def test_time_aug( Args: input_data (np.array): input data for segmentation model, mplex_img and binary mask channel (str): channel name - app (tf.keras.Model): segmentation model + app (Nimbus): segmentation model normalization_dict (dict): dict with channel names as keys and norm factors as values rotate (bool): whether to rotate flip (bool): whether to flip @@ -381,25 +420,16 @@ def nimbus_preprocess(image, **kwargs): return output -def calculate_normalization(dataset: MultiplexDataset, quantile: float, include_channels: list): +def calculate_normalization(dataset: MultiplexDataset, quantile: float): """Calculates the normalization values for a given ome file Args: dataset (MultiplexDataset): dataset object quantile (float): quantile to use for normalization - include_channels (list): list of channels to include Returns: normalization_values (dict): dict with channel names as keys and norm factors as values """ - - channel_names = dataset.channels - if not include_channels: - include_channels = channel_names - if isinstance(include_channels, str): - include_channels = [include_channels] - # check if include_channels are included in ome file metadata - misc_utils.verify_in_list(include_channels=include_channels, dataset_channels=channel_names) normalization_values = {} - for channel in include_channels: + for channel in dataset.channels: mplex_img = dataset.get_channel(dataset.fovs[0], channel) mplex_img = mplex_img.astype(np.float32) foreground = mplex_img[mplex_img > 0] @@ -408,16 +438,14 @@ def calculate_normalization(dataset: MultiplexDataset, quantile: float, include_ def prepare_normalization_dict( - dataset: MultiplexDataset, output_dir: str, quantile: float=0.999, - include_channels: list=[], n_subset: int=10, n_jobs: int=1, - output_name: str="normalization_dict.json" + dataset: MultiplexDataset, output_dir: str, quantile: float=0.999, n_subset: int=10, + n_jobs: int=1, output_name: str="normalization_dict.json" ): """Prepares the normalization dict for a list of ome.tif fovs Args: MultiplexDataset (list): list of paths to fovs output_dir (str): path to output directory quantile (float): quantile to use for normalization - exclude_channels (list): list of channels to exclude n_subset (int): number of fovs to use for normalization n_jobs (int): number of jobs to use for joblib multiprocessing output_name (str): name of output file @@ -425,7 +453,7 @@ def prepare_normalization_dict( normalization_dict (dict): dict with channel names as keys and norm factors as values """ normalization_dict = {} - fov_paths = dataset.fov_paths + fov_paths = copy(dataset.fov_paths) if n_subset is not None: random.shuffle(fov_paths) fov_paths = fov_paths[:n_subset] @@ -434,16 +462,18 @@ def prepare_normalization_dict( normalization_values = Parallel(n_jobs=n_jobs)( delayed(calculate_normalization)( MultiplexDataset( - [fov_path], dataset.segmentation_naming_convention, dataset.suffix - ), quantile, include_channels) + [fov_path], dataset.segmentation_naming_convention, dataset.channels, + dataset.suffix, True + ), quantile) for fov_path in fov_paths ) else: normalization_values = [ calculate_normalization( MultiplexDataset( - [fov_path], dataset.segmentation_naming_convention, dataset.suffix - ), quantile, include_channels) + [fov_path], dataset.segmentation_naming_convention, dataset.channels, + dataset.suffix, True + ), quantile) for fov_path in fov_paths ] for norm_dict in normalization_values: diff --git a/src/nimbus_inference/viewer_widget.py b/src/nimbus_inference/viewer_widget.py index 79f6123..db5ec90 100644 --- a/src/nimbus_inference/viewer_widget.py +++ b/src/nimbus_inference/viewer_widget.py @@ -25,8 +25,7 @@ def __init__( self.dataset = dataset self.output_dir = output_dir self.suffix = suffix - self.fov_names = self.dataset.fovs - self.fov_names = natsorted(self.fov_names) + self.fov_names = natsorted(copy(self.dataset.fovs)) self.update_button = widgets.Button(description="Update Image") self.update_button.on_click(self.update_button_click) self.overlay_checkbox = widgets.Checkbox( @@ -65,7 +64,7 @@ def select_fov(self, change): Args: change (dict): Change dictionary from ipywidgets. """ - channels = natsorted(self.dataset.channels) + channels = natsorted(copy(self.dataset.channels)) self.red_select.options = channels self.green_select.options = channels self.blue_select.options = channels @@ -236,6 +235,10 @@ def update_composite(self): if not non_none: return composite_image = self.create_composite_image(path_dict) + composite_image, _ = self.overlay( + composite_image, add_overlay=True + ) + in_composite_image = self.create_composite_from_dataset(in_path_dict) in_composite_image, seg_boundaries = self.overlay( in_composite_image, add_boundaries=self.overlay_checkbox.value diff --git a/templates/1_Nimbus_Predict.ipynb b/templates/1_Nimbus_Predict.ipynb index 9da6ce4..6e3c254 100644 --- a/templates/1_Nimbus_Predict.ipynb +++ b/templates/1_Nimbus_Predict.ipynb @@ -22,6 +22,7 @@ "from IPython.display import display, HTML\n", "display(HTML(\"\"))\n", "from nimbus_inference.nimbus import Nimbus, prep_naming_convention\n", + "from nimbus_inference.utils import MultiplexDataset\n", "from alpineer import io_utils\n", "from ark.utils import example_dataset\n", "from nimbus_inference.viewer_widget import NimbusViewer" @@ -53,8 +54,7 @@ "outputs": [], "source": [ "# set up the base directory\n", - "base_dir = os.path.normpath(\"../data/example_dataset\")\n", - "# base_dir = os.path.normpath(\"C:/Users/lorenz/Desktop/angelo_lab/data/example_dataset\")" + "base_dir = os.path.normpath(\"../data/example_dataset\")" ] }, { @@ -172,6 +172,29 @@ " print(\"Segmentation data does not exist for fov 0 or naming convention is incorrect\")" ] }, + { + "cell_type": "markdown", + "id": "e7717960", + "metadata": {}, + "source": [ + "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50997492", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = MultiplexDataset(\n", + " fov_paths=fov_paths,\n", + " suffix=\".tiff\",\n", + " include_channels=include_channels,\n", + " segmentation_naming_convention=segmentation_naming_convention,\n", + ")" + ] + }, { "cell_type": "markdown", "id": "839e5240", @@ -189,15 +212,12 @@ "outputs": [], "source": [ "nimbus = Nimbus(\n", - " fov_paths=fov_paths,\n", - " segmentation_naming_convention=segmentation_naming_convention,\n", + " dataset=dataset,\n", " output_dir=nimbus_output_dir,\n", - " include_channels=include_channels,\n", " save_predictions=True,\n", " batch_size=4,\n", " test_time_aug=True,\n", " input_shape=[1024,1024],\n", - " suffix=\".tiff\",\n", " device=\"auto\",\n", ")\n", "\n", @@ -275,7 +295,7 @@ "metadata": {}, "outputs": [], "source": [ - "viewer = NimbusViewer(input_dir=tiff_dir, output_dir=nimbus_output_dir)\n", + "viewer = NimbusViewer(dataset=dataset, output_dir=nimbus_output_dir)\n", "viewer.display()" ] } diff --git a/templates/1_Nimbus_Predict_OME.ipynb b/templates/1_Nimbus_Predict_OME.ipynb new file mode 100644 index 0000000..5c83080 --- /dev/null +++ b/templates/1_Nimbus_Predict_OME.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "673d8e3a", + "metadata": {}, + "source": [ + "# Nimbus prediction notebook for stack .ome.tiff files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f920e689", + "metadata": {}, + "outputs": [], + "source": [ + "# import required packages\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\")\n", + "import os\n", + "from IPython.display import display, HTML\n", + "display(HTML(\"\"))\n", + "from nimbus_inference.nimbus import Nimbus, prep_naming_convention\n", + "from nimbus_inference.utils import MultiplexDataset\n", + "from alpineer import io_utils\n", + "from ark.utils import example_dataset\n", + "from nimbus_inference.viewer_widget import NimbusViewer" + ] + }, + { + "cell_type": "markdown", + "id": "e4642fe2", + "metadata": {}, + "source": [ + "## 0: Set root directory and download example dataset\n", + "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure:\n", + "```\n", + "|-- base_dir\n", + "| |-- image_data\n", + "| | |-- fov1.ome.tiff\n", + "| | |-- fov2.ome.tiff\n", + "| |-- segmentation\n", + "| | |-- deepcell_output\n", + "| | | |-- fov1_whole_cell.tiff\n", + "| | | |-- fov2_whole_cell.tiff\n", + "| |-- nimbus_output\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "974f8dda", + "metadata": {}, + "outputs": [], + "source": [ + "# set up the base directory\n", + "base_dir = os.path.normpath(\"../data/example_dataset\")" + ] + }, + { + "cell_type": "markdown", + "id": "0ade450f", + "metadata": {}, + "source": [ + "If you would like to test Nimbus with an example dataset, run the cell below. It will download a dataset consisting of 10 FOVs with 22 channels. You may find more information about the example dataset in the [ark-analysis README](https://github.com/angelolab/ark-analysis/blob/bc6685050dfbef4607874fbbadebd4289251c173/README.md#example-dataset). If you want to use your own data, skip the cell below\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37733de5", + "metadata": {}, + "outputs": [], + "source": [ + "example_dataset.get_example_dataset(dataset=\"cluster_pixels\", save_dir = base_dir, overwrite_existing = False)" + ] + }, + { + "cell_type": "markdown", + "id": "9cd2ab6c", + "metadata": {}, + "source": [ + "## 1: set file paths and parameters\n", + "\n", + "### All data, images, files, etc. must be placed in the 'data' directory, and referenced via '../data/path_to_your_data'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "292e4524", + "metadata": {}, + "outputs": [], + "source": [ + "# set up file paths\n", + "tiff_dir = os.path.join(base_dir, \"ome_image_data\")\n", + "deepcell_output_dir = os.path.join(base_dir, \"segmentation\", \"deepcell_output\")\n", + "nimbus_output_dir = os.path.join(base_dir, \"nimbus_output\")\n", + "\n", + "# Create nimbus output directory\n", + "os.makedirs(nimbus_output_dir, exist_ok=True)\n", + "\n", + "# Check if paths exist\n", + "io_utils.validate_paths([base_dir, tiff_dir, deepcell_output_dir, nimbus_output_dir])" + ] + }, + { + "cell_type": "markdown", + "id": "ae89442a", + "metadata": {}, + "source": [ + "## 2: Set up input paths and the naming convention for the segmentation data\n", + "Store names of channels to exclude in the list below. Either predict all FOVs or specify manually the ones you want to apply Nimbus on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65a319c9", + "metadata": {}, + "outputs": [], + "source": [ + "# define the channels to include\n", + "include_channels = [\n", + " \"CD3\", \"CD4\", \"CD8\", \"CD14\", \"CD20\", \"CD31\", \"CD45\", \"CD68\", \"CD163\", \"CK17\", \"Collagen1\",\n", + " \"ECAD\", \"Fibronectin\", \"GLUT1\", \"HLADR\", \"IDO\", \"Ki67\", \"PD1\", \"SMA\", \"Vim\"\n", + "]\n", + "\n", + "# either get all fovs in the folder...\n", + "fov_names = os.listdir(tiff_dir)\n", + "# ... or optionally, select a specific set of fovs manually\n", + "# fovs = [\"fov0\", \"fov1\"]\n", + "\n", + "# construct paths for fovs\n", + "fov_paths = [os.path.join(tiff_dir, fov_name) for fov_name in fov_names]" + ] + }, + { + "cell_type": "markdown", + "id": "8c85f682", + "metadata": {}, + "source": [ + "Define the naming convention for the segmentation data in function `segmentation_naming_convention`, that maps the `fov_name` to the path of the associated segmentation output. The below function `prep_deepcell_naming_convention` assumes that all segmentation outputs are stored in one folder, with the `fov_name` as the prefix and `_whole_cell.tiff` as the suffix, as shown below in the visualization of the folder structure. If this does not apply to your data, you have to define a function `segmentation_naming_convention` that takes an element from `fov_paths` and returns a valid path to the segmentation label map you want to use for that fov.\n", + "\n", + "```\n", + "|-- base_dir\n", + "| |-- image_data\n", + "| | |-- fov1.ome.tiff\n", + "| | |-- fov2.ome.tiff\n", + "| |-- segmentation\n", + "| | |-- deepcell_output\n", + "| | | |-- fov1_whole_cell.tiff\n", + "| | | |-- fov2_whole_cell.tiff\n", + "| |-- nimbus_output\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc8256e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare segmentation naming convention that maps a fov_path to the according segmentation label map\n", + "segmentation_naming_convention = prep_naming_convention(deepcell_output_dir)\n", + "\n", + "# test segmentation_naming_convention\n", + "if os.path.exists(segmentation_naming_convention(fov_paths[0])):\n", + " print(\"Segmentation data exists for fov 0 and naming convention is correct\")\n", + "else:\n", + " print(\"Segmentation data does not exist for fov 0 or naming convention is incorrect\")" + ] + }, + { + "cell_type": "markdown", + "id": "e7717960", + "metadata": {}, + "source": [ + "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix. When you use .ome.tiff files make sure to set the suffix to `.ome.tiff`, otherwise the ViewerWidget won't be able to display the images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50997492", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = MultiplexDataset(\n", + " fov_paths=fov_paths,\n", + " suffix=\".ome.tiff\",\n", + " include_channels=include_channels,\n", + " segmentation_naming_convention=segmentation_naming_convention,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "839e5240", + "metadata": {}, + "source": [ + "## 3: Load model and initialize Nimbus application\n", + "The following code initializes the Nimbus application and loads the model checkpoint. The model was trained on a diverse set of tissues, protein markers, imaging platforms and cell types and doesn't need re-training. If you want to use the model on a machine without GPU, set `test_time_aug=False` to speed up inference. If you run it on a laptop GPU and run into out-of-memory errors, consider reducing the `batch_size` to 1 and the `input_shape` to `[512,512]`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fd0a575", + "metadata": {}, + "outputs": [], + "source": [ + "nimbus = Nimbus(\n", + " dataset=dataset,\n", + " output_dir=nimbus_output_dir,\n", + " save_predictions=True,\n", + " batch_size=4,\n", + " test_time_aug=True,\n", + " input_shape=[1024,1024],\n", + " device=\"auto\",\n", + ")\n", + "\n", + "# check if all inputs are valid\n", + "nimbus.check_inputs()" + ] + }, + { + "cell_type": "markdown", + "id": "bbce682e", + "metadata": {}, + "source": [ + "## 4: Prepare normalization dictionary \n", + "The next step is to iterate through all the fovs and calculate the 0.999 marker expression quantile for each marker individually. This is used for normalizing the marker expressions prior to predicting marker confidence scores with our model. You can set `n_subset` to estimate the quantiles on a small subset of the data and you can set `multiprocessing=True` to speed up computation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b100e7", + "metadata": {}, + "outputs": [], + "source": [ + "nimbus.prepare_normalization_dict(\n", + " n_subset=50,\n", + " multiprocessing=True,\n", + " overwrite=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e782794", + "metadata": {}, + "source": [ + "## 5: Make predictions with the model\n", + "Nimbus will iterate through your samples and store predictions and a file named `nimbus_cell_table.csv` that contains the mean-per-cell predicted marker confidence scores in the sub-directory called `nimbus_output`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76225704", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "cell_table = nimbus.predict_fovs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca222e0e", + "metadata": {}, + "outputs": [], + "source": [ + "cell_table" + ] + }, + { + "cell_type": "markdown", + "id": "fdef2ab9", + "metadata": {}, + "source": [ + "## 6: View multiplexed channels and Nimbus predictions side-by-side\n", + "Select an FOV and one marker image per channel to inspect the imaging data and associated Nimbus predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f95e351", + "metadata": {}, + "outputs": [], + "source": [ + "viewer = NimbusViewer(dataset=dataset, output_dir=nimbus_output_dir)\n", + "viewer.display()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py index ccdc034..cc2955d 100644 --- a/tests/test_nimbus.py +++ b/tests/test_nimbus.py @@ -43,15 +43,17 @@ def test_prepare_normalization_dict(): selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] fov_paths,_ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - dataset = MultiplexDataset(fov_paths, naming_convention) - nimbus = Nimbus(dataset, temp_dir, include_channels=["CD45", "CD3", "CD8"]) + dataset = MultiplexDataset( + fov_paths, naming_convention, include_channels=["CD45", "CD3", "CD8"] + ) + nimbus = Nimbus(dataset, temp_dir) # test if normalization dict gets prepared and saved nimbus.prepare_normalization_dict(overwrite=True) assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) assert "ChyTr" not in nimbus.normalization_dict.keys() # test if normalization dict gets loaded - nimbus_2 = Nimbus(dataset, temp_dir, include_channels=["CD45", "CD3", "CD8"]) + nimbus_2 = Nimbus(dataset, temp_dir) nimbus_2.prepare_normalization_dict() assert nimbus_2.normalization_dict == nimbus.normalization_dict diff --git a/tests/test_utils.py b/tests/test_utils.py index f3c450a..f1195be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -119,15 +119,11 @@ def test_calculate_normalization(): ) ome_dataset = MultiplexDataset(ome_fov_paths, suffix=".ome.tiff") for dataset in [tif_dataset, ome_dataset]: - norm_dict = calculate_normalization(dataset, 0.999, include_channels=[channel]) + norm_dict = calculate_normalization(dataset, 0.999) channel_out, norm_val = list(norm_dict.items())[0] # test if we get the correct channel and normalization value assert channel_out == channel assert np.isclose(norm_val, 0.5, 0.01) - - # test if ValueError is raised if include_channels are not in the dataset - with pytest.raises(ValueError): - calculate_normalization(dataset, 0.999, include_channels=["CD42", "CD56"]) def test_prepare_normalization_dict(): From e08fe2ace88eade002c068051ad8c3162688dbbc Mon Sep 17 00:00:00 2001 From: Lenz Date: Wed, 10 Apr 2024 09:55:49 +0200 Subject: [PATCH 4/4] Merged predict notebooks into one --- templates/1_Nimbus_Predict.ipynb | 7 +- templates/1_Nimbus_Predict_OME.ipynb | 326 --------------------------- 2 files changed, 4 insertions(+), 329 deletions(-) delete mode 100644 templates/1_Nimbus_Predict_OME.ipynb diff --git a/templates/1_Nimbus_Predict.ipynb b/templates/1_Nimbus_Predict.ipynb index 6e3c254..f8067bb 100644 --- a/templates/1_Nimbus_Predict.ipynb +++ b/templates/1_Nimbus_Predict.ipynb @@ -34,7 +34,8 @@ "metadata": {}, "source": [ "## 0: Set root directory and download example dataset\n", - "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure:\n", + "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure, with `fov_1` and `fov_2` either being folders of individual channel images or `.ome.tiff` files that contain all channels in a single file.\n", + "```bash\n", "```\n", "|-- base_dir\n", "| |-- image_data\n", @@ -177,7 +178,7 @@ "id": "e7717960", "metadata": {}, "source": [ - "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix." + "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix. When you use `.ome.tiff` files make sure to set the suffix to `.ome.tiff`, otherwise the ViewerWidget won't be able to display the images." ] }, { @@ -189,7 +190,7 @@ "source": [ "dataset = MultiplexDataset(\n", " fov_paths=fov_paths,\n", - " suffix=\".tiff\",\n", + " suffix=\".tiff\", # or .png, .jpg, .jpeg, .tif or .ome.tiff\n", " include_channels=include_channels,\n", " segmentation_naming_convention=segmentation_naming_convention,\n", ")" diff --git a/templates/1_Nimbus_Predict_OME.ipynb b/templates/1_Nimbus_Predict_OME.ipynb deleted file mode 100644 index 5c83080..0000000 --- a/templates/1_Nimbus_Predict_OME.ipynb +++ /dev/null @@ -1,326 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "673d8e3a", - "metadata": {}, - "source": [ - "# Nimbus prediction notebook for stack .ome.tiff files" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f920e689", - "metadata": {}, - "outputs": [], - "source": [ - "# import required packages\n", - "import warnings\n", - "warnings.simplefilter(\"ignore\")\n", - "import os\n", - "from IPython.display import display, HTML\n", - "display(HTML(\"\"))\n", - "from nimbus_inference.nimbus import Nimbus, prep_naming_convention\n", - "from nimbus_inference.utils import MultiplexDataset\n", - "from alpineer import io_utils\n", - "from ark.utils import example_dataset\n", - "from nimbus_inference.viewer_widget import NimbusViewer" - ] - }, - { - "cell_type": "markdown", - "id": "e4642fe2", - "metadata": {}, - "source": [ - "## 0: Set root directory and download example dataset\n", - "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure:\n", - "```\n", - "|-- base_dir\n", - "| |-- image_data\n", - "| | |-- fov1.ome.tiff\n", - "| | |-- fov2.ome.tiff\n", - "| |-- segmentation\n", - "| | |-- deepcell_output\n", - "| | | |-- fov1_whole_cell.tiff\n", - "| | | |-- fov2_whole_cell.tiff\n", - "| |-- nimbus_output\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "974f8dda", - "metadata": {}, - "outputs": [], - "source": [ - "# set up the base directory\n", - "base_dir = os.path.normpath(\"../data/example_dataset\")" - ] - }, - { - "cell_type": "markdown", - "id": "0ade450f", - "metadata": {}, - "source": [ - "If you would like to test Nimbus with an example dataset, run the cell below. It will download a dataset consisting of 10 FOVs with 22 channels. You may find more information about the example dataset in the [ark-analysis README](https://github.com/angelolab/ark-analysis/blob/bc6685050dfbef4607874fbbadebd4289251c173/README.md#example-dataset). If you want to use your own data, skip the cell below\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37733de5", - "metadata": {}, - "outputs": [], - "source": [ - "example_dataset.get_example_dataset(dataset=\"cluster_pixels\", save_dir = base_dir, overwrite_existing = False)" - ] - }, - { - "cell_type": "markdown", - "id": "9cd2ab6c", - "metadata": {}, - "source": [ - "## 1: set file paths and parameters\n", - "\n", - "### All data, images, files, etc. must be placed in the 'data' directory, and referenced via '../data/path_to_your_data'\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "292e4524", - "metadata": {}, - "outputs": [], - "source": [ - "# set up file paths\n", - "tiff_dir = os.path.join(base_dir, \"ome_image_data\")\n", - "deepcell_output_dir = os.path.join(base_dir, \"segmentation\", \"deepcell_output\")\n", - "nimbus_output_dir = os.path.join(base_dir, \"nimbus_output\")\n", - "\n", - "# Create nimbus output directory\n", - "os.makedirs(nimbus_output_dir, exist_ok=True)\n", - "\n", - "# Check if paths exist\n", - "io_utils.validate_paths([base_dir, tiff_dir, deepcell_output_dir, nimbus_output_dir])" - ] - }, - { - "cell_type": "markdown", - "id": "ae89442a", - "metadata": {}, - "source": [ - "## 2: Set up input paths and the naming convention for the segmentation data\n", - "Store names of channels to exclude in the list below. Either predict all FOVs or specify manually the ones you want to apply Nimbus on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65a319c9", - "metadata": {}, - "outputs": [], - "source": [ - "# define the channels to include\n", - "include_channels = [\n", - " \"CD3\", \"CD4\", \"CD8\", \"CD14\", \"CD20\", \"CD31\", \"CD45\", \"CD68\", \"CD163\", \"CK17\", \"Collagen1\",\n", - " \"ECAD\", \"Fibronectin\", \"GLUT1\", \"HLADR\", \"IDO\", \"Ki67\", \"PD1\", \"SMA\", \"Vim\"\n", - "]\n", - "\n", - "# either get all fovs in the folder...\n", - "fov_names = os.listdir(tiff_dir)\n", - "# ... or optionally, select a specific set of fovs manually\n", - "# fovs = [\"fov0\", \"fov1\"]\n", - "\n", - "# construct paths for fovs\n", - "fov_paths = [os.path.join(tiff_dir, fov_name) for fov_name in fov_names]" - ] - }, - { - "cell_type": "markdown", - "id": "8c85f682", - "metadata": {}, - "source": [ - "Define the naming convention for the segmentation data in function `segmentation_naming_convention`, that maps the `fov_name` to the path of the associated segmentation output. The below function `prep_deepcell_naming_convention` assumes that all segmentation outputs are stored in one folder, with the `fov_name` as the prefix and `_whole_cell.tiff` as the suffix, as shown below in the visualization of the folder structure. If this does not apply to your data, you have to define a function `segmentation_naming_convention` that takes an element from `fov_paths` and returns a valid path to the segmentation label map you want to use for that fov.\n", - "\n", - "```\n", - "|-- base_dir\n", - "| |-- image_data\n", - "| | |-- fov1.ome.tiff\n", - "| | |-- fov2.ome.tiff\n", - "| |-- segmentation\n", - "| | |-- deepcell_output\n", - "| | | |-- fov1_whole_cell.tiff\n", - "| | | |-- fov2_whole_cell.tiff\n", - "| |-- nimbus_output\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fc8256e6", - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare segmentation naming convention that maps a fov_path to the according segmentation label map\n", - "segmentation_naming_convention = prep_naming_convention(deepcell_output_dir)\n", - "\n", - "# test segmentation_naming_convention\n", - "if os.path.exists(segmentation_naming_convention(fov_paths[0])):\n", - " print(\"Segmentation data exists for fov 0 and naming convention is correct\")\n", - "else:\n", - " print(\"Segmentation data does not exist for fov 0 or naming convention is incorrect\")" - ] - }, - { - "cell_type": "markdown", - "id": "e7717960", - "metadata": {}, - "source": [ - "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix. When you use .ome.tiff files make sure to set the suffix to `.ome.tiff`, otherwise the ViewerWidget won't be able to display the images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "50997492", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = MultiplexDataset(\n", - " fov_paths=fov_paths,\n", - " suffix=\".ome.tiff\",\n", - " include_channels=include_channels,\n", - " segmentation_naming_convention=segmentation_naming_convention,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "839e5240", - "metadata": {}, - "source": [ - "## 3: Load model and initialize Nimbus application\n", - "The following code initializes the Nimbus application and loads the model checkpoint. The model was trained on a diverse set of tissues, protein markers, imaging platforms and cell types and doesn't need re-training. If you want to use the model on a machine without GPU, set `test_time_aug=False` to speed up inference. If you run it on a laptop GPU and run into out-of-memory errors, consider reducing the `batch_size` to 1 and the `input_shape` to `[512,512]`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fd0a575", - "metadata": {}, - "outputs": [], - "source": [ - "nimbus = Nimbus(\n", - " dataset=dataset,\n", - " output_dir=nimbus_output_dir,\n", - " save_predictions=True,\n", - " batch_size=4,\n", - " test_time_aug=True,\n", - " input_shape=[1024,1024],\n", - " device=\"auto\",\n", - ")\n", - "\n", - "# check if all inputs are valid\n", - "nimbus.check_inputs()" - ] - }, - { - "cell_type": "markdown", - "id": "bbce682e", - "metadata": {}, - "source": [ - "## 4: Prepare normalization dictionary \n", - "The next step is to iterate through all the fovs and calculate the 0.999 marker expression quantile for each marker individually. This is used for normalizing the marker expressions prior to predicting marker confidence scores with our model. You can set `n_subset` to estimate the quantiles on a small subset of the data and you can set `multiprocessing=True` to speed up computation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41b100e7", - "metadata": {}, - "outputs": [], - "source": [ - "nimbus.prepare_normalization_dict(\n", - " n_subset=50,\n", - " multiprocessing=True,\n", - " overwrite=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "9e782794", - "metadata": {}, - "source": [ - "## 5: Make predictions with the model\n", - "Nimbus will iterate through your samples and store predictions and a file named `nimbus_cell_table.csv` that contains the mean-per-cell predicted marker confidence scores in the sub-directory called `nimbus_output`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "76225704", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "cell_table = nimbus.predict_fovs()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca222e0e", - "metadata": {}, - "outputs": [], - "source": [ - "cell_table" - ] - }, - { - "cell_type": "markdown", - "id": "fdef2ab9", - "metadata": {}, - "source": [ - "## 6: View multiplexed channels and Nimbus predictions side-by-side\n", - "Select an FOV and one marker image per channel to inspect the imaging data and associated Nimbus predictions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f95e351", - "metadata": {}, - "outputs": [], - "source": [ - "viewer = NimbusViewer(dataset=dataset, output_dir=nimbus_output_dir)\n", - "viewer.display()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}