diff --git a/examples/map_to_hsv.py b/examples/map_to_hsv.py new file mode 100644 index 00000000..cae6bd15 --- /dev/null +++ b/examples/map_to_hsv.py @@ -0,0 +1,84 @@ +from mantis.analysis import visualization as vz +from pathlib import Path +from iohub import open_ome_zarr +import numpy as np +import glob +from natsort import natsorted +from mantis.cli import utils + +if __name__ == "__main__": + HSV_method = "JCh" # "RO" or "PRO" or "JCh" + hsv_channels = ["Retardance - recon", "Orientation - recon"] + num_processes = 10 + input_data_paths = "/hpc/projects/comp.micro/zebrafish/2023_02_02_zebrafish_casper/2-prototype-reconstruction/fov6-reconstruction.zarr/*/*/*" + output_data_path = f"./test_{HSV_method}_new_method2.zarr" + + input_data_paths = [Path(path) for path in natsorted(glob.glob(input_data_paths))] + output_data_path = Path(output_data_path) + + # Taking the input sample + with open_ome_zarr(input_data_paths[0], mode="r") as dataset: + T, C, Z, Y, X = dataset.data.shape + dataset_scale = dataset.scale + channel_names = dataset.channel_names + + input_channel_idx = [] + # FIXME: these should be part of a config + # FIXME:this is hardcoded to spit out 3 chans for RGB + output_channel_idx = [0, 1, 2] + time_indices = list(range(T)) + + if HSV_method == "PRO": + # hsv_channels = ["Orientation", "Retardance", "Phase"] + HSV_func = vz.HSV_PRO + hsv_func_kwargs = dict( + channel_order=output_channel_idx, max_val_V=0.5, max_val_S=1.0 + ) + + elif HSV_method == "RO": + # hsv_channels = [ + # "Orientation", + # "Retardance", + # ] + HSV_func = vz.HSV_RO + hsv_func_kwargs = dict(channel_order=output_channel_idx, max_val_V=0.5) + + elif HSV_method == "JCh": + # hsv_channels = ["Orientation", "Retardance"] + HSV_func = vz.JCh_mapping + hsv_func_kwargs = dict( + channel_order=output_channel_idx, max_val_ret=150, noise_level=1 + ) + for channel in hsv_channels: + if channel in channel_names: + input_channel_idx.append(channel_names.index(channel)) + rgb_channel_names = ["Red", "Green", "Blue"] + + # Here the functions will output an RGB image + output_metadata = { + "shape": (len(time_indices), len(rgb_channel_names), Z, Y, X), + "chunks": None, + "scale": dataset_scale, + "channel_names": rgb_channel_names, + "dtype": np.float32, + } + + utils.create_empty_hcs_zarr( + store_path=output_data_path, + position_keys=[p.parts[-3:] for p in input_data_paths], + **output_metadata, + ) + + for input_position_path in input_data_paths: + utils.process_single_position_v2( + HSV_func, + input_data_path=input_position_path, # source store + output_path=output_data_path, # target store + time_indices=time_indices, + input_channel_idx=input_channel_idx, + output_channel_idx=output_channel_idx, + num_processes=num_processes, # parallel processing over time + **hsv_func_kwargs, + ) + +# %% diff --git a/mantis/analysis/register.py b/mantis/analysis/register.py index 3db569b8..e6a34f72 100644 --- a/mantis/analysis/register.py +++ b/mantis/analysis/register.py @@ -267,7 +267,7 @@ def find_lir(registered_zyx: np.ndarray, plot: bool = False) -> Tuple: # Find the lir Z zyx_shape = registered_zyx.shape registered_zx_bool = registered_zyx.transpose((2, 0, 1)) > 0 - registered_zx_bool = registered_zx_bool[zyx_shape[0] // 2].copy() + registered_zx_bool = registered_zx_bool[zyx_shape[2] // 2].copy() rectangle_coords_zx = lir.lir(registered_zx_bool) x = rectangle_coords_zx[0] z = rectangle_coords_zx[1] diff --git a/mantis/analysis/visualization.py b/mantis/analysis/visualization.py new file mode 100644 index 00000000..5837ec04 --- /dev/null +++ b/mantis/analysis/visualization.py @@ -0,0 +1,147 @@ +import numpy as np + +from colorspacious import cspace_convert +from skimage.color import hsv2rgb +from skimage.exposure import rescale_intensity + + +def HSV_PRO(czyx, channel_order, max_val_V: float = 1.0, max_val_S: float = 1.0): + """ + HSV encoding of retardance + orientation + phase image with hsv colormap (orientation in h, retardance in s, phase in v) + Parameters + ---------- + czyx : numpy.ndarray + channel_order: list + the order in which the channels should be stacked i.e([orientation_c_idx, retardance_c_idx,phase_c_idx]) + the 0 index corresponds to the orientation image (range from 0 to pi) + the 1 index corresponds to the retardance image + the 2 index corresponds to the Phase image + max_val_V : float + raise the brightness of the phase channel by 1/max_val_V + max_val_S : float + raise the brightness of the retardance channel by 1/max_val_S + Returns: + RGB with HSV (Orientation, Retardance, Phase) + """ + + C, Z, Y, X = czyx.shape + assert C == 3, "The input array must have 3 channels" + print(f"channel_order: {channel_order}") + + czyx_out = np.zeros((3, Z, Y, X), dtype=np.float32) + # Normalize the stack + ordered_stack = np.stack( + ( + # Normalize the first channel by dividing by pi + czyx[channel_order[0]] / np.pi, + # Normalize the second channel and rescale intensity + rescale_intensity( + czyx[channel_order[1]], + in_range=( + np.min(czyx[channel_order[1]]), + np.max(czyx[channel_order[1]]), + ), + out_range=(0, 1), + ) + / max_val_S, + # Normalize the third channel and rescale intensity + rescale_intensity( + czyx[channel_order[2]], + in_range=( + np.min(czyx[channel_order[2]]), + np.max(czyx[channel_order[2]]), + ), + out_range=(0, 1), + ) + / max_val_V, + ), + axis=0, + ) + czyx_out = hsv2rgb(ordered_stack, channel_axis=0) + return czyx_out + + +def HSV_RO(czyx, channel_order: list[int], max_val_V: int = 1): + """ + visualize retardance + orientation with hsv colormap (orientation in h, saturation=1 s, retardance in v) + Parameters + ---------- + czyx : numpy.ndarray + channel_order: list + the order in which the channels should be stacked i.e([orientation_c_idx, retardance_c_idx]) + the 0 index corresponds to the orientation image (range from 0 to pi) + the 1 index corresponds to the retardance image + max_val_V : float + raise the brightness of the phase channel by 1/max_val_V + Returns: + + RGB with HSV (Orientation, _____ , Retardance) + """ + C, Z, Y, X = czyx.shape + assert C == 2, "The input array must have 2 channels" + czyx_out = np.zeros((3, Z, Y, X), dtype=np.float32) + ordered_stack = np.stack( + ( + # Normalize the first channel by dividing by pi and then rescale intensity + czyx[channel_order[0]] / np.pi, + # Set the second channel to ones = Saturation 1 + np.ones_like(czyx[channel_order[0]]), + # Normalize the third channel and rescale intensity + np.minimum( + 1, + rescale_intensity( + czyx[channel_order[1]], + in_range=( + np.min(czyx[channel_order[1]]), + np.max(czyx[channel_order[1]]), + ), + out_range=(0, max_val_V), + ), + ), + ), + axis=0, + ) + # HSV-RO encoding + czyx_out = hsv2rgb(ordered_stack, channel_axis=0) + return czyx_out + + +def JCh_mapping(czyx, channel_order: list[int], max_val_ret: int = None, noise_level: int = 1): + """ + JCh retardance + orientation + phase image with hsv colormap (orientation in h, retardance in s, phase in v) + Parameters + ---------- + czyx : numpy.ndarray + channel_order: list + the order in which the channels should be stacked i.e([retardance_c_idx, orientation_c_idx]) + the 0 index corresponds to the retardance image + the 1 index corresponds to the orientation image (range from 0 to pi) + + max_val_V : float + raise the brightness of the phase channel by 1/max_val_ret + Returns: + RGB with JCh (Retardance, Orientation) + """ + # retardance, orientation + C, Z, Y, X = czyx.shape + assert C == 2, "The input array must have 2 channels" + + # Retardance,chroma,Hue + czyx_out = np.zeros((3, Z, Y, X), dtype=np.float32) + for z_idx in range(Z): + # Retardance + if max_val_ret is None: + max_val_ret = np.max(czyx[channel_order[0], z_idx]) + retardance = np.clip(czyx[channel_order[0], z_idx], 0, max_val_ret) + # Chroma of each pixel, set to 60 by default, with noise handling + chroma = np.where(czyx[channel_order[0], z_idx] < noise_level, 0, 60) + # Orientation 180 to 360 to match periodic hue + hue = czyx[channel_order[1], z_idx] * 360 / np.pi + # Stack arrays in the correct order (Y, X, 3) + I_JCh = np.stack((retardance, chroma, hue), axis=-1) + # Transpose to shape for the skimage or colorspace functions + JCh_rgb = cspace_convert(I_JCh, "JCh", "sRGB1") + JCh_rgb = np.clip(JCh_rgb, 0, 1) + czyx_out[:, z_idx] = np.transpose(JCh_rgb, (2, 0, 1)) + + return czyx_out diff --git a/mantis/cli/apply_affine.py b/mantis/cli/apply_affine.py index 82a85b31..105232ec 100644 --- a/mantis/cli/apply_affine.py +++ b/mantis/cli/apply_affine.py @@ -108,7 +108,8 @@ def apply_affine( ) # Overwrite the previous target shape Z_target, Y_target, X_target = cropped_target_shape_zyx[-3:] - click.echo(f'Shape of cropped output dataset: {target_shape_zyx}\n') + cropped_target_shape_zyx = Z_target, Y_target, X_target + click.echo(f'Shape of cropped output dataset: {cropped_target_shape_zyx}\n') else: Z_slice, Y_slice, X_slice = ( slice(0, Z_target), diff --git a/mantis/cli/estimate_affine.py b/mantis/cli/estimate_affine.py index bdaacba0..6e985d27 100644 --- a/mantis/cli/estimate_affine.py +++ b/mantis/cli/estimate_affine.py @@ -6,6 +6,7 @@ from iohub import open_ome_zarr from iohub.reader import print_info from skimage.transform import EuclideanTransform +from skimage.transform import SimilarityTransform from waveorder.focus import focus_from_transverse_band from mantis.analysis.AnalysisSettings import RegistrationSettings @@ -33,7 +34,15 @@ @source_position_dirpaths() @target_position_dirpaths() @output_filepath() -def estimate_affine(source_position_dirpaths, target_position_dirpaths, output_filepath): +@click.option( + "--similarity-flag", + '-x', + is_flag=True, + help='flag to use similarity transform (rotation, translation, scaling) default:Eucledian (rotation, translation)', +) +def estimate_affine( + source_position_dirpaths, target_position_dirpaths, output_filepath, similarity_flag +): """ Estimate the affine transform between a source (i.e. moving) and a target (i.e. fixed) image by selecting corresponding points in each. @@ -42,6 +51,7 @@ def estimate_affine(source_position_dirpaths, target_position_dirpaths, output_f -s ./acq_name_labelfree_reconstructed.zarr/0/0/0 -t ./acq_name_lightsheet_deskewed.zarr/0/0/0 -o ./output.yml + -x flag to use similarity transform (rotation, translation, scaling) default:Eucledian (rotation, translation) """ click.echo("\nTarget channel INFO:") @@ -72,35 +82,39 @@ def estimate_affine(source_position_dirpaths, target_position_dirpaths, output_f source_channel_Z, source_channel_Y, source_channel_X = source_channel_volume.shape[-3:] target_channel_Z, target_channel_Y, target_channel_X = target_channel_volume.shape[-3:] - focus_source_channel_idx = focus_from_transverse_band( - source_channel_volume[ - :, - source_channel_Y // 2 - - FOCUS_SLICE_ROI_WIDTH : source_channel_Y // 2 - + FOCUS_SLICE_ROI_WIDTH, - source_channel_X // 2 - - FOCUS_SLICE_ROI_WIDTH : source_channel_X // 2 - + FOCUS_SLICE_ROI_WIDTH, - ], - NA_det=NA_DETECTION_SOURCE, - lambda_ill=WAVELENGTH_EMISSION_SOURCE_CHANNEL, - pixel_size=source_channel_voxel_size[-1], - ) + if source_channel_Z < 2 or target_channel_Z < 2: + focus_source_channel_idx = 1 + focus_target_channel_idx = 1 + else: + focus_source_channel_idx = focus_from_transverse_band( + source_channel_volume[ + :, + source_channel_Y // 2 + - FOCUS_SLICE_ROI_WIDTH : source_channel_Y // 2 + + FOCUS_SLICE_ROI_WIDTH, + source_channel_X // 2 + - FOCUS_SLICE_ROI_WIDTH : source_channel_X // 2 + + FOCUS_SLICE_ROI_WIDTH, + ], + NA_det=NA_DETECTION_SOURCE, + lambda_ill=WAVELENGTH_EMISSION_SOURCE_CHANNEL, + pixel_size=source_channel_voxel_size[-1], + ) - focus_target_channel_idx = focus_from_transverse_band( - target_channel_volume[ - :, - target_channel_Y // 2 - - FOCUS_SLICE_ROI_WIDTH : target_channel_Y // 2 - + FOCUS_SLICE_ROI_WIDTH, - target_channel_X // 2 - - FOCUS_SLICE_ROI_WIDTH : target_channel_X // 2 - + FOCUS_SLICE_ROI_WIDTH, - ], - NA_det=NA_DETECTION_TARGET, - lambda_ill=WAVELENGTH_EMISSION_TARGET_CHANNEL, - pixel_size=target_channel_voxel_size[-1], - ) + focus_target_channel_idx = focus_from_transverse_band( + target_channel_volume[ + :, + target_channel_Y // 2 + - FOCUS_SLICE_ROI_WIDTH : target_channel_Y // 2 + + FOCUS_SLICE_ROI_WIDTH, + target_channel_X // 2 + - FOCUS_SLICE_ROI_WIDTH : target_channel_X // 2 + + FOCUS_SLICE_ROI_WIDTH, + ], + NA_det=NA_DETECTION_TARGET, + lambda_ill=WAVELENGTH_EMISSION_TARGET_CHANNEL, + pixel_size=target_channel_voxel_size[-1], + ) click.echo() if focus_source_channel_idx not in (0, source_channel_Z - 1): @@ -276,24 +290,36 @@ def lambda_callback(layer, event): pts_target_channel = points_target_channel.data # Estimate the affine transform between the points xy to make sure registration is good - transform = EuclideanTransform() - transform.estimate(pts_source_channel[:, 1:], pts_target_channel[:, 1:]) - yx_points_transformation_matrix = transform.params + if similarity_flag: + # Similarity transform (rotation, translation, scaling) + transform = SimilarityTransform() + transform.estimate(pts_source_channel, pts_target_channel) + manual_estimated_transform = transform.params @ compound_affine - z_translation = pts_target_channel[0, 0] - pts_source_channel[0, 0] + else: + # Euclidean transform (rotation, translation) limiting this dataset's scale and just z-translation + transform = EuclideanTransform() + transform.estimate(pts_source_channel[:, 1:], pts_target_channel[:, 1:]) + yx_points_transformation_matrix = transform.params - z_scale_translate_matrix = np.array([[1, 0, 0, z_translation]]) + z_translation = pts_target_channel[0, 0] - pts_source_channel[0, 0] - # 2D to 3D matrix - euclidian_transform = np.vstack( - (z_scale_translate_matrix, np.insert(yx_points_transformation_matrix, 0, 0, axis=1)) - ) # Insert 0 in the third entry of each row + z_scale_translate_matrix = np.array([[1, 0, 0, z_translation]]) - scaling_affine = get_3D_rescaling_matrix( - (1, target_channel_Y, target_channel_X), - (scaling_factor_z, scaling_factor_yx, scaling_factor_yx), - ) - manual_estimated_transform = euclidian_transform @ compound_affine + # 2D to 3D matrix + euclidian_transform = np.vstack( + ( + z_scale_translate_matrix, + np.insert(yx_points_transformation_matrix, 0, 0, axis=1), + ) + ) # Insert 0 in the third entry of each row + + scaling_affine = get_3D_rescaling_matrix( + (1, target_channel_Y, target_channel_X), + (scaling_factor_z, scaling_factor_yx, scaling_factor_yx), + ) + + manual_estimated_transform = euclidian_transform @ compound_affine # NOTE: these two functions are key to pass the function properly to ANTs manual_estimated_transform_ants_style = manual_estimated_transform[:, :-1].ravel() @@ -349,6 +375,8 @@ def lambda_callback(layer, event): click.echo(f"Writing registration parameters to {output_filepath}") model_to_yaml(model, output_filepath) + input("Press to close the viewer and exit...") + if __name__ == "__main__": estimate_affine() diff --git a/mantis/cli/utils.py b/mantis/cli/utils.py index 73d68f80..99a5ae74 100644 --- a/mantis/cli/utils.py +++ b/mantis/cli/utils.py @@ -8,14 +8,27 @@ from pathlib import Path from typing import Tuple +import ants import click +import largestinteriorrectangle as lir import numpy as np +import scipy.ndimage as ndi import yaml from iohub.ngff import Position, open_ome_zarr from iohub.ngff_meta import TransformationMeta -from numpy.typing import DTypeLike from tqdm import tqdm +from numpy.typing import DTypeLike +import torch +import matplotlib.pyplot as plt + +from numpy.typing import DTypeLike +import torch +import matplotlib.pyplot as plt +from natsort import natsorted +import glob +from typing import List +import os # TODO: replace this with recOrder recOrder.cli.utils.create_empty_hcs() @@ -87,86 +100,6 @@ def create_empty_zarr( input_dataset.close() -# TODO: convert all code to use this function from now on -def create_empty_hcs_zarr( - store_path: Path, - position_keys: list[Tuple[str]], - channel_names: list[str], - shape: Tuple[int], - chunks: Tuple[int] = None, - scale: Tuple[float] = (1, 1, 1, 1, 1), - dtype: DTypeLike = np.float32, - max_chunk_size_bytes=500e6, -) -> None: - """ - If the plate does not exist, create an empty zarr plate. - If the plate exists, append positions and channels if they are not - already in the plate. - Parameters - ---------- - store_path : Path - hcs plate path - position_keys : list[Tuple[str]] - Position keys, will append if not present in the plate. - e.g. [("A", "1", "0"), ("A", "1", "1")] - shape : Tuple[int] - chunks : Tuple[int] - scale : Tuple[float] - channel_names : list[str] - Channel names, will append if not present in metadata. - dtype : DTypeLike - - Modifying from recOrder - https://github.com/mehta-lab/recOrder/blob/d31ad910abf84c65ba927e34561f916651cbb3e8/recOrder/cli/utils.py#L12 - """ - MAX_CHUNK_SIZE = max_chunk_size_bytes # in bytes - bytes_per_pixel = np.dtype(dtype).itemsize - - # Limiting the chunking to 500MB - if chunks is None: - chunk_zyx_shape = list(shape[-3:]) - # chunk_zyx_shape[-3] > 1 ensures while loop will not stall if single - # XY image is larger than MAX_CHUNK_SIZE - while ( - chunk_zyx_shape[-3] > 1 - and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE - ): - chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int) - chunk_zyx_shape = tuple(chunk_zyx_shape) - - chunks = 2 * (1,) + chunk_zyx_shape - - # Create plate - output_plate = open_ome_zarr( - str(store_path), layout="hcs", mode="a", channel_names=channel_names - ) - - # Create positions - for position_key in position_keys: - position_key_string = "/".join(position_key) - # Check if position is already in the store, if not create it - if position_key_string not in output_plate.zgroup: - position = output_plate.create_position(*position_key) - _ = position.create_zeros( - name="0", - shape=shape, - chunks=chunks, - dtype=dtype, - transform=[TransformationMeta(type="scale", scale=scale)], - ) - else: - position = output_plate[position_key_string] - - # Check if channel_names are already in the store, if not append them - for channel_name in channel_names: - # Read channel names directly from metadata to avoid race conditions - metadata_channel_names = [ - channel.label for channel in position.metadata.omero.channels - ] - if channel_name not in metadata_channel_names: - position.append_channel(channel_name, resize_arrays=True) - - def get_output_paths(input_paths: list[Path], output_zarr_path: Path) -> list[Path]: """Generates a mirrored output path list given an input list of positions""" list_output_path = [] @@ -198,56 +131,6 @@ def apply_function_to_zyx_and_save( click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") -# NOTE WIP -def apply_transform_to_zyx_and_save_v2( - func, - position: Position, - output_path: Path, - input_channel_indices: list[int], - output_channel_indices: list[int], - t_idx: int, - c_idx: int = None, - **kwargs, -) -> None: - """Load a zyx array from a Position object, apply a transformation to CZYX or ZYX and save the result to file""" - click.echo(f"Processing c={c_idx}, t={t_idx}") - - # TODO: temporary fix to slumkit issue - if _is_nested(input_channel_indices): - # print(f'input_channel_indices: {input_channel_indices}') - input_channel_indices = [int(x) for x in input_channel_indices if x.isdigit()] - if _is_nested(output_channel_indices): - # print(f'input_channel_indices: {output_channel_indices}') - output_channel_indices = [int(x) for x in output_channel_indices if x.isdigit()] - click.echo(f'input_channel_indices: {input_channel_indices}') - - # Process CZYX vs ZYX - if input_channel_indices is not None: - czyx_data = position.data.oindex[t_idx, input_channel_indices] - if not _check_nan_n_zeros(czyx_data): - transformed_czyx = func(czyx_data, **kwargs) - # Write to file - with open_ome_zarr(output_path, mode="r+") as output_dataset: - output_dataset[0].oindex[t_idx, output_channel_indices] = transformed_czyx - click.echo(f"Finished Writing.. t={t_idx}") - else: - click.echo(f"Skipping t={t_idx} due to all zeros or nans") - else: - zyx_data = position.data.oindex[t_idx, c_idx] - # Checking if nans or zeros and skip processing - if not _check_nan_n_zeros(zyx_data): - # Apply transformation - transformed_zyx = func(zyx_data, **kwargs) - - # Write to file - with open_ome_zarr(output_path, mode="r+") as output_dataset: - output_dataset[0][t_idx, c_idx] = transformed_zyx - - click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") - else: - click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") - - def process_single_position( func, input_data_path: Path, @@ -284,11 +167,11 @@ def process_single_position( # Write the settings into the metadata if existing # TODO: alternatively we can throw all extra arguments as metadata. - if 'extra_metadata' in non_func_args: + if "extra_metadata" in non_func_args: # For each dictionary in the nest - with open_ome_zarr(output_path, mode='r+') as output_dataset: - for params_metadata_keys in kwargs['extra_metadata'].keys(): - output_dataset.zattrs['extra_metadata'] = non_func_args['extra_metadata'] + with open_ome_zarr(output_path, mode="r+") as output_dataset: + for params_metadata_keys in kwargs["extra_metadata"].keys(): + output_dataset.zattrs["extra_metadata"] = non_func_args["extra_metadata"] # Loop through (T, C), deskewing and writing as we go click.echo(f"\nStarting multiprocess pool with {num_processes} processes") @@ -305,95 +188,207 @@ def process_single_position( ) -# TODO: modifiy how we get the time and channesl like recOrder (isinstance(input, list) or instance(input,int) or all) -def process_single_position_v2( - func, - input_data_path: Path, - output_path: Path, - time_indices: list = [0], - input_channel_idx: list = [], - output_channel_idx: list = [], - num_processes: int = mp.cpu_count(), - **kwargs, -) -> None: - """Register a single position with multiprocessing parallelization over T and C""" - # Function to be applied - click.echo(f"Function to be applied: \t{func}") +def scale_affine(start_shape_zyx, scaling_factor_zyx=(1, 1, 1), end_shape_zyx=None): + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + scaling_matrix = np.array( + [ + [scaling_factor_zyx[-3], 0, 0, 0], + [ + 0, + scaling_factor_zyx[-2], + 0, + -center_Y_start * scaling_factor_zyx[-2] + center_Y_end, + ], + [ + 0, + 0, + scaling_factor_zyx[-1], + -center_X_start * scaling_factor_zyx[-1] + center_X_end, + ], + [0, 0, 0, 1], + ] + ) + return scaling_matrix - # Get the reader and writer - click.echo(f"Input data path:\t{input_data_path}") - click.echo(f"Output data path:\t{str(output_path)}") - input_dataset = open_ome_zarr(str(input_data_path)) - stdout_buffer = io.StringIO() - with contextlib.redirect_stdout(stdout_buffer): - input_dataset.print_tree() - click.echo(f" Input data tree: {stdout_buffer.getvalue()}") - # Find time indices - if time_indices == "all": - time_indices = range(input_dataset.data.shape[0]) - elif isinstance(time_indices, list): - time_indices = time_indices +def rotate_affine( + start_shape_zyx: Tuple, angle: float = 0.0, end_shape_zyx: Tuple = None +) -> np.ndarray: + """ + Rotate Transformation Matrix - # Check for invalid times - time_ubound = input_dataset.data.shape[0] - 1 - if np.max(time_indices) > time_ubound: - raise ValueError( - f"time_indices = {time_indices} includes a time index beyond the maximum index of the dataset = {time_ubound}" - ) + Parameters + ---------- + start_shape_zyx : Tuple + Shape of the input + angle : float, optional + Angles of rotation in degrees + end_shape_zyx : Tuple, optional + Shape of output space - # Check the arguments for the function - all_func_params = inspect.signature(func).parameters.keys() - # Extract the relevant kwargs for the function 'func' - func_args = {} - non_func_args = {} + Returns + ------- + np.ndarray + Rotation matrix + """ + # TODO: make this 3D? + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + theta = np.radians(angle) + + rotation_matrix = np.array( + [ + [1, 0, 0, 0], + [ + 0, + np.cos(theta), + -np.sin(theta), + -center_Y_start * np.cos(theta) + + np.sin(theta) * center_X_start + + center_Y_end, + ], + [ + 0, + np.sin(theta), + np.cos(theta), + -center_Y_start * np.sin(theta) + - center_X_start * np.cos(theta) + + center_X_end, + ], + [0, 0, 0, 1], + ] + ) - for k, v in kwargs.items(): - if k in all_func_params: - func_args[k] = v - else: - non_func_args[k] = v + affine_rot_n_scale_matrix_zyx = rotation_matrix - # Write the settings into the metadata if existing - if 'extra_metadata' in non_func_args: - # For each dictionary in the nest - with open_ome_zarr(output_path, mode='r+') as output_dataset: - for params_metadata_keys in kwargs['extra_metadata'].keys(): - output_dataset.zattrs['extra_metadata'] = non_func_args['extra_metadata'] + return affine_rot_n_scale_matrix_zyx + + +def affine_transform( + zyx_data: np.ndarray, + matrix: np.ndarray, + output_shape_zyx: Tuple, + method="ants", + crop_output_slicing: bool = None, +) -> np.ndarray: + """_summary_ + + Parameters + ---------- + zyx_data : np.ndarray + 3D input array to be transformed + matrix : np.ndarray + 3D Homogenous transformation matrix + output_shape_zyx : Tuple + output target zyx shape + method : str, optional + method to use for transformation, by default 'ants' + crop_output : bool, optional + crop the output to the largest interior rectangle, by default False + + Returns + ------- + np.ndarray + registered zyx data + """ + # Convert nans to 0 + zyx_data = np.nan_to_num(zyx_data, nan=0) + + # NOTE: default set to ANTS apply_affine method until we decide we get a benefit from using cupy + # The ants method on CPU is 10x faster than scipy on CPU. Cupy method has not been bencharked vs ANTs + + if method == "ants": + # The output has to be a ANTImage Object + empty_target_array = np.zeros((output_shape_zyx), dtype=np.float32) + target_zyx_ants = ants.from_numpy(empty_target_array) + + T_ants = numpy_to_ants_transform_zyx(matrix) + + zyx_data_ants = ants.from_numpy(zyx_data.astype(np.float32)) + registered_zyx = T_ants.apply_to_image( + zyx_data_ants, reference=target_zyx_ants + ).numpy() + + elif method == "scipy": + registered_zyx = ndi.affine_transform(zyx_data, matrix, output_shape_zyx) - # Loop through (T, C), deskewing and writing as we go - if input_channel_idx is None or len(input_channel_idx) == 0: - # If C is not empty, use itertools.product with both ranges - _, C, _, _, _ = input_dataset.data.shape - iterable = itertools.product(time_indices, range(C)) - partial_apply_transform_to_zyx_and_save = partial( - apply_transform_to_zyx_and_save_v2, - func, - input_dataset, - output_path / Path(*input_data_path.parts[-3:]), - input_channel_indices=None, - **func_args, - ) else: - # If C is empty, use only the range for time_indices - iterable = itertools.product(time_indices) - partial_apply_transform_to_zyx_and_save = partial( - apply_transform_to_zyx_and_save_v2, - func, - input_dataset, - output_path / Path(*input_data_path.parts[-3:]), - input_channel_idx, - output_channel_idx, - c_idx=0, - **func_args, + raise ValueError(f"Unknown method {method}") + + # Crop the output to the largest interior rectangle + if crop_output_slicing is not None: + Z_slice, Y_slice, X_slice = crop_output_slicing + registered_zyx = registered_zyx[Z_slice, Y_slice, X_slice] + + return registered_zyx + + +def find_lir(registered_zyx: np.ndarray, plot: bool = False) -> Tuple: + # Find the lir YX + registered_yx_bool = registered_zyx[registered_zyx.shape[0] // 2].copy() + registered_yx_bool = registered_yx_bool > 0 * 1.0 + rectangle_coords_yx = lir.lir(registered_yx_bool) + + x = rectangle_coords_yx[0] + y = rectangle_coords_yx[1] + width = rectangle_coords_yx[2] + height = rectangle_coords_yx[3] + corner1_xy = (x, y) # Bottom-left corner + corner2_xy = (x + width, y) # Bottom-right corner + corner3_xy = (x + width, y + height) # Top-right corner + corner4_xy = (x, y + height) # Top-left corner + rectangle_xy = np.array((corner1_xy, corner2_xy, corner3_xy, corner4_xy)) + X_slice = slice(rectangle_xy.min(axis=0)[0], rectangle_xy.max(axis=0)[0]) + Y_slice = slice(rectangle_xy.min(axis=0)[1], rectangle_xy.max(axis=0)[1]) + + # Find the lir Z + zyx_shape = registered_zyx.shape + registered_zx_bool = registered_zyx.transpose((2, 0, 1)) > 0 + registered_zx_bool = registered_zx_bool[zyx_shape[0] // 2].copy() + rectangle_coords_zx = lir.lir(registered_zx_bool) + x = rectangle_coords_zx[0] + z = rectangle_coords_zx[1] + width = rectangle_coords_zx[2] + height = rectangle_coords_zx[3] + corner1_zx = (x, z) # Bottom-left corner + corner2_zx = (x + width, z) # Bottom-right corner + corner3_zx = (x + width, z + height) # Top-right corner + corner4_zx = (x, z + height) # Top-left corner + rectangle_zx = np.array((corner1_zx, corner2_zx, corner3_zx, corner4_zx)) + Z_slice = slice(rectangle_zx.min(axis=0)[1], rectangle_zx.max(axis=0)[1]) + + if plot: + rectangle_yx = plt.Polygon( + (corner1_xy, corner2_xy, corner3_xy, corner4_xy), + closed=True, + fill=None, + edgecolor="r", + ) + # Add the rectangle to the plot + fig, ax = plt.subplots(nrows=1, ncols=2) + ax[0].imshow(registered_yx_bool) + ax[0].add_patch(rectangle_yx) + + rectangle_zx = plt.Polygon( + (corner1_zx, corner2_zx, corner3_zx, corner4_zx), + closed=True, + fill=None, + edgecolor="r", ) + ax[1].imshow(registered_zx_bool) + ax[1].add_patch(rectangle_zx) + plt.savefig("./lir.png") - click.echo(f"\nStarting multiprocess pool with {num_processes} processes") - with mp.Pool(num_processes) as p: - p.starmap( - partial_apply_transform_to_zyx_and_save, - iterable, - ) + return (Z_slice, Y_slice, X_slice) def copy_n_paste(zyx_data: np.ndarray, zyx_slicing_params: list) -> np.ndarray: @@ -446,6 +441,87 @@ def copy_n_paste_czyx(czyx_data: np.ndarray, czyx_slicing_params: list) -> np.nd return czyx_data_sliced +def find_lir_slicing_params( + input_zyx_shape: Tuple, + target_zyx_shape: Tuple, + transformation_matrix: np.ndarray, + plot: bool = False, +) -> Tuple: + """ + Find the largest internal rectangle between the transformed input and the target + and return the cropping parameters + + Parameters + ---------- + input_zyx_shape : Tuple + shape of input array + target_zyx_shape : Tuple + shape of target array + transformation_matrix : np.ndarray + transformation matrix between input and target + + Returns + ------- + Tuple + Slicing parameters to crop LIR + + """ + print("Starting Largest interior rectangle (LIR) search") + + # Make dummy volumes + img1 = np.ones(tuple(input_zyx_shape), dtype=np.float32) + img2 = np.ones(tuple(target_zyx_shape), dtype=np.float32) + + # Conver to ants objects + target_zyx_ants = ants.from_numpy(img2.astype(np.float32)) + zyx_data_ants = ants.from_numpy(img1.astype(np.float32)) + + ants_composed_matrix = numpy_to_ants_transform_zyx(transformation_matrix) + + # Apply affine + registered_zyx = ants_composed_matrix.apply_to_image( + zyx_data_ants, reference=target_zyx_ants + ) + + Z_slice, Y_slice, X_slice = find_lir(registered_zyx.numpy(), plot=plot) + + # registered_zyx_bool = registered_zyx.numpy().copy() + # registered_zyx_bool = registered_zyx_bool > 0 + # # NOTE: we use the center of the volume as reference + # rectangle_coords_yx = lir.lir(registered_zyx_bool[registered_zyx.shape[0] // 2]) + + # # Find the overlap in XY + # x = rectangle_coords_yx[0] + # y = rectangle_coords_yx[1] + # width = rectangle_coords_yx[2] + # height = rectangle_coords_yx[3] + # corner1_xy = (x, y) # Bottom-left corner + # corner2_xy = (x + width, y) # Bottom-right corner + # corner3_xy = (x + width, y + height) # Top-right corner + # corner4_xy = (x, y + height) # Top-left corner + # rectangle_xy = np.array((corner1_xy, corner2_xy, corner3_xy, corner4_xy)) + # X_slice = slice(rectangle_xy.min(axis=0)[0], rectangle_xy.max(axis=0)[0]) + # Y_slice = slice(rectangle_xy.min(axis=0)[1], rectangle_xy.max(axis=0)[1]) + + # # Find the overlap in Z + # registered_zx = registered_zyx.numpy() + # registered_zx = registered_zx.transpose((2, 0, 1)) > 0 + # rectangle_coords_zx = lir.lir(registered_zx[registered_zyx.shape[0] // 2].copy()) + # x = rectangle_coords_zx[0] + # y = rectangle_coords_zx[1] + # width = rectangle_coords_zx[2] + # height = rectangle_coords_zx[3] + # corner1_zx = (x, y) # Bottom-left corner + # corner2_zx = (x + width, y) # Bottom-right corner + # corner3_zx = (x + width, y + height) # Top-right corner + # corner4_zx = (x, y + height) # Top-left corner + # rectangle_zx = np.array((corner1_zx, corner2_zx, corner3_zx, corner4_zx)) + # Z_slice = slice(rectangle_zx.min(axis=0)[1], rectangle_zx.max(axis=0)[1]) + + print(f"Slicing parameters Z:{Z_slice}, Y:{Y_slice}, X:{X_slice}") + return (Z_slice, Y_slice, X_slice) + + def append_channels(input_data_path: Path, target_data_path: Path) -> None: """ Append channels to a target zarr store @@ -464,9 +540,9 @@ def append_channels(input_data_path: Path, target_data_path: Path) -> None: num_channels = len(target_data_channel_names) - 1 print(f"channels in target {target_data_channel_names}") print(f"adding channels {appending_channel_names}") - for name, position in tqdm(dataset.positions(), desc='Positions'): + for name, position in tqdm(dataset.positions(), desc="Positions"): for i, appending_channel_idx in enumerate( - tqdm(appending_channel_names, desc='Channel', leave=False) + tqdm(appending_channel_names, desc="Channel", leave=False) ): position.append_channel(appending_channel_idx) position["0"][:, num_channels + i + 1] = appending_dataset[str(name)][0][:, i] @@ -474,20 +550,192 @@ def append_channels(input_data_path: Path, target_data_path: Path) -> None: appending_dataset.close() -def model_to_yaml(model, yaml_path: Path) -> None: - """ - Save a model's dictionary representation to a YAML file. - - Borrowing from recOrder==0.4.0 +def numpy_to_ants_transform_zyx(T_numpy: np.ndarray): + """Homogeneous 3D transformation matrix from numpy to ants Parameters ---------- - model : object - The model object to convert to YAML. - yaml_path : Path - The path to the output YAML file. + numpy_transform :4x4 homogenous matrix - Raises + Returns + ------- + Ants transformation matrix object + """ + assert T_numpy.shape == (4, 4) + + T_ants_style = T_numpy[:, :-1].ravel() + T_ants_style[-3:] = T_numpy[:3, -1] + T_ants = ants.new_ants_transform( + transform_type="AffineTransform", + ) + T_ants.set_parameters(T_ants_style) + + return T_ants + + +def ants_to_numpy_transform_zyx(T_ants): + """ + Convert the ants transformation matrix to numpy 3D homogenous transform + + Modified from Jordao's dexp code + + Parameters + ---------- + T_ants : Ants transfromation matrix object + + Returns + ------- + np.array + Converted Ants to numpy array + + """ + + T_numpy = T_ants.parameters.reshape((3, 4), order="F") + T_numpy[:, :3] = T_numpy[:, :3].transpose() + T_numpy = np.vstack((T_numpy, np.array([0, 0, 0, 1]))) + + # Reference: + # https://sourceforge.net/p/advants/discussion/840261/thread/9fbbaab7/ + # https://github.com/netstim/leaddbs/blob/a2bb3e663cf7fceb2067ac887866124be54aca7d/helpers/ea_antsmat2mat.m + # T = original translation offset from A + # T = T + (I - A) @ centering + + T_numpy[:3, -1] += (np.eye(3) - T_numpy[:3, :3]) @ T_ants.fixed_parameters + + return T_numpy + + +def apply_stabilization_over_time_ants( + list_of_shifts: list, + input_data_path: Path, + output_path: Path, + time_indices: list = [0], + input_channel_idx: list = [], + output_channel_idx: list = [], + num_processes: int = 1, + **kwargs, +) -> None: + """Apply stabilization over time""" + # Function to be applied + # Get the reader and writer + click.echo(f"Input data path:\t{input_data_path}") + click.echo(f"Output data path:\t{str(output_path)}") + input_dataset = open_ome_zarr(str(input_data_path)) + stdout_buffer = io.StringIO() + with contextlib.redirect_stdout(stdout_buffer): + input_dataset.print_tree() + click.echo(f" Input data tree: {stdout_buffer.getvalue()}") + + T, C, _, _, _ = input_dataset.data.shape + + # Write the settings into the metadata if existing + # TODO: alternatively we can throw all extra arguments as metadata. + if "extra_metadata" in kwargs: + # For each dictionary in the nest + with open_ome_zarr(output_path, mode="r+") as output_dataset: + for params_metadata_keys in kwargs["extra_metadata"].keys(): + output_dataset.zattrs["extra_metadata"] = kwargs["extra_metadata"] + + # Loop through (T, C), deskewing and writing as we go + click.echo(f"\nStarting multiprocess pool with {num_processes} processes") + + if input_channel_idx is None or len(input_channel_idx) == 0: + # If C is not empty, use itertools.product with both ranges + _, C, _, _, _ = input_dataset.data.shape + iterable = itertools.product(time_indices, range(C)) + partial_stabilization_over_time_ants = partial( + stabilization_over_time_ants, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + list_of_shifts, + None, + None, + ) + else: + # If C is empty, use only the range for time_indices + iterable = itertools.product(time_indices) + partial_stabilization_over_time_ants = partial( + stabilization_over_time_ants, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + list_of_shifts, + input_channel_idx, + output_channel_idx, + c_idx=0, + ) + + with mp.Pool(num_processes) as p: + p.starmap( + partial_stabilization_over_time_ants, + iterable, + ) + input_dataset.close() + + +def stabilization_over_time_ants( + position: Position, + output_path: Path, + list_of_shifts: np.ndarray, + input_channel_idx: list, + output_channel_idx: list, + t_idx: int, + c_idx: int, + **kwargs, +) -> None: + """Load a zyx array from a Position object, apply a transformation and save the result to file""" + + click.echo(f"Processing c={c_idx}, t={t_idx}") + tx_shifts = numpy_to_ants_transform_zyx(list_of_shifts[t_idx]) + + # Process CZYX vs ZYX + if input_channel_idx is not None: + czyx_data = position.data.oindex[t_idx, input_channel_idx] + if not _check_nan_n_zeros(czyx_data): + for c in input_channel_idx: + print(f"czyx_data.shape {czyx_data.shape}") + zyx_data_ants = ants.from_numpy(czyx_data[0]) + registered_zyx = tx_shifts.apply_to_image( + zyx_data_ants, reference=zyx_data_ants + ) + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0].oindex[ + t_idx, output_channel_idx + ] = registered_zyx.numpy() + click.echo(f"Finished Writing.. t={t_idx}") + else: + click.echo(f"Skipping t={t_idx} due to all zeros or nans") + else: + zyx_data = position.data.oindex[t_idx, c_idx] + # Checking if nans or zeros and skip processing + if not _check_nan_n_zeros(zyx_data): + zyx_data_ants = ants.from_numpy(zyx_data) + # Apply transformation + registered_zyx = tx_shifts.apply_to_image(zyx_data_ants, reference=zyx_data_ants) + + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0][t_idx, c_idx] = registered_zyx.numpy() + + click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + else: + click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") + + +def model_to_yaml(model, yaml_path: Path) -> None: + """ + Save a model's dictionary representation to a YAML file. + + Borrowing from recOrder==0.4.0 + + Parameters + ---------- + model : object + The model object to convert to YAML. + yaml_path : Path + The path to the output YAML file. + + Raises ------ TypeError If the `model` object does not have a `dict()` method. @@ -569,6 +817,87 @@ def yaml_to_model(yaml_path: Path, model): return model(**raw_settings) +# TODO: convert all code to use this function from now on +def create_empty_hcs_zarr( + store_path: Path, + position_keys: list[Tuple[str]], + channel_names: list[str], + shape: Tuple[int], + chunks: Tuple[int] = None, + scale: Tuple[float] = (1, 1, 1, 1, 1), + dtype: DTypeLike = np.float32, + max_chunk_size_bytes=500e6, +) -> None: + """ + If the plate does not exist, create an empty zarr plate. + If the plate exists, append positions and channels if they are not + already in the plate. + Parameters + ---------- + store_path : Path + hcs plate path + position_keys : list[Tuple[str]] + Position keys, will append if not present in the plate. + e.g. [("A", "1", "0"), ("A", "1", "1")] + shape : Tuple[int] + chunks : Tuple[int] + scale : Tuple[float] + channel_names : list[str] + Channel names, will append if not present in metadata. + dtype : DTypeLike + + Modifying from recOrder + https://github.com/mehta-lab/recOrder/blob/d31ad910abf84c65ba927e34561f916651cbb3e8/recOrder/cli/utils.py#L12 + """ + MAX_CHUNK_SIZE = max_chunk_size_bytes # in bytes + bytes_per_pixel = np.dtype(dtype).itemsize + + # Limiting the chunking to 500MB + if chunks is None: + chunk_zyx_shape = list(shape[-3:]) + # chunk_zyx_shape[-3] > 1 ensures while loop will not stall if single + # XY image is larger than MAX_CHUNK_SIZE + while ( + chunk_zyx_shape[-3] > 1 + and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE + ): + chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int) + chunk_zyx_shape = tuple(chunk_zyx_shape) + + chunks = 2 * (1,) + chunk_zyx_shape + click.echo(f"Chunk size: {chunks}") + + # Create plate + output_plate = open_ome_zarr( + str(store_path), layout="hcs", mode="a", channel_names=channel_names + ) + + # Create positions + for position_key in position_keys: + position_key_string = "/".join(position_key) + # Check if position is already in the store, if not create it + if position_key_string not in output_plate.zgroup: + position = output_plate.create_position(*position_key) + _ = position.create_zeros( + name="0", + shape=shape, + chunks=chunks, + dtype=dtype, + transform=[TransformationMeta(type="scale", scale=scale)], + ) + else: + position = output_plate[position_key_string] + + # Check if channel_names are already in the store, if not append them + for channel_name in channel_names: + # Read channel names directly from metadata to avoid race conditions + metadata_channel_names = [ + channel.label for channel in position.metadata.omero.channels + ] + if channel_name not in metadata_channel_names: + position.append_channel(channel_name, resize_arrays=True) + + def _is_nested(lst): return any(isinstance(i, list) for i in lst) or any(isinstance(i, str) for i in lst) @@ -599,3 +928,302 @@ def _check_nan_n_zeros(input_array): # Return false return False + + +def nuc_mem_segmentation(czyx_data, **cellpose_kwargs) -> np.ndarray: + """Segment nuclei and membranes using cellpose""" + + from cellpose import models + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Get the key/values under this dictionary + # cellpose_params = cellpose_params.get('cellpose_params', {}) + cellpose_params = cellpose_kwargs["cellpose_kwargs"] + Z_center_slice = slice(int(cellpose_params["z_idx"]), int(cellpose_params["z_idx"]) + 1) + Z_slice = slice(int(cellpose_params["z_idx"]) - 3, int(cellpose_params["z_idx"]) + 3) + C, Z, Y, X = czyx_data.shape + + czyx_data[0][czyx_data[0] < 0] = 0 + czyx_data[1][czyx_data[1] < 0] = 0 + + czyx_data_mip = np.zeros((C, 1, Y, X)) + for c in range(C): + czyx_data_mip[c, 0] = np.max(czyx_data[c, Z_slice], axis=0) + cyx_data = czyx_data_mip[:, 0] + + if "nucleus_segmentation" in cellpose_params: + nuc_seg_kwargs = cellpose_params["nucleus_segmentation"] + if "membrane_segmentation" in cellpose_params: + mem_seg_kwargs = cellpose_params["membrane_segmentation"] + + # Initialize Cellpose models + cyto_model = models.Cellpose(gpu=True, model_type=cellpose_params["mem_model_path"]) + nuc_model = models.CellposeModel( + model_type=cellpose_params["nuc_model_path"], device=torch.device(device) + ) + + nuc_masks = nuc_model.eval(cyx_data[0], **nuc_seg_kwargs)[0] + mem_masks, _, _, _ = cyto_model.eval(cyx_data[1], **mem_seg_kwargs) + + # Save + segmentation_stack = np.zeros_like(czyx_data_mip) + zyx_mask = np.stack((nuc_masks, mem_masks)) + segmentation_stack[:, 0:1] = zyx_mask[:, np.newaxis] + + return segmentation_stack + + +## NOTE WIP +def apply_transform_to_zyx_and_save_v2( + func, + position: Position, + output_path: Path, + input_channel_indices: list[int], + output_channel_indices: list[int], + t_idx: int, + t_idx_out: int, + c_idx: int = None, + **kwargs, +) -> None: + """Load a zyx array from a Position object, apply a transformation to CZYX or ZYX and save the result to file""" + + # TODO: temporary fix to slumkit issue + if _is_nested(input_channel_indices): + input_channel_indices = [int(x) for x in input_channel_indices if x.isdigit()] + if _is_nested(output_channel_indices): + output_channel_indices = [int(x) for x in output_channel_indices if x.isdigit()] + click.echo(f"input_channel_indices: {input_channel_indices}") + + # Process CZYX vs ZYX + if input_channel_indices is not None: + click.echo(f"Processing t={t_idx}") + + czyx_data = position.data.oindex[t_idx, input_channel_indices] + if not _check_nan_n_zeros(czyx_data): + transformed_czyx = func(czyx_data, **kwargs) + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0].oindex[t_idx_out, output_channel_indices] = transformed_czyx + click.echo(f"Finished Writing.. t={t_idx}") + else: + click.echo(f"Skipping t={t_idx} due to all zeros or nans") + else: + click.echo(f"Processing c={c_idx}, t={t_idx}") + + zyx_data = position.data.oindex[t_idx, c_idx] + # Checking if nans or zeros and skip processing + if not _check_nan_n_zeros(zyx_data): + # Apply transformation + transformed_zyx = func(zyx_data, **kwargs) + + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0][t_idx_out, c_idx] = transformed_zyx + + click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + else: + click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") + + +# TODO: modifiy how we get the time and channesl like recOrder (isinstance(input, list) or instance(input,int) or all) +def process_single_position_v2( + func, + input_data_path: Path, + output_path: Path, + time_indices: list = [0], + time_indices_out: list = [0], + input_channel_idx: list = [], + output_channel_idx: list = [], + num_processes: int = mp.cpu_count(), + **kwargs, +) -> None: + """Register a single position with multiprocessing parallelization over T and C""" + # Function to be applied + click.echo(f"Function to be applied: \t{func}") + + # Get the reader and writer + click.echo(f"Input data path:\t{input_data_path}") + click.echo(f"Output data path:\t{str(output_path)}") + input_dataset = open_ome_zarr(str(input_data_path)) + stdout_buffer = io.StringIO() + with contextlib.redirect_stdout(stdout_buffer): + input_dataset.print_tree() + click.echo(f" Input data tree: {stdout_buffer.getvalue()}") + + # Find time indices + if time_indices == "all": + time_indices = range(input_dataset.data.shape[0]) + time_indices_out = time_indices + elif isinstance(time_indices, list): + time_indices_out = range(len(time_indices)) + + # Check for invalid times + time_ubound = input_dataset.data.shape[0] - 1 + if np.max(time_indices) > time_ubound: + raise ValueError( + f"time_indices = {time_indices} includes a time index beyond the maximum index of the dataset = {time_ubound}" + ) + + # Check the arguments for the function + all_func_params = inspect.signature(func).parameters.keys() + # Extract the relevant kwargs for the function 'func' + func_args = {} + non_func_args = {} + + for k, v in kwargs.items(): + if k in all_func_params: + func_args[k] = v + else: + non_func_args[k] = v + + # Write the settings into the metadata if existing + if "extra_metadata" in non_func_args: + # For each dictionary in the nest + with open_ome_zarr(output_path, mode="r+") as output_dataset: + for params_metadata_keys in kwargs["extra_metadata"].keys(): + output_dataset.zattrs["extra_metadata"] = non_func_args["extra_metadata"] + + # Loop through (T, C), deskewing and writing as we go + click.echo(f"\nStarting multiprocess pool with {num_processes} processes") + + if input_channel_idx is None or len(input_channel_idx) == 0: + # If C is not empty, use itertools.product with both ranges + _, C, _, _, _ = input_dataset.data.shape + iterable = [ + (time_idx, time_idx_out, c) + for (time_idx, time_idx_out), c in itertools.product( + zip(time_indices, time_indices_out), range(C) + ) + ] + partial_apply_transform_to_zyx_and_save = partial( + apply_transform_to_zyx_and_save_v2, + func, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + input_channel_indices=None, + **func_args, + ) + else: + # If C is empty, use only the range for time_indices + iterable = list(zip(time_indices, time_indices_out)) + partial_apply_transform_to_zyx_and_save = partial( + apply_transform_to_zyx_and_save_v2, + func, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + input_channel_idx, + output_channel_idx, + c_idx=0, + **func_args, + ) + with mp.Pool(num_processes) as p: + p.starmap( + partial_apply_transform_to_zyx_and_save, + iterable, + ) + + +def nuc_mem_segmentation_3D(czyx_data, zyx_slicing, **cellpose_kwargs): + from cellpose import models + from skimage.exposure import rescale_intensity + from skimage.util import invert + + Z_slice = zyx_slicing[0] + Y_slice = zyx_slicing[1] + X_slice = zyx_slicing[2] + czyx_data = czyx_data[:, Z_slice, Y_slice, X_slice] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + segmentation_stack = np.zeros_like(czyx_data) + click.echo(f'Segmentation Stack shape {segmentation_stack.shape}') + cellpose_params = cellpose_kwargs["cellpose_kwargs"] + c_idx = 0 + if "nucleus_kwargs" in cellpose_params: + click.echo('Segmenting Nuclei') + nuc_seg_kwargs = cellpose_params["nucleus_kwargs"] + + model_nucleus_3D = models.CellposeModel( + pretrained_model=cellpose_params["nuc_model_path"], + # net_avg=True, #Note removed CP3.0 + gpu=True, + device=torch.device(device), + ) + nuc_segmentation, _, _ = model_nucleus_3D.eval(czyx_data, **nuc_seg_kwargs) + segmentation_stack[c_idx] = nuc_segmentation.astype(np.uint16) + c_idx += 1 + if "membrane_kwargs" in cellpose_params: + click.echo('Segmenting Membrane') + mem_seg_kwargs = cellpose_params["membrane_kwargs"] + + model_membrane_3D = models.CellposeModel( + pretrained_model=cellpose_params["mem_model_path"], + # net_avg=True, + gpu=True, + device=torch.device(device), + ) + c_idx_mem, c_idx_nuc = mem_seg_kwargs['channels'] + czyx_data[c_idx_mem] = rescale_intensity( + invert(czyx_data[c_idx_mem]), out_range='uint16' + ) + czyx_data[c_idx_nuc] = rescale_intensity( + czyx_data[c_idx_nuc], + out_range=(np.min(czyx_data[c_idx_mem]), np.max(czyx_data[c_idx_mem])), + ) + + mem_segmentation, _, _ = model_membrane_3D.eval(czyx_data, **mem_seg_kwargs) + segmentation_stack[c_idx] = mem_segmentation.astype(np.uint16) + + return segmentation_stack + + +def _assign_available_gpu(): + """ + Assign an available GPU if there is one. + """ + import tensorflow as tf + + # Get the list of available GPUs + available_gpus = ( + os.popen("nvidia-smi --query-gpu=index --format=csv,noheader").read().split("\n")[:-1] + ) + gpus = tf.config.experimental.list_physical_devices("GPU") + if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + print(e) + + # Check if there are any available GPUs + if available_gpus: + # Assign the first available GPU + os.environ["CUDA_VISIBLE_DEVICES"] = available_gpus[0] + print(f"GPU {available_gpus[0]} assigned.") + else: + print("No available GPUs.") + + +def denoise_nuc_mem( + czyx_data: np.ndarray, + model_path: Path, +) -> np.ndarray: + from n2v.models import N2V + + _assign_available_gpu() + + # NOTE: this assumes channel_order = ['nuc', 'mem'] + C, Z, _, _ = czyx_data.shape + model_path = Path(model_path) + basedir_path = model_path.parent + model_nuc = N2V(config=None, name=model_path.parts[-1], basedir=basedir_path) + + z_stack = [] + for z_idx in range(Z): + yx_data_nuc = model_nuc.predict(czyx_data[0, z_idx], axes="YX") + z_stack.append(yx_data_nuc) + z_stack = np.stack(z_stack, axis=0) + z_stack = z_stack[np.newaxis] + print(z_stack.shape) + return z_stack diff --git a/pyproject.toml b/pyproject.toml index 2672f689..cd3d7f55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "waveorder @ git+https://github.com/mehta-lab/waveorder", "largestinteriorrectangle", "antspyx", + "colorspacious", ]