diff --git a/examples/slurmkit_example/slurm_apply_affine.py b/examples/slurmkit_example/slurm_apply_affine.py index dc7ecef0..88064270 100644 --- a/examples/slurmkit_example/slurm_apply_affine.py +++ b/examples/slurmkit_example/slurm_apply_affine.py @@ -1,68 +1,147 @@ import datetime -import os import glob -from mantis.cli import utils -from slurmkit import SlurmParams, slurm_function, submit_function -from natsort import natsorted +import os + +from pathlib import Path + import click -from mantis.cli.apply_affine import registration_params_from_file, rotate_n_affine_transform import numpy as np + from iohub import open_ome_zarr +from natsort import natsorted +from slurmkit import SlurmParams, slurm_function, submit_function +from mantis.analysis.AnalysisSettings import RegistrationSettings +from mantis.analysis.register import apply_affine_transform, find_overlapping_volume +from mantis.cli.apply_affine import rescale_voxel_size +from mantis.cli.utils import ( + copy_n_paste_czyx, + create_empty_hcs_zarr, + process_single_position_v2, + yaml_to_model, +) # io parameters -labelfree_data_paths = '/hpc/projects/comp.micro/mantis/2023_08_09_HEK_PCNA_H2B/2-phase3D/pcna_rac1_virtual_staining_b1_redo_1/phase3D.zarr/0/0/0' -lightsheet_data_paths = '/hpc/projects/comp.micro/mantis/2023_08_09_HEK_PCNA_H2B/1-deskew/pcna_rac1_virtual_staining_b1_redo_1/deskewed.zarr/0/0/0' -output_data_path = './registered_output.zarr' -registration_param_path = './register.yml' +source_position_dirpaths = '/input_source.zarr/*/*/*' +target_position_dirpaths = '/input_target.zarr/*/*/*' +config_filepath = ( + '../mantis/analysis/settings/example_apply_affine_settings.yml' +) +output_dirpath = './test_output.zarr' # sbatch and resource parameters -cpus_per_task = 16 +cpus_per_task = 4 mem_per_cpu = "16G" -time = 40 # minutes -simultaneous_processes_per_node = 5 - -# path handling -labelfree_data_paths = natsorted(glob.glob(labelfree_data_paths)) -lightsheet_data_paths = natsorted(glob.glob(lightsheet_data_paths)) -output_dir = os.path.dirname(output_data_path) -output_paths = utils.get_output_paths(labelfree_data_paths, output_data_path) -click.echo(f"in: {labelfree_data_paths}, out: {output_paths}") -slurm_out_path = str(os.path.join(output_dir, "slurm_output/register-%j.out")) - -# Additional registraion arguments +time = 60 # minutes +partition = 'cpu' +simultaneous_processes_per_node = ( + 8 # number of processes that are run in parallel on a single node +) + +# NOTE: parameters from here and below should not have to be changed +source_position_dirpaths = [ + Path(path) for path in natsorted(glob.glob(source_position_dirpaths)) +] +target_position_dirpaths = [ + Path(path) for path in natsorted(glob.glob(target_position_dirpaths)) +] +output_dirpath = Path(output_dirpath) +config_filepath = Path(config_filepath) + +click.echo(f"in_path: {source_position_dirpaths[0]}, out_path: {output_dirpath}") +slurm_out_path = output_dirpath.parent / "slurm_output" / "register-%j.out" + # Parse from the yaml file -settings = registration_params_from_file(registration_param_path) +settings = yaml_to_model(config_filepath, RegistrationSettings) matrix = np.array(settings.affine_transform_zyx) -output_shape_zyx = tuple(settings.output_shape_zyx) +keep_overhang = settings.keep_overhang + +# Calculate the output voxel size from the input scale and affine transform +with open_ome_zarr(source_position_dirpaths[0]) as source_dataset: + T, C, Z, Y, X = source_dataset.data.shape + source_channel_names = source_dataset.channel_names + source_shape_zyx = source_dataset.data.shape[-3:] + source_voxel_size = source_dataset.scale[-3:] + output_voxel_size = rescale_voxel_size(matrix[:3, :3], source_voxel_size) + +with open_ome_zarr(target_position_dirpaths[0]) as target_dataset: + target_channel_names = target_dataset.channel_names + Z_target, Y_target, X_target = target_dataset.data.shape[-3:] + target_shape_zyx = target_dataset.data.shape[-3:] -# Get the output voxel_size -with open_ome_zarr(lightsheet_data_paths[0]) as light_sheet_position: - voxel_size = tuple(light_sheet_position.scale[-3:]) +click.echo('\nREGISTRATION PARAMETERS:') +click.echo(f'Transformation matrix:\n{matrix}') +click.echo(f'Voxel size: {output_voxel_size}') +# Logic to parse time indices +if settings.time_indices == "all": + time_indices = list(range(T)) +elif isinstance(settings.time_indices, list): + time_indices = settings.time_indices +elif isinstance(settings.time_indices, int): + time_indices = [settings.time_indices] + +output_channel_names = target_channel_names +if target_position_dirpaths != source_position_dirpaths: + output_channel_names += source_channel_names + +if not keep_overhang: + # Find the largest interior rectangle + click.echo('\nFinding largest overlapping volume between source and target datasets') + Z_slice, Y_slice, X_slice = find_overlapping_volume( + source_shape_zyx, target_shape_zyx, matrix + ) + # TODO: start or stop may be None + cropped_target_shape_zyx = ( + Z_slice.stop - Z_slice.start, + Y_slice.stop - Y_slice.start, + X_slice.stop - X_slice.start, + ) + # Overwrite the previous target shape + Z_target, Y_target, X_target = cropped_target_shape_zyx[-3:] + click.echo(f'Shape of cropped output dataset: {target_shape_zyx}\n') +else: + Z_slice, Y_slice, X_slice = ( + slice(0, Z_target), + slice(0, Y_target), + slice(0, X_target), + ) + +output_metadata = { + "shape": (len(time_indices), len(output_channel_names), Z_target, Y_target, X_target), + "chunks": None, + "scale": (1,) * 2 + tuple(output_voxel_size), + "channel_names": output_channel_names, + "dtype": np.float32, +} + +# Create the output zarr mirroring source_position_dirpaths +create_empty_hcs_zarr( + store_path=output_dirpath, + position_keys=[p.parts[-3:] for p in source_position_dirpaths], + **output_metadata, +) + +# Get the affine transformation matrix +# NOTE: add any extra metadata if needed: extra_metadata = { - 'registration': { - 'affine_matrix': matrix.tolist(), - 'pre_affine_90degree_rotations_about_z': settings.pre_affine_90degree_rotations_about_z, + 'affine_transformation': { + 'transform_matrix': matrix.tolist(), } } + affine_transform_args = { 'matrix': matrix, - 'output_shape_zyx': settings.output_shape_zyx, - 'pre_affine_90degree_rotations_about_z': settings.pre_affine_90degree_rotations_about_z, + 'output_shape_zyx': target_shape_zyx, # NOTE: this is the shape of the original target dataset + 'crop_output_slicing': ([Z_slice, Y_slice, X_slice] if not keep_overhang else None), 'extra_metadata': extra_metadata, } -utils.create_empty_zarr( - position_paths=labelfree_data_paths, - output_path=output_data_path, - output_zyx_shape=output_shape_zyx, - chunk_zyx_shape=None, - voxel_size=voxel_size, -) + +copy_n_paste_kwargs = {"czyx_slicing_params": ([Z_slice, Y_slice, X_slice])} # prepare slurm parameters params = SlurmParams( - partition="cpu", + partition=partition, cpus_per_task=cpus_per_task, mem_per_cpu=mem_per_cpu, time=datetime.timedelta(minutes=time), @@ -70,20 +149,46 @@ ) # wrap our utils.process_single_position() function with slurmkit -slurm_process_single_position = slurm_function(utils.process_single_position) +slurm_process_single_position = slurm_function(process_single_position_v2) register_func = slurm_process_single_position( - func=rotate_n_affine_transform, + func=apply_affine_transform, + output_path=output_dirpath, + time_indices=time_indices, num_processes=simultaneous_processes_per_node, **affine_transform_args, ) -# generate an array of jobs by passing the in_path and out_path to slurm wrapped function -register_jobs = [ - submit_function( - register_func, - slurm_params=params, - input_data_path=in_path, - output_path=out_path, - ) - for in_path, out_path in zip(labelfree_data_paths, output_paths) -] +copy_n_paste_func = slurm_process_single_position( + func=copy_n_paste_czyx, + output_path=output_dirpath, + time_indices=time_indices, + num_processes=simultaneous_processes_per_node, + **copy_n_paste_kwargs, +) + +# NOTE: channels will not be processed in parallel +# NOTE: the the source and target datastores may be the same (e.g. Hummingbird datasets) +# apply affine transform to channels in the source datastore that should be registered +# as given in the config file (i.e. settings.source_channel_names) +for input_position_path in source_position_dirpaths: + for channel_name in source_channel_names: + if channel_name in settings.source_channel_names: + submit_function( + register_func, + slurm_params=params, + input_data_path=input_position_path, + input_channel_idx=[source_channel_names.index(channel_name)], + output_channel_idx=[output_channel_names.index(channel_name)], + ) + +# Copy over the channels that were not processed +for input_position_path in target_position_dirpaths: + for channel_name in target_channel_names: + if channel_name not in settings.source_channel_names: + submit_function( + copy_n_paste_func, + slurm_params=params, + input_data_path=input_position_path, + input_channel_idx=[target_channel_names.index(channel_name)], + output_channel_idx=[output_channel_names.index(channel_name)], + ) diff --git a/mantis/analysis/AnalysisSettings.py b/mantis/analysis/AnalysisSettings.py index 1ff58df8..a71428f5 100644 --- a/mantis/analysis/AnalysisSettings.py +++ b/mantis/analysis/AnalysisSettings.py @@ -1,20 +1,21 @@ -from typing import Optional +from typing import Literal, Optional, Union import numpy as np -from pydantic import ConfigDict, PositiveFloat, PositiveInt, validator -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, Extra, NonNegativeInt, PositiveFloat, PositiveInt, validator -config = ConfigDict(extra="forbid") +# All settings classes inherit from MyBaseModel, which forbids extra parameters to guard against typos +class MyBaseModel(BaseModel, extra=Extra.forbid): + pass -@dataclass(config=config) -class DeskewSettings: + +class DeskewSettings(MyBaseModel): pixel_size_um: PositiveFloat ls_angle_deg: PositiveFloat px_to_scan_ratio: Optional[PositiveFloat] = None scan_step_um: Optional[PositiveFloat] = None - keep_overhang: bool = True + keep_overhang: bool = False average_n_slices: PositiveInt = 3 @validator("ls_angle_deg") @@ -28,19 +29,25 @@ def px_to_scan_ratio_check(cls, v): if v is not None: return round(float(v), 3) - def __post_init__(self): - if self.px_to_scan_ratio is None: - if self.scan_step_um is not None: - self.px_to_scan_ratio = round(self.pixel_size_um / self.scan_step_um, 3) + def __init__(self, **data): + if data.get("px_to_scan_ratio") is None: + if data.get("scan_step_um") is not None: + data["px_to_scan_ratio"] = round( + data["pixel_size_um"] / data["scan_step_um"], 3 + ) else: - raise TypeError("px_to_scan_ratio is not valid") + raise ValueError( + "If px_to_scan_ratio is not provided, both pixel_size_um and scan_step_um must be provided" + ) + super().__init__(**data) -@dataclass(config=config) -class RegistrationSettings: +class RegistrationSettings(MyBaseModel): + source_channel_names: list[str] + target_channel_name: str affine_transform_zyx: list - output_shape_zyx: list - pre_affine_90degree_rotations_about_z: Optional[int] = 1 + keep_overhang: bool = False + time_indices: Union[NonNegativeInt, list[NonNegativeInt], Literal["all"]] = "all" @validator("affine_transform_zyx") def check_affine_transform(cls, v): @@ -60,9 +67,3 @@ def check_affine_transform(cls, v): raise ValueError("The array must contain valid numerical values.") return v - - @validator("output_shape_zyx") - def check_output_shape_zyx(cls, v): - if not isinstance(v, list) or len(v) != 3: - raise ValueError("The output shape zyx must be a list of length 3.") - return v diff --git a/mantis/analysis/register.py b/mantis/analysis/register.py new file mode 100644 index 00000000..3db569b8 --- /dev/null +++ b/mantis/analysis/register.py @@ -0,0 +1,357 @@ +from typing import Tuple + +import ants +import largestinteriorrectangle as lir +import matplotlib.pyplot as plt +import numpy as np +import scipy.ndimage + + +def get_3D_rescaling_matrix(start_shape_zyx, scaling_factor_zyx=(1, 1, 1), end_shape_zyx=None): + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + scaling_matrix = np.array( + [ + [scaling_factor_zyx[-3], 0, 0, 0], + [ + 0, + scaling_factor_zyx[-2], + 0, + -center_Y_start * scaling_factor_zyx[-2] + center_Y_end, + ], + [ + 0, + 0, + scaling_factor_zyx[-1], + -center_X_start * scaling_factor_zyx[-1] + center_X_end, + ], + [0, 0, 0, 1], + ] + ) + return scaling_matrix + + +def get_3D_rotation_matrix( + start_shape_zyx: Tuple, angle: float = 0.0, end_shape_zyx: Tuple = None +) -> np.ndarray: + """ + Rotate Transformation Matrix + + Parameters + ---------- + start_shape_zyx : Tuple + Shape of the input + angle : float, optional + Angles of rotation in degrees + end_shape_zyx : Tuple, optional + Shape of output space + + Returns + ------- + np.ndarray + Rotation matrix + """ + # TODO: make this 3D? + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + theta = np.radians(angle) + + rotation_matrix = np.array( + [ + [1, 0, 0, 0], + [ + 0, + np.cos(theta), + -np.sin(theta), + -center_Y_start * np.cos(theta) + + np.sin(theta) * center_X_start + + center_Y_end, + ], + [ + 0, + np.sin(theta), + np.cos(theta), + -center_Y_start * np.sin(theta) + - center_X_start * np.cos(theta) + + center_X_end, + ], + [0, 0, 0, 1], + ] + ) + return rotation_matrix + + +def convert_transform_to_ants(T_numpy: np.ndarray): + """Homogeneous 3D transformation matrix from numpy to ants + + Parameters + ---------- + numpy_transform :4x4 homogenous matrix + + Returns + ------- + Ants transformation matrix object + """ + assert T_numpy.shape == (4, 4) + + T_ants_style = T_numpy[:, :-1].ravel() + T_ants_style[-3:] = T_numpy[:3, -1] + T_ants = ants.new_ants_transform( + transform_type='AffineTransform', + ) + T_ants.set_parameters(T_ants_style) + + return T_ants + + +# def numpy_to_ants_transform_czyx(T_numpy: np.ndarray): +# """Homogeneous 3D transformation matrix from numpy to ants + +# Parameters +# ---------- +# numpy_transform :4x4 homogenous matrix + +# Returns +# ------- +# Ants transformation matrix object +# """ +# assert T_numpy.shape == (5, 5) +# shape = T_numpy.shape +# T_ants_style = T_numpy[:, :-1].ravel() +# T_ants_style[-shape[0] + 1 :] = T_numpy[-shape[0] : -1, -1] +# T_ants = ants.new_ants_transform( +# transform_type='AffineTransform', +# ) +# T_ants.set_parameters(T_ants_style) + +# return T_ants + + +def convert_transform_to_numpy(T_ants): + """ + Convert the ants transformation matrix to numpy 3D homogenous transform + + Modified from Jordao's dexp code + + Parameters + ---------- + T_ants : Ants transfromation matrix object + + Returns + ------- + np.array + Converted Ants to numpy array + + """ + + T_numpy = T_ants.parameters.reshape((3, 4), order="F") + T_numpy[:, :3] = T_numpy[:, :3].transpose() + T_numpy = np.vstack((T_numpy, np.array([0, 0, 0, 1]))) + + # Reference: + # https://sourceforge.net/p/advants/discussion/840261/thread/9fbbaab7/ + # https://github.com/netstim/leaddbs/blob/a2bb3e663cf7fceb2067ac887866124be54aca7d/helpers/ea_antsmat2mat.m + # T = original translation offset from A + # T = T + (I - A) @ centering + + T_numpy[:3, -1] += (np.eye(3) - T_numpy[:3, :3]) @ T_ants.fixed_parameters + + return T_numpy + + +def apply_affine_transform( + zyx_data: np.ndarray, + matrix: np.ndarray, + output_shape_zyx: Tuple, + method='ants', + crop_output_slicing: bool = None, +) -> np.ndarray: + """_summary_ + + Parameters + ---------- + zyx_data : np.ndarray + 3D input array to be transformed + matrix : np.ndarray + 3D Homogenous transformation matrix + output_shape_zyx : Tuple + output target zyx shape + method : str, optional + method to use for transformation, by default 'ants' + crop_output : bool, optional + crop the output to the largest interior rectangle, by default False + + Returns + ------- + np.ndarray + registered zyx data + """ + + Z, Y, X = output_shape_zyx + if crop_output_slicing is not None: + Z_slice, Y_slice, X_slice = crop_output_slicing + Z = Z_slice.stop - Z_slice.start + Y = Y_slice.stop - Y_slice.start + X = X_slice.stop - X_slice.start + + # TODO: based on the signature of this function, it should not be called on 4D array + if zyx_data.ndim == 4: + registered_czyx = np.zeros((zyx_data.shape[0], Z, Y, X), dtype=np.float32) + for c in range(zyx_data.shape[0]): + registered_czyx[c] = apply_affine_transform( + zyx_data[c], + matrix, + output_shape_zyx, + method, + crop_output_slicing, + ) + return registered_czyx + else: + # Convert nans to 0 + zyx_data = np.nan_to_num(zyx_data, nan=0) + + # NOTE: default set to ANTS apply_affine method until we decide we get a benefit from using cupy + # The ants method on CPU is 10x faster than scipy on CPU. Cupy method has not been bencharked vs ANTs + + if method == 'ants': + # The output has to be a ANTImage Object + empty_target_array = np.zeros((output_shape_zyx), dtype=np.float32) + target_zyx_ants = ants.from_numpy(empty_target_array) + + T_ants = convert_transform_to_ants(matrix) + + zyx_data_ants = ants.from_numpy(zyx_data.astype(np.float32)) + registered_zyx = T_ants.apply_to_image( + zyx_data_ants, reference=target_zyx_ants + ).numpy() + + elif method == 'scipy': + registered_zyx = scipy.ndimage.affine_transform(zyx_data, matrix, output_shape_zyx) + + else: + raise ValueError(f'Unknown method {method}') + + # Crop the output to the largest interior rectangle + if crop_output_slicing is not None: + registered_zyx = registered_zyx[Z_slice, Y_slice, X_slice] + + return registered_zyx + + +def find_lir(registered_zyx: np.ndarray, plot: bool = False) -> Tuple: + # Find the lir YX + registered_yx_bool = registered_zyx[registered_zyx.shape[0] // 2].copy() + registered_yx_bool = registered_yx_bool > 0 * 1.0 + rectangle_coords_yx = lir.lir(registered_yx_bool) + + x = rectangle_coords_yx[0] + y = rectangle_coords_yx[1] + width = rectangle_coords_yx[2] + height = rectangle_coords_yx[3] + corner1_xy = (x, y) # Bottom-left corner + corner2_xy = (x + width, y) # Bottom-right corner + corner3_xy = (x + width, y + height) # Top-right corner + corner4_xy = (x, y + height) # Top-left corner + rectangle_xy = np.array((corner1_xy, corner2_xy, corner3_xy, corner4_xy)) + X_slice = slice(rectangle_xy.min(axis=0)[0], rectangle_xy.max(axis=0)[0]) + Y_slice = slice(rectangle_xy.min(axis=0)[1], rectangle_xy.max(axis=0)[1]) + + # Find the lir Z + zyx_shape = registered_zyx.shape + registered_zx_bool = registered_zyx.transpose((2, 0, 1)) > 0 + registered_zx_bool = registered_zx_bool[zyx_shape[0] // 2].copy() + rectangle_coords_zx = lir.lir(registered_zx_bool) + x = rectangle_coords_zx[0] + z = rectangle_coords_zx[1] + width = rectangle_coords_zx[2] + height = rectangle_coords_zx[3] + corner1_zx = (x, z) # Bottom-left corner + corner2_zx = (x + width, z) # Bottom-right corner + corner3_zx = (x + width, z + height) # Top-right corner + corner4_zx = (x, z + height) # Top-left corner + rectangle_zx = np.array((corner1_zx, corner2_zx, corner3_zx, corner4_zx)) + Z_slice = slice(rectangle_zx.min(axis=0)[1], rectangle_zx.max(axis=0)[1]) + + if plot: + rectangle_yx = plt.Polygon( + (corner1_xy, corner2_xy, corner3_xy, corner4_xy), + closed=True, + fill=None, + edgecolor="r", + ) + # Add the rectangle to the plot + fig, ax = plt.subplots(nrows=1, ncols=2) + ax[0].imshow(registered_yx_bool) + ax[0].add_patch(rectangle_yx) + + rectangle_zx = plt.Polygon( + (corner1_zx, corner2_zx, corner3_zx, corner4_zx), + closed=True, + fill=None, + edgecolor="r", + ) + ax[1].imshow(registered_zx_bool) + ax[1].add_patch(rectangle_zx) + plt.savefig("./lir.png") + + return (Z_slice, Y_slice, X_slice) + + +def find_overlapping_volume( + input_zyx_shape: Tuple, + target_zyx_shape: Tuple, + transformation_matrix: np.ndarray, + method: str = 'LIR', + plot: bool = False, +) -> Tuple: + """ + Find the overlapping rectangular volume after registration of two 3D datasets + + Parameters + ---------- + input_zyx_shape : Tuple + shape of input array + target_zyx_shape : Tuple + shape of target array + transformation_matrix : np.ndarray + affine transformation matrix + method : str, optional + method of finding the overlapping volume, by default 'LIR' + + Returns + ------- + Tuple + ZYX slices of the overlapping volume after registration + + """ + + # Make dummy volumes + img1 = np.ones(tuple(input_zyx_shape), dtype=np.float32) + img2 = np.ones(tuple(target_zyx_shape), dtype=np.float32) + + # Conver to ants objects + target_zyx_ants = ants.from_numpy(img2.astype(np.float32)) + zyx_data_ants = ants.from_numpy(img1.astype(np.float32)) + + ants_composed_matrix = convert_transform_to_ants(transformation_matrix) + + # Apply affine + registered_zyx = ants_composed_matrix.apply_to_image( + zyx_data_ants, reference=target_zyx_ants + ) + + if method == 'LIR': + print('Starting Largest interior rectangle (LIR) search') + Z_slice, Y_slice, X_slice = find_lir(registered_zyx.numpy(), plot=plot) + else: + raise ValueError(f'Unknown method {method}') + + return (Z_slice, Y_slice, X_slice) diff --git a/mantis/analysis/simple_phase_recon.py b/mantis/analysis/scripts/simple_phase_recon.py similarity index 100% rename from mantis/analysis/simple_phase_recon.py rename to mantis/analysis/scripts/simple_phase_recon.py diff --git a/mantis/analysis/settings/example_apply_affine_settings.yml b/mantis/analysis/settings/example_apply_affine_settings.yml index 75863bf7..89935f64 100644 --- a/mantis/analysis/settings/example_apply_affine_settings.yml +++ b/mantis/analysis/settings/example_apply_affine_settings.yml @@ -1,7 +1,9 @@ +source_channel_names: [Phase3D, Orientation, Retardance, Birefringence] +target_channel_name: GFP affine_transform_zyx: - [1.0, 0.0, 0.0, 0.0] - [0.0, 1.0, 0.0, 0.0] - [0.0, 0.0, 1.0, 0.0] - [0.0, 0.0, 0.0, 1.0] -output_shape_zyx: [8, 512, 128] -pre_affine_90degree_rotations_about_z: -1 +keep_overhang: false # Optional, if true datasets will have non-overlapping regions +time_indices: all # Optional, by default all time indices are processed diff --git a/mantis/analysis/stabilize.py b/mantis/analysis/stabilize.py new file mode 100644 index 00000000..1b148248 --- /dev/null +++ b/mantis/analysis/stabilize.py @@ -0,0 +1,133 @@ +import contextlib +import io +import itertools +import multiprocessing as mp + +from functools import partial +from pathlib import Path + +import ants +import click +import numpy as np + +from iohub.ngff import Position, open_ome_zarr + +from mantis.analysis.register import convert_transform_to_ants +from mantis.cli.utils import _check_nan_n_zeros + + +def stabilization_over_time_ants( + position: Position, + output_path: Path, + list_of_shifts: np.ndarray, + input_channel_idx: list, + output_channel_idx: list, + t_idx: int, + c_idx: int, + **kwargs, +) -> None: + """Load a zyx array from a Position object, apply a transformation and save the result to file""" + + click.echo(f"Processing c={c_idx}, t={t_idx}") + tx_shifts = convert_transform_to_ants(list_of_shifts[t_idx]) + + # Process CZYX vs ZYX + if input_channel_idx is not None: + czyx_data = position.data.oindex[t_idx, input_channel_idx] + if not _check_nan_n_zeros(czyx_data): + for c in input_channel_idx: + print(f'czyx_data.shape {czyx_data.shape}') + zyx_data_ants = ants.from_numpy(czyx_data[0]) + registered_zyx = tx_shifts.apply_to_image( + zyx_data_ants, reference=zyx_data_ants + ) + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0].oindex[ + t_idx, output_channel_idx + ] = registered_zyx.numpy() + click.echo(f"Finished Writing.. t={t_idx}") + else: + click.echo(f"Skipping t={t_idx} due to all zeros or nans") + else: + zyx_data = position.data.oindex[t_idx, c_idx] + # Checking if nans or zeros and skip processing + if not _check_nan_n_zeros(zyx_data): + zyx_data_ants = ants.from_numpy(zyx_data) + # Apply transformation + registered_zyx = tx_shifts.apply_to_image(zyx_data_ants, reference=zyx_data_ants) + + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0][t_idx, c_idx] = registered_zyx.numpy() + + click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + else: + click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") + + +def apply_stabilization_over_time_ants( + list_of_shifts: list, + input_data_path: Path, + output_path: Path, + time_indices: list = [0], + input_channel_idx: list = [], + output_channel_idx: list = [], + num_processes: int = 1, + **kwargs, +) -> None: + """Apply stabilization over time""" + # Function to be applied + # Get the reader and writer + click.echo(f"Input data path:\t{input_data_path}") + click.echo(f"Output data path:\t{str(output_path)}") + input_dataset = open_ome_zarr(str(input_data_path)) + stdout_buffer = io.StringIO() + with contextlib.redirect_stdout(stdout_buffer): + input_dataset.print_tree() + click.echo(f" Input data tree: {stdout_buffer.getvalue()}") + + T, C, _, _, _ = input_dataset.data.shape + + # Write the settings into the metadata if existing + # TODO: alternatively we can throw all extra arguments as metadata. + if 'extra_metadata' in kwargs: + # For each dictionary in the nest + with open_ome_zarr(output_path, mode='r+') as output_dataset: + for params_metadata_keys in kwargs['extra_metadata'].keys(): + output_dataset.zattrs['extra_metadata'] = kwargs['extra_metadata'] + + # Loop through (T, C), deskewing and writing as we go + click.echo(f"\nStarting multiprocess pool with {num_processes} processes") + + if input_channel_idx is None or len(input_channel_idx) == 0: + # If C is not empty, use itertools.product with both ranges + _, C, _, _, _ = input_dataset.data.shape + iterable = itertools.product(time_indices, range(C)) + partial_stabilization_over_time_ants = partial( + stabilization_over_time_ants, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + list_of_shifts, + None, + None, + ) + else: + # If C is empty, use only the range for time_indices + iterable = itertools.product(time_indices) + partial_stabilization_over_time_ants = partial( + stabilization_over_time_ants, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + list_of_shifts, + input_channel_idx, + output_channel_idx, + c_idx=0, + ) + + with mp.Pool(num_processes) as p: + p.starmap( + partial_stabilization_over_time_ants, + iterable, + ) + input_dataset.close() diff --git a/mantis/cli/apply_affine.py b/mantis/cli/apply_affine.py index a417ab31..82a85b31 100644 --- a/mantis/cli/apply_affine.py +++ b/mantis/cli/apply_affine.py @@ -1,121 +1,189 @@ -import multiprocessing as mp - from pathlib import Path from typing import List import click import numpy as np -import yaml from iohub import open_ome_zarr -from scipy.ndimage import affine_transform from mantis.analysis.AnalysisSettings import RegistrationSettings -from mantis.cli import utils -from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath - - -def registration_params_from_file(registration_param_path: Path) -> RegistrationSettings: - """Parse the deskewing parameters from the yaml file""" - # Load params - with open(registration_param_path) as file: - raw_settings = yaml.safe_load(file) - settings = RegistrationSettings(**raw_settings) - return settings - - -def rotate_n_affine_transform( - zyx_data, matrix, output_shape_zyx, pre_affine_90degree_rotations_about_z: int = 0 -): - if pre_affine_90degree_rotations_about_z != 0: - rotate_volume = np.rot90( - zyx_data, k=pre_affine_90degree_rotations_about_z, axes=(1, 2) - ) - affine_volume = affine_transform( - rotate_volume, matrix=matrix, output_shape=output_shape_zyx - ) - return affine_volume +from mantis.analysis.register import apply_affine_transform, find_overlapping_volume +from mantis.cli.parsing import ( + config_filepath, + output_dirpath, + source_position_dirpaths, + target_position_dirpaths, +) +from mantis.cli.utils import ( + copy_n_paste_czyx, + create_empty_hcs_zarr, + process_single_position_v2, + yaml_to_model, +) -def apply_affine_to_scale(affine_matrix, input_scale): +def rescale_voxel_size(affine_matrix, input_scale): return np.linalg.norm(affine_matrix, axis=1) * input_scale @click.command() -@input_position_dirpaths() +@source_position_dirpaths() +@target_position_dirpaths() @config_filepath() @output_dirpath() @click.option( "--num-processes", "-j", - default=mp.cpu_count(), - help="Number of cores", + default=1, + help="Number of parallel processes", required=False, type=int, ) def apply_affine( - input_position_dirpaths: List[str], + source_position_dirpaths: List[str], + target_position_dirpaths: List[str], config_filepath: str, output_dirpath: str, num_processes: int, ): """ - Apply an affine transformation to a single position across T and C axes using the pathfile for affine transform to the phase channel + Apply an affine transformation to a single position across T and C axes based on a registration config file >> mantis apply_affine -i ./acq_name_lightsheet_deskewed.zarr/*/*/* -c ./register.yml -o ./acq_name_registerred.zarr """ + # Convert string paths to Path objects output_dirpath = Path(output_dirpath) config_filepath = Path(config_filepath) # Handle single position or wildcard filepath - output_paths = utils.get_output_paths(input_position_dirpaths, output_dirpath) - click.echo(f"List of input_pos:{input_position_dirpaths} output_pos:{output_paths}") + click.echo(f"\nInput positions: {[str(path) for path in source_position_dirpaths]}") + click.echo(f"Output position: {output_dirpath}") # Parse from the yaml file - settings = registration_params_from_file(config_filepath) + settings = yaml_to_model(config_filepath, RegistrationSettings) matrix = np.array(settings.affine_transform_zyx) - output_shape_zyx = tuple(settings.output_shape_zyx) - pre_affine_90degree_rotations_about_z = settings.pre_affine_90degree_rotations_about_z + keep_overhang = settings.keep_overhang # Calculate the output voxel size from the input scale and affine transform - with open_ome_zarr(input_position_dirpaths[0]) as input_dataset: - output_voxel_size = apply_affine_to_scale(matrix[:3, :3], input_dataset.scale[-3:]) + with open_ome_zarr(source_position_dirpaths[0]) as source_dataset: + T, C, Z, Y, X = source_dataset.data.shape + source_channel_names = source_dataset.channel_names + source_shape_zyx = source_dataset.data.shape[-3:] + source_voxel_size = source_dataset.scale[-3:] + output_voxel_size = rescale_voxel_size(matrix[:3, :3], source_voxel_size) + + with open_ome_zarr(target_position_dirpaths[0]) as target_dataset: + target_channel_names = target_dataset.channel_names + Z_target, Y_target, X_target = target_dataset.data.shape[-3:] + target_shape_zyx = target_dataset.data.shape[-3:] click.echo('\nREGISTRATION PARAMETERS:') - click.echo(f'Affine transform: {matrix}') + click.echo(f'Transformation matrix:\n{matrix}') click.echo(f'Voxel size: {output_voxel_size}') - utils.create_empty_zarr( - position_paths=input_position_dirpaths, - output_path=output_dirpath, - output_zyx_shape=output_shape_zyx, - chunk_zyx_shape=None, - voxel_size=tuple(output_voxel_size), + # Logic to parse time indices + if settings.time_indices == "all": + time_indices = list(range(T)) + elif isinstance(settings.time_indices, list): + time_indices = settings.time_indices + elif isinstance(settings.time_indices, int): + time_indices = [settings.time_indices] + + output_channel_names = target_channel_names + if target_position_dirpaths != source_position_dirpaths: + output_channel_names += source_channel_names + + if not keep_overhang: + # Find the largest interior rectangle + click.echo('\nFinding largest overlapping volume between source and target datasets') + Z_slice, Y_slice, X_slice = find_overlapping_volume( + source_shape_zyx, target_shape_zyx, matrix + ) + # TODO: start or stop may be None + cropped_target_shape_zyx = ( + Z_slice.stop - Z_slice.start, + Y_slice.stop - Y_slice.start, + X_slice.stop - X_slice.start, + ) + # Overwrite the previous target shape + Z_target, Y_target, X_target = cropped_target_shape_zyx[-3:] + click.echo(f'Shape of cropped output dataset: {target_shape_zyx}\n') + else: + Z_slice, Y_slice, X_slice = ( + slice(0, Z_target), + slice(0, Y_target), + slice(0, X_target), + ) + + output_metadata = { + "shape": (len(time_indices), len(output_channel_names), Z_target, Y_target, X_target), + "chunks": None, + "scale": (1,) * 2 + tuple(output_voxel_size), + "channel_names": output_channel_names, + "dtype": np.float32, + } + + # Create the output zarr mirroring source_position_dirpaths + create_empty_hcs_zarr( + store_path=output_dirpath, + position_keys=[p.parts[-3:] for p in source_position_dirpaths], + **output_metadata, ) # Get the affine transformation matrix + # NOTE: add any extra metadata if needed: extra_metadata = { 'affine_transformation': { - 'affine_matrix': matrix.tolist(), - 'pre_affine_90degree_rotations_about_z': pre_affine_90degree_rotations_about_z, + 'transform_matrix': matrix.tolist(), } } + affine_transform_args = { 'matrix': matrix, - 'output_shape_zyx': settings.output_shape_zyx, - 'pre_affine_90degree_rotations_about_z': pre_affine_90degree_rotations_about_z, + 'output_shape_zyx': target_shape_zyx, # NOTE: this is the shape of the original target dataset + 'crop_output_slicing': ([Z_slice, Y_slice, X_slice] if not keep_overhang else None), 'extra_metadata': extra_metadata, } - # Loop over positions - for input_position_path, output_position_path in zip( - input_position_dirpaths, output_paths - ): - utils.process_single_position( - rotate_n_affine_transform, - input_data_path=input_position_path, - output_path=output_position_path, - num_processes=num_processes, - **affine_transform_args, - ) + copy_n_paste_kwargs = {"czyx_slicing_params": ([Z_slice, Y_slice, X_slice])} + + # NOTE: channels will not be processed in parallel + # NOTE: the the source and target datastores may be the same (e.g. Hummingbird datasets) + + # apply affine transform to channels in the source datastore that should be registered + # as given in the config file (i.e. settings.source_channel_names) + for input_position_path in source_position_dirpaths: + for channel_name in source_channel_names: + if channel_name in settings.source_channel_names: + process_single_position_v2( + apply_affine_transform, + input_data_path=input_position_path, # source store + output_path=output_dirpath, + time_indices=time_indices, + input_channel_idx=[source_channel_names.index(channel_name)], + output_channel_idx=[output_channel_names.index(channel_name)], + num_processes=num_processes, # parallel processing over time + **affine_transform_args, + ) + + # crop all channels that are not being registered and save them in the output zarr store + # Note: when target and source datastores are the same we don't process channels which + # were already registered in the previous step + for input_position_path in target_position_dirpaths: + for channel_name in target_channel_names: + if channel_name not in settings.source_channel_names: + process_single_position_v2( + copy_n_paste_czyx, + input_data_path=input_position_path, # target store + output_path=output_dirpath, + time_indices=time_indices, + input_channel_idx=[target_channel_names.index(channel_name)], + output_channel_idx=[output_channel_names.index(channel_name)], + num_processes=num_processes, + **copy_n_paste_kwargs, + ) + + +if __name__ == "__main__": + apply_affine() diff --git a/mantis/cli/deskew.py b/mantis/cli/deskew.py index 638258d7..eb49da59 100644 --- a/mantis/cli/deskew.py +++ b/mantis/cli/deskew.py @@ -1,11 +1,9 @@ import multiprocessing as mp -from dataclasses import asdict from pathlib import Path from typing import List import click -import yaml from iohub.ngff import open_ome_zarr @@ -13,17 +11,7 @@ from mantis.analysis.deskew import deskew_data, get_deskewed_data_shape from mantis.cli import utils from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath - - -# TODO: consider refactoring to utils -def deskew_params_from_file(deskew_param_path: Path) -> DeskewSettings: - """Parse the deskewing parameters from the yaml file""" - # Load params - with open(deskew_param_path) as file: - raw_settings = yaml.safe_load(file) - settings = DeskewSettings(**raw_settings) - click.echo(f"Deskewing parameters: {asdict(settings)}") - return settings +from mantis.cli.utils import yaml_to_model @click.command() @@ -62,7 +50,7 @@ def deskew( # Load the first position to infer dataset information with open_ome_zarr(str(input_position_dirpaths[0]), mode="r") as input_dataset: T, C, Z, Y, X = input_dataset.data.shape - settings = deskew_params_from_file(config_filepath) + settings = yaml_to_model(config_filepath, DeskewSettings) deskewed_shape, voxel_size = get_deskewed_data_shape( (Z, Y, X), settings.ls_angle_deg, @@ -86,7 +74,7 @@ def deskew( 'px_to_scan_ratio': settings.px_to_scan_ratio, 'keep_overhang': settings.keep_overhang, 'average_n_slices': settings.average_n_slices, - 'extra_metadata': {'deskew': asdict(settings)}, + 'extra_metadata': {'deskew': settings.dict()}, } # Loop over positions diff --git a/mantis/cli/estimate_affine.py b/mantis/cli/estimate_affine.py index 2aafca3a..bdaacba0 100644 --- a/mantis/cli/estimate_affine.py +++ b/mantis/cli/estimate_affine.py @@ -1,159 +1,162 @@ -import os - -from dataclasses import asdict - +import ants import click import napari import numpy as np -import scipy -import yaml from iohub import open_ome_zarr -from skimage.transform import SimilarityTransform +from iohub.reader import print_info +from skimage.transform import EuclideanTransform from waveorder.focus import focus_from_transverse_band from mantis.analysis.AnalysisSettings import RegistrationSettings +from mantis.analysis.register import ( + convert_transform_to_numpy, + get_3D_rescaling_matrix, + get_3D_rotation_matrix, +) from mantis.cli.parsing import ( - labelfree_position_dirpaths, - lightsheet_position_dirpaths, output_filepath, + source_position_dirpaths, + target_position_dirpaths, ) +from mantis.cli.utils import model_to_yaml # TODO: see if at some point these globals should be hidden or exposed. -FOCUS_SLICE_ROI_SIDE = 150 -NA_DETECTION_PHASE = 1.35 -NA_DETECTION_FLUOR = 1.35 -WAVELENGTH_EMISSION_PHASE_CHANNEL = 0.45 # [um] -WAVELENGTH_EMISSION_FLUOR_CHANNEL = 0.6 # [um] - -# TODO:the current pipeline always assumes we register to fluoresence mcherry/mScarlet channel so it will change the colormaps to magenta +NA_DETECTION_SOURCE = 1.35 +NA_DETECTION_TARGET = 1.35 +WAVELENGTH_EMISSION_SOURCE_CHANNEL = 0.45 # in um +WAVELENGTH_EMISSION_TARGET_CHANNEL = 0.6 # in um +FOCUS_SLICE_ROI_WIDTH = 150 # size of central ROI used to find focal slice @click.command() -@labelfree_position_dirpaths() -@lightsheet_position_dirpaths() +@source_position_dirpaths() +@target_position_dirpaths() @output_filepath() -@click.option( - "--pre-affine-90degree-rotations-about-z", - "-k", - default=1, - help="Pre-affine 90degree rotations about z", - required=False, - type=int, -) -def estimate_phase_to_fluor_affine( - labelfree_position_dirpaths, - lightsheet_position_dirpaths, - output_filepath, - pre_affine_90degree_rotations_about_z, -): +def estimate_affine(source_position_dirpaths, target_position_dirpaths, output_filepath): """ - Estimate the affine transform between two channels (source channel and target channel) by manual inputs. + Estimate the affine transform between a source (i.e. moving) and a target (i.e. + fixed) image by selecting corresponding points in each. - mantis estimate-phase-to-fluor-affine -lf ./acq_name_labelfree_reconstructed.zarr/0/0/0 -ls ./acq_name_lightsheet_deskewed.zarr/0/0/0 -o ./register.yml + mantis estimate-affine + -s ./acq_name_labelfree_reconstructed.zarr/0/0/0 + -t ./acq_name_lightsheet_deskewed.zarr/0/0/0 + -o ./output.yml """ - assert str(output_filepath).endswith(('.yaml', '.yml')), "Output file must be a YAML file." - # Get a napari viewer() - viewer = napari.Viewer() + click.echo("\nTarget channel INFO:") + print_info(target_position_dirpaths[0], verbose=False) + click.echo("\nSource channel INFO:") + print_info(source_position_dirpaths[0], verbose=False) - print("Getting dataset info") - print("\n phase channel INFO:") - os.system(f"iohub info {labelfree_position_dirpaths[0]}") - print("\n fluorescence channel INFO:") - os.system(f"iohub info {lightsheet_position_dirpaths[0]} ") + click.echo() # prints empty line + target_channel_index = int(input("Enter target channel index: ")) + source_channel_index = int(input("Enter source channel index: ")) + pre_affine_90degree_rotations_about_z = int( + input("Rotate the source channel by 90 degrees? (0, 1, or -1): ") + ) - phase_channel_idx = int(input("Enter phase_channel index to process: ")) - fluor_channel_idx = int(input("Enter fluor_channel index to process: ")) + # Display volumes rescaled + with open_ome_zarr(source_position_dirpaths[0], mode="r") as source_channel_position: + source_channels = source_channel_position.channel_names + source_channel_name = source_channels[source_channel_index] + source_channel_volume = source_channel_position[0][0, source_channel_index] + source_channel_voxel_size = source_channel_position.scale[-3:] + + with open_ome_zarr(target_position_dirpaths[0], mode="r") as target_channel_position: + target_channel_name = target_channel_position.channel_names[target_channel_index] + target_channel_volume = target_channel_position[0][0, target_channel_index] + target_channel_voxel_size = target_channel_position.scale[-3:] + + # Find the infocus slice + source_channel_Z, source_channel_Y, source_channel_X = source_channel_volume.shape[-3:] + target_channel_Z, target_channel_Y, target_channel_X = target_channel_volume.shape[-3:] + + focus_source_channel_idx = focus_from_transverse_band( + source_channel_volume[ + :, + source_channel_Y // 2 + - FOCUS_SLICE_ROI_WIDTH : source_channel_Y // 2 + + FOCUS_SLICE_ROI_WIDTH, + source_channel_X // 2 + - FOCUS_SLICE_ROI_WIDTH : source_channel_X // 2 + + FOCUS_SLICE_ROI_WIDTH, + ], + NA_det=NA_DETECTION_SOURCE, + lambda_ill=WAVELENGTH_EMISSION_SOURCE_CHANNEL, + pixel_size=source_channel_voxel_size[-1], + ) - click.echo("Loading data and estimating best focus plane...") + focus_target_channel_idx = focus_from_transverse_band( + target_channel_volume[ + :, + target_channel_Y // 2 + - FOCUS_SLICE_ROI_WIDTH : target_channel_Y // 2 + + FOCUS_SLICE_ROI_WIDTH, + target_channel_X // 2 + - FOCUS_SLICE_ROI_WIDTH : target_channel_X // 2 + + FOCUS_SLICE_ROI_WIDTH, + ], + NA_det=NA_DETECTION_TARGET, + lambda_ill=WAVELENGTH_EMISSION_TARGET_CHANNEL, + pixel_size=target_channel_voxel_size[-1], + ) - # Display volumes rescaled - with open_ome_zarr(labelfree_position_dirpaths[0], mode="r") as phase_channel_position: - phase_channel_str = phase_channel_position.channel_names[phase_channel_idx] - phase_channel_volume = phase_channel_position[0][0, phase_channel_idx] - phase_channel_Z, phase_channel_Y, phase_channel_X = phase_channel_volume.shape - # Get the voxel dimensions in sample space - ( - z_sample_space_phase_channel, - y_sample_space_phase_channel, - x_sample_space_phase_channel, - ) = phase_channel_position.scale[-3:] - - # Find the infocus slice - focus_phase_channel_idx = focus_from_transverse_band( - phase_channel_position[0][ - 0, - phase_channel_idx, - :, - phase_channel_Y // 2 - - FOCUS_SLICE_ROI_SIDE : phase_channel_Y // 2 - + FOCUS_SLICE_ROI_SIDE, - phase_channel_X // 2 - - FOCUS_SLICE_ROI_SIDE : phase_channel_X // 2 - + FOCUS_SLICE_ROI_SIDE, - ], - NA_det=NA_DETECTION_PHASE, - lambda_ill=WAVELENGTH_EMISSION_PHASE_CHANNEL, - pixel_size=x_sample_space_phase_channel, - plot_path="./best_focus_phase.svg", + click.echo() + if focus_source_channel_idx not in (0, source_channel_Z - 1): + click.echo(f"Best source channel focus slice: {focus_source_channel_idx}") + else: + focus_source_channel_idx = source_channel_Z // 2 + click.echo( + f"Could not determine best source channel focus slice, using {focus_source_channel_idx}" ) - click.echo(f"Best focus phase z_idx: {focus_phase_channel_idx}") - - with open_ome_zarr(lightsheet_position_dirpaths[0], mode="r") as fluor_channel_position: - fluor_channel_str = fluor_channel_position.channel_names[fluor_channel_idx] - fluor_channel_volume = fluor_channel_position[0][0, fluor_channel_idx] - fluor_channel_Z, fluor_channel_Y, fluor_channel_X = fluor_channel_volume.shape - # Get the voxel dimension in sample space - ( - z_sample_space_fluor_channel, - y_sample_space_fluor_channel, - x_sample_space_fluor_channel, - ) = fluor_channel_position.scale[-3:] - - # Finding the infocus plane - focus_fluor_channel_idx = focus_from_transverse_band( - fluor_channel_position[0][ - 0, - fluor_channel_idx, - :, - fluor_channel_Y // 2 - - FOCUS_SLICE_ROI_SIDE : fluor_channel_Y // 2 - + FOCUS_SLICE_ROI_SIDE, - fluor_channel_X // 2 - - FOCUS_SLICE_ROI_SIDE : fluor_channel_X // 2 - + FOCUS_SLICE_ROI_SIDE, - ], - NA_det=NA_DETECTION_FLUOR, - lambda_ill=WAVELENGTH_EMISSION_FLUOR_CHANNEL, - pixel_size=x_sample_space_fluor_channel, - plot_path="./best_focus_fluor.svg", + + if focus_target_channel_idx not in (0, target_channel_Z - 1): + click.echo(f"Best target channel focus slice: {focus_target_channel_idx}") + else: + focus_target_channel_idx = target_channel_Z // 2 + click.echo( + f"Could not determine best target channel focus slice, using {focus_target_channel_idx}" ) - click.echo(f"Best focus fluor z_idx: {focus_fluor_channel_idx}") # Calculate scaling factors for displaying data - scaling_factor_z = z_sample_space_phase_channel / z_sample_space_fluor_channel - scaling_factor_yx = x_sample_space_phase_channel / x_sample_space_fluor_channel - + scaling_factor_z = source_channel_voxel_size[-3] / target_channel_voxel_size[-3] + scaling_factor_yx = source_channel_voxel_size[-1] / target_channel_voxel_size[-1] + click.echo( + f"Z scaling factor: {scaling_factor_z:.3f}; XY scaling factor: {scaling_factor_yx:.3f}\n" + ) # Add layers to napari with and transform # Rotate the image if needed here - phase_channel_volume_rotated = np.rot90( - phase_channel_volume, k=pre_affine_90degree_rotations_about_z, axes=(1, 2) - ) - layer_phase_channel = viewer.add_image( - phase_channel_volume_rotated, name=phase_channel_str + + # Convert to ants objects + source_zyx_ants = ants.from_numpy(source_channel_volume.astype(np.float32)) + target_zyx_ants = ants.from_numpy(target_channel_volume.astype(np.float32)) + + scaling_affine = get_3D_rescaling_matrix( + (target_channel_Z, target_channel_Y, target_channel_X), + (scaling_factor_z, scaling_factor_yx, scaling_factor_yx), + (target_channel_Z, target_channel_Y, target_channel_X), ) - layer_fluor_channel = viewer.add_image(fluor_channel_volume, name=fluor_channel_str) - layer_phase_channel.scale = (scaling_factor_z, scaling_factor_yx, scaling_factor_yx) - Z_rot, Y_rot, X_rot = phase_channel_volume_rotated.shape - layer_fluor_channel.translate = ( - 0, - 0, - phase_channel_Y * scaling_factor_yx, + rotate90_affine = get_3D_rotation_matrix( + (source_channel_Z, source_channel_Y, source_channel_X), + 90 * pre_affine_90degree_rotations_about_z, + (target_channel_Z, target_channel_Y, target_channel_X), ) + compound_affine = scaling_affine @ rotate90_affine + + # NOTE: these two functions are key to pass the function properly to ANTs + compound_affine_ants_style = compound_affine[:, :-1].ravel() + compound_affine_ants_style[-3:] = compound_affine[:3, -1] + + # Ants affine transforms + tx_manual = ants.new_ants_transform() + tx_manual.set_parameters(compound_affine_ants_style) + tx_manual = tx_manual.invert() + source_zxy_pre_reg = tx_manual.apply_to_image(source_zyx_ants, reference=target_zyx_ants) + + # Get a napari viewer + viewer = napari.Viewer() - # %% - # Manual annotation of features COLOR_CYCLE = [ "white", "cyan", @@ -165,22 +168,49 @@ def estimate_phase_to_fluor_affine( "magenta", ] + viewer.add_image(target_channel_volume, name=target_channel_name) + points_target_channel = viewer.add_points( + ndim=3, name=f"pts_{target_channel_name}", size=50, face_color=COLOR_CYCLE[0] + ) + + viewer.add_image( + source_zxy_pre_reg.numpy(), + name=source_channel_name, + blending='additive', + colormap='bop blue', + ) + points_source_channel = viewer.add_points( + ndim=3, name=f"pts_{source_channel_name}", size=50, face_color=COLOR_CYCLE[0] + ) + + # setup viewer + viewer.layers.selection.active = points_source_channel + viewer.grid.enabled = False + viewer.grid.stride = 2 + viewer.grid.shape = (-1, 2) + points_source_channel.mode = "add" + points_target_channel.mode = "add" + + # Manual annotation of features def next_on_click(layer, event, in_focus): if layer.mode == "add": - if layer is points_phase_channel: - next_layer = points_fluor_channel + if layer is points_source_channel: + next_layer = points_target_channel # Change slider value if len(next_layer.data) < 1: - prev_step_fluor_channel = ( + prev_step_target_channel = ( in_focus[1], 0, 0, ) else: - prev_step_fluor_channel = (next_layer.data[-1][0] + 1, 0, 0) + prev_step_target_channel = (next_layer.data[-1][0], 0, 0) # Add a point to the active layer - cursor_position = np.array(viewer.cursor.position) - layer.add(cursor_position) + # viewer.cursor.position is return in world coordinates + # point position needs to be converted to data coordinates before plotting + # on top of layer + cursor_position_data_coords = layer.world_to_data(viewer.cursor.position) + layer.add(cursor_position_data_coords) # Change the colors current_index = COLOR_CYCLE.index(layer.current_face_color) @@ -192,22 +222,22 @@ def next_on_click(layer, event, in_focus): next_layer.mode = "add" layer.selected_data = {} viewer.layers.selection.active = next_layer - viewer.dims.current_step = prev_step_fluor_channel + viewer.dims.current_step = prev_step_target_channel else: - next_layer = points_phase_channel + next_layer = points_source_channel # Change slider value if len(next_layer.data) < 1: - prev_step_phase_channel = ( + prev_step_source_channel = ( in_focus[0] * scaling_factor_z, 0, 0, ) else: # TODO: this +1 is not clear to me? - prev_step_phase_channel = (next_layer.data[-1][0] + 1, 0, 0) - cursor_position = np.array(viewer.cursor.position) - layer.add(cursor_position) + prev_step_source_channel = (next_layer.data[-1][0], 0, 0) + cursor_position_data_coords = layer.world_to_data(viewer.cursor.position) + layer.add(cursor_position_data_coords) # Change the colors current_index = COLOR_CYCLE.index(layer.current_face_color) next_index = (current_index + 1) % len(COLOR_CYCLE) @@ -218,22 +248,10 @@ def next_on_click(layer, event, in_focus): next_layer.mode = "add" layer.selected_data = {} viewer.layers.selection.active = next_layer - viewer.dims.current_step = prev_step_phase_channel - - # Create the first points layer - points_phase_channel = viewer.add_points( - ndim=3, name=f"pts_{phase_channel_str}", size=50, face_color=COLOR_CYCLE[0] - ) - points_fluor_channel = viewer.add_points( - ndim=3, name=f"pts_{fluor_channel_str}", size=50, face_color=COLOR_CYCLE[0] - ) + viewer.dims.current_step = prev_step_source_channel - # Create the second points layer - viewer.layers.selection.active = points_phase_channel - points_phase_channel.mode = "add" - points_fluor_channel.mode = "add" # Bind the mouse click callback to both point layers - in_focus = (focus_phase_channel_idx, focus_fluor_channel_idx) + in_focus = (focus_source_channel_idx, focus_target_channel_idx) def lambda_callback(layer, event): return next_on_click(layer=layer, event=event, in_focus=in_focus) @@ -243,123 +261,94 @@ def lambda_callback(layer, event): 0, 0, ) - points_phase_channel.mouse_drag_callbacks.append(lambda_callback) - points_fluor_channel.mouse_drag_callbacks.append(lambda_callback) + points_source_channel.mouse_drag_callbacks.append(lambda_callback) + points_target_channel.mouse_drag_callbacks.append(lambda_callback) input( - "\n Add at least three points in the two channels by sequentially clicking a feature on phase channel and its corresponding feature in fluorescence channel. Press when done..." + "Add at least three points in the two channels by sequentially clicking " + + "on a feature in the source channel and its corresponding feature in target channel. " + + "Select grid mode if you prefer side-by-side view. " + + "Press when done..." ) # Get the data from the layers - pts_phase_channel = points_phase_channel.data - pts_fluor_channel = points_fluor_channel.data - - # De-apply the scaling and translation that was applied in the viewer - pts_phase_channel[:, 1:] /= scaling_factor_yx - pts_fluor_channel[:, 2] -= ( - phase_channel_Y * scaling_factor_yx - ) # subtract the translation offset for display + pts_source_channel = points_source_channel.data + pts_target_channel = points_target_channel.data # Estimate the affine transform between the points xy to make sure registration is good - transform = SimilarityTransform() - transform.estimate(pts_phase_channel[:, 1:], pts_fluor_channel[:, 1:]) + transform = EuclideanTransform() + transform.estimate(pts_source_channel[:, 1:], pts_target_channel[:, 1:]) yx_points_transformation_matrix = transform.params - z_shift = np.array([1, 0, 0, 0]) + z_translation = pts_target_channel[0, 0] - pts_source_channel[0, 0] + + z_scale_translate_matrix = np.array([[1, 0, 0, z_translation]]) # 2D to 3D matrix - zyx_affine_transform = np.vstack( - (z_shift, np.insert(yx_points_transformation_matrix, 0, 0, axis=1)) + euclidian_transform = np.vstack( + (z_scale_translate_matrix, np.insert(yx_points_transformation_matrix, 0, 0, axis=1)) ) # Insert 0 in the third entry of each row - zyx_affine_transform = np.linalg.inv(zyx_affine_transform) - # Get the transformation matrix - output_shape_zyx = (fluor_channel_Z, fluor_channel_Y, fluor_channel_X) - - # Demo: apply the affine transform to the image at the z-slice where all the points are located - aligned_image = scipy.ndimage.affine_transform( - phase_channel_volume_rotated[ - int(np.ceil(pts_phase_channel[0, 0])) : int(np.ceil(pts_phase_channel[0, 0])) + 1 - ], - zyx_affine_transform, - output_shape=(1, fluor_channel_Y, fluor_channel_X), - ) - viewer.add_image( - fluor_channel_position[0][ - 0, - fluor_channel_idx, - int(np.ceil(pts_fluor_channel[0, 0])) : int(np.ceil(pts_fluor_channel[0, 0])) + 1, - ], - name=f"middle_plane_{fluor_channel_str}", - colormap="magenta", - ) - print( - 'Showing registered pair (phase and fluorescence) with pseudo colored fluorescence in magenta' - ) - viewer.add_image(aligned_image, name=f"registered_{phase_channel_str}", opacity=0.5) - viewer.layers.remove(f"pts_{phase_channel_str}") - viewer.layers.remove(f"pts_{fluor_channel_str}") - viewer.layers[fluor_channel_str].visible = False - viewer.layers[phase_channel_str].visible = False - viewer.dims.current_step = (0, 0, 0) # Return to slice 0 - - # NOTE: This assumes within a channel will lie in the same plane - # Compute the 3D registration - # Estimate the Similarity Transform (rotation,scaling,translation) - transform = SimilarityTransform() - transform.estimate(pts_phase_channel[:, 1:], pts_fluor_channel[:, 1:]) - yx_points_transformation_matrix = transform.params - z_translation = pts_fluor_channel[0, 0] - pts_phase_channel[0, 0] - z_scale_translate_matrix = np.array([[scaling_factor_z, 0, 0, z_translation]]) - zyx_affine_transform = np.vstack( - (z_scale_translate_matrix, np.insert(yx_points_transformation_matrix, 0, 0, axis=1)) + scaling_affine = get_3D_rescaling_matrix( + (1, target_channel_Y, target_channel_X), + (scaling_factor_z, scaling_factor_yx, scaling_factor_yx), ) + manual_estimated_transform = euclidian_transform @ compound_affine - # Composite of all transforms - zyx_affine_transform = np.linalg.inv(zyx_affine_transform) # phase to fluorescence mapping - print(f"Affine Transform Matrix:\n {zyx_affine_transform}\n") - settings = RegistrationSettings( - affine_transform_zyx=zyx_affine_transform.tolist(), # phase to fluorescence mapping - output_shape_zyx=list(output_shape_zyx), - pre_affine_90degree_rotations_about_z=pre_affine_90degree_rotations_about_z, - ) + # NOTE: these two functions are key to pass the function properly to ANTs + manual_estimated_transform_ants_style = manual_estimated_transform[:, :-1].ravel() + manual_estimated_transform_ants_style[-3:] = manual_estimated_transform[:3, -1] - print(f"Writing registration parameters to {output_filepath}") - with open(output_filepath, "w") as f: - yaml.dump(asdict(settings), f) + # Ants affine transforms + tx_manual = ants.new_ants_transform() + tx_manual.set_parameters(manual_estimated_transform_ants_style) + tx_manual = tx_manual.invert() - # Apply the transformation to 3D volume - flag_apply_3D_transform = input("\n Apply 3D registration *this make some time* (Y/N) :") - if flag_apply_3D_transform == "Y" or flag_apply_3D_transform == "y": - print("Applying 3D Affine Transform...") - # Rotate the image first + source_zxy_manual_reg = tx_manual.apply_to_image( + source_zyx_ants, reference=target_zyx_ants + ) - phase_volume_rotated = np.rot90( - phase_channel_position[0][0, phase_channel_idx], - k=pre_affine_90degree_rotations_about_z, - axes=(1, 2), + click.echo("\nShowing registered source image in magenta") + viewer.grid.enabled = False + viewer.add_image( + source_zxy_manual_reg.numpy(), + name=f"registered_{source_channel_name}", + colormap="magenta", + blending='additive', + ) + viewer.layers.remove(f"pts_{source_channel_name}") + viewer.layers.remove(f"pts_{target_channel_name}") + viewer.layers[source_channel_name].visible = False + + # Ants affine transforms + T_manual_numpy = convert_transform_to_numpy(tx_manual) + click.echo(f'Estimated affine transformation matrix:\n{T_manual_numpy}\n') + + additional_source_channels = source_channels.copy() + additional_source_channels.remove(source_channel_name) + if target_channel_name in additional_source_channels: + additional_source_channels.remove(target_channel_name) + + flag_apply_to_all_channels = 'N' + if len(additional_source_channels) > 0: + flag_apply_to_all_channels = str( + input( + f"Would you like to register these additional source channels: {additional_source_channels}? (y/N): " + ) ) - registered_3D_volume = scipy.ndimage.affine_transform( - phase_volume_rotated, - zyx_affine_transform, - output_shape=output_shape_zyx, - ) - viewer.add_image( - registered_3D_volume, - name=f"registered_volume_{phase_channel_str}", - opacity=1.0, - ) + source_channel_names = [source_channel_name] + if flag_apply_to_all_channels in ('Y', 'y'): + source_channel_names += additional_source_channels - viewer.add_image( - fluor_channel_position[0][0, fluor_channel_idx], - name=f"{fluor_channel_str}", - opacity=0.5, - colormap="magenta", - ) + model = RegistrationSettings( + source_channel_names=source_channel_names, + target_channel_name=target_channel_name, + affine_transform_zyx=T_manual_numpy.tolist(), + ) + click.echo(f"Writing registration parameters to {output_filepath}") + model_to_yaml(model, output_filepath) - viewer.layers[f"registered_{phase_channel_str}"].visible = False - viewer.layers[f"{phase_channel_str}"].visible = False - viewer.layers[f"middle_plane_{fluor_channel_str}"].visible = False - input("\n Displaying registered channels. Press to close...") +if __name__ == "__main__": + estimate_affine() diff --git a/mantis/cli/main.py b/mantis/cli/main.py index 60e46487..e1af5110 100644 --- a/mantis/cli/main.py +++ b/mantis/cli/main.py @@ -2,9 +2,10 @@ from mantis.cli.apply_affine import apply_affine from mantis.cli.deskew import deskew -from mantis.cli.estimate_affine import estimate_phase_to_fluor_affine +from mantis.cli.estimate_affine import estimate_affine from mantis.cli.estimate_bleaching import estimate_bleaching from mantis.cli.estimate_deskew import estimate_deskew +from mantis.cli.optimize_affine import optimize_affine from mantis.cli.run_acquisition import run_acquisition from mantis.cli.update_scale_metadata import update_scale_metadata @@ -26,6 +27,7 @@ def cli(): cli.add_command(estimate_bleaching) cli.add_command(estimate_deskew) cli.add_command(deskew) -cli.add_command(estimate_phase_to_fluor_affine) +cli.add_command(estimate_affine) +cli.add_command(optimize_affine) cli.add_command(apply_affine) cli.add_command(update_scale_metadata) diff --git a/mantis/cli/optimize_affine.py b/mantis/cli/optimize_affine.py new file mode 100644 index 00000000..0ce4da93 --- /dev/null +++ b/mantis/cli/optimize_affine.py @@ -0,0 +1,138 @@ +import ants +import click +import napari +import numpy as np + +from iohub import open_ome_zarr + +from mantis.analysis.AnalysisSettings import RegistrationSettings +from mantis.analysis.register import convert_transform_to_ants, convert_transform_to_numpy +from mantis.cli.parsing import ( + config_filepath, + output_filepath, + source_position_dirpaths, + target_position_dirpaths, +) +from mantis.cli.utils import model_to_yaml, yaml_to_model + +# TODO: maybe a CLI call? +T_IDX = 0 + + +@click.command() +@source_position_dirpaths() +@target_position_dirpaths() +@config_filepath() +@output_filepath() +@click.option( + "--display-viewer", + "-d", + is_flag=True, + help="Display the registered channels in a napari viewer", +) +@click.option( + "--optimizer-verbose", + "-v", + is_flag=True, + help="Show verbose output of optimizer", +) +def optimize_affine( + source_position_dirpaths, + target_position_dirpaths, + config_filepath, + output_filepath, + display_viewer, + optimizer_verbose, +): + """ + Optimize the affine transform between source and target channels using ANTs library. + + mantis optimize-affine -s ./acq_name_virtual_staining_reconstructed.zarr/0/0/0 -t ./acq_name_lightsheet_deskewed.zarr/0/0/0 -c ./transform.yml -o ./optimized_transform.yml -d -v + """ + + settings = yaml_to_model(config_filepath, RegistrationSettings) + + # Load the source volume + with open_ome_zarr(source_position_dirpaths[0]) as source_position: + source_channel_names = source_position.channel_names + # NOTE: using the first channel in the config to register + source_channel_index = source_channel_names.index(settings.source_channel_names[0]) + source_channel_name = source_channel_names[source_channel_index] + source_data_zyx = source_position[0][T_IDX, source_channel_index].astype(np.float32) + + # Load the target volume + with open_ome_zarr(target_position_dirpaths[0]) as target_position: + target_channel_names = target_position.channel_names + target_channel_index = target_channel_names.index(settings.target_channel_name) + target_channel_name = target_channel_names[target_channel_index] + target_channel_zyx = target_position[0][T_IDX, target_channel_index] + + source_zyx_ants = ants.from_numpy(source_data_zyx) + target_zyx_ants = ants.from_numpy(target_channel_zyx.astype(np.float32)) + click.echo( + f"\nOptimizing registration using source channel {source_channel_name} and target channel {target_channel_name}" + ) + + # Affine Transforms + # numpy to ants + T_pre_optimize_numpy = np.array(settings.affine_transform_zyx) + T_pre_optimize = convert_transform_to_ants(T_pre_optimize_numpy) + + # Apply transformation to source prior to optimization of the matrix + source_zyx_pre_optim = T_pre_optimize.apply_to_image( + source_zyx_ants, reference=target_zyx_ants + ) + + click.echo("Running ANTS optimizer...") + # Optimization + tx_opt = ants.registration( + fixed=target_zyx_ants, + moving=source_zyx_pre_optim, + type_of_transform="Similarity", + verbose=optimizer_verbose, + ) + + tx_opt_mat = ants.read_transform(tx_opt["fwdtransforms"][0]) + tx_opt_numpy = convert_transform_to_numpy(tx_opt_mat) + composed_matrix = T_pre_optimize_numpy @ tx_opt_numpy + + # Saving the parameters + click.echo(f"Writing registration parameters to {output_filepath}") + # copy config settings and modify only ones that change + output_settings = settings.copy() + output_settings.affine_transform_zyx = composed_matrix.tolist() + model_to_yaml(output_settings, output_filepath) + + if display_viewer: + composed_matrix_ants = convert_transform_to_ants(composed_matrix) + source_registered = composed_matrix_ants.apply_to_image( + source_zyx_ants, reference=target_zyx_ants + ) + + viewer = napari.Viewer() + source_pre_opt_layer = viewer.add_image( + source_zyx_pre_optim.numpy(), + name="source_pre_optimization", + colormap="cyan", + opacity=0.5, + ) + source_pre_opt_layer.visible = False + + viewer.add_image( + source_registered.numpy(), + name="source_post_optimization", + colormap="cyan", + blending="additive", + ) + viewer.add_image( + target_position[0][0, target_channel_index], + name="target", + colormap="magenta", + blending="additive", + ) + + input("\n Displaying registered channels. Press to close...") + + +if __name__ == "__main__": + optimize_affine() diff --git a/mantis/cli/parsing.py b/mantis/cli/parsing.py index 2527487c..97a79021 100644 --- a/mantis/cli/parsing.py +++ b/mantis/cli/parsing.py @@ -21,40 +21,50 @@ def _validate_and_process_paths(ctx: click.Context, opt: click.Option, value: st return input_paths +def _str_to_path(ctx: click.Context, opt: click.Option, value: str) -> Path: + return Path(value) + + def input_position_dirpaths() -> Callable: def decorator(f: Callable) -> Callable: return click.option( "--input-position-dirpaths", "-i", + required=True, cls=OptionEatAll, type=tuple, callback=_validate_and_process_paths, + help='Paths to input positions, for example: "input.zarr/0/0/0" or "input.zarr/*/*/*"', )(f) return decorator -def labelfree_position_dirpaths() -> Callable: +def source_position_dirpaths() -> Callable: def decorator(f: Callable) -> Callable: return click.option( - "--labelfree-position-dirpaths", - "-lf", + "--source-position-dirpaths", + "-s", + required=True, cls=OptionEatAll, type=tuple, callback=_validate_and_process_paths, + help='Paths to source positions, for example: "source.zarr/0/0/0" or "source.zarr/*/*/*"', )(f) return decorator -def lightsheet_position_dirpaths() -> Callable: +def target_position_dirpaths() -> Callable: def decorator(f: Callable) -> Callable: return click.option( - "--lightsheet-position-dirpaths", - "-ls", + "--target-position-dirpaths", + "-t", + required=True, cls=OptionEatAll, type=tuple, callback=_validate_and_process_paths, + help='Paths to target positions, for example: "target.zarr/0/0/0" or "target.zarr/*/*/*"', )(f) return decorator @@ -67,7 +77,8 @@ def decorator(f: Callable) -> Callable: "-c", required=True, type=click.Path(exists=True, file_okay=True, dir_okay=False), - help="Path to YAML configuration file", + callback=_str_to_path, + help="Path to YAML configuration file.", )(f) return decorator @@ -81,6 +92,7 @@ def decorator(f: Callable) -> Callable: required=True, type=click.Path(exists=False, file_okay=False, dir_okay=True), help="Path to output directory", + callback=_str_to_path, )(f) return decorator diff --git a/mantis/cli/utils.py b/mantis/cli/utils.py index 37876d6f..73d68f80 100644 --- a/mantis/cli/utils.py +++ b/mantis/cli/utils.py @@ -10,11 +10,15 @@ import click import numpy as np +import yaml from iohub.ngff import Position, open_ome_zarr from iohub.ngff_meta import TransformationMeta +from numpy.typing import DTypeLike +from tqdm import tqdm +# TODO: replace this with recOrder recOrder.cli.utils.create_empty_hcs() def create_empty_zarr( position_paths: list[Path], output_path: Path, @@ -83,6 +87,86 @@ def create_empty_zarr( input_dataset.close() +# TODO: convert all code to use this function from now on +def create_empty_hcs_zarr( + store_path: Path, + position_keys: list[Tuple[str]], + channel_names: list[str], + shape: Tuple[int], + chunks: Tuple[int] = None, + scale: Tuple[float] = (1, 1, 1, 1, 1), + dtype: DTypeLike = np.float32, + max_chunk_size_bytes=500e6, +) -> None: + """ + If the plate does not exist, create an empty zarr plate. + If the plate exists, append positions and channels if they are not + already in the plate. + Parameters + ---------- + store_path : Path + hcs plate path + position_keys : list[Tuple[str]] + Position keys, will append if not present in the plate. + e.g. [("A", "1", "0"), ("A", "1", "1")] + shape : Tuple[int] + chunks : Tuple[int] + scale : Tuple[float] + channel_names : list[str] + Channel names, will append if not present in metadata. + dtype : DTypeLike + + Modifying from recOrder + https://github.com/mehta-lab/recOrder/blob/d31ad910abf84c65ba927e34561f916651cbb3e8/recOrder/cli/utils.py#L12 + """ + MAX_CHUNK_SIZE = max_chunk_size_bytes # in bytes + bytes_per_pixel = np.dtype(dtype).itemsize + + # Limiting the chunking to 500MB + if chunks is None: + chunk_zyx_shape = list(shape[-3:]) + # chunk_zyx_shape[-3] > 1 ensures while loop will not stall if single + # XY image is larger than MAX_CHUNK_SIZE + while ( + chunk_zyx_shape[-3] > 1 + and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE + ): + chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int) + chunk_zyx_shape = tuple(chunk_zyx_shape) + + chunks = 2 * (1,) + chunk_zyx_shape + + # Create plate + output_plate = open_ome_zarr( + str(store_path), layout="hcs", mode="a", channel_names=channel_names + ) + + # Create positions + for position_key in position_keys: + position_key_string = "/".join(position_key) + # Check if position is already in the store, if not create it + if position_key_string not in output_plate.zgroup: + position = output_plate.create_position(*position_key) + _ = position.create_zeros( + name="0", + shape=shape, + chunks=chunks, + dtype=dtype, + transform=[TransformationMeta(type="scale", scale=scale)], + ) + else: + position = output_plate[position_key_string] + + # Check if channel_names are already in the store, if not append them + for channel_name in channel_names: + # Read channel names directly from metadata to avoid race conditions + metadata_channel_names = [ + channel.label for channel in position.metadata.omero.channels + ] + if channel_name not in metadata_channel_names: + position.append_channel(channel_name, resize_arrays=True) + + def get_output_paths(input_paths: list[Path], output_zarr_path: Path) -> list[Path]: """Generates a mirrored output path list given an input list of positions""" list_output_path = [] @@ -94,21 +178,74 @@ def get_output_paths(input_paths: list[Path], output_zarr_path: Path) -> list[Pa return list_output_path -def apply_transform_to_zyx_and_save( +def apply_function_to_zyx_and_save( func, position: Position, output_path: Path, t_idx: int, c_idx: int, **kwargs ) -> None: """Load a zyx array from a Position object, apply a transformation and save the result to file""" click.echo(f"Processing c={c_idx}, t={t_idx}") + zyx_data = position[0][t_idx, c_idx] + if _check_nan_n_zeros(zyx_data): + click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") + else: + # Apply function + processed_zyx = func(zyx_data, **kwargs) + + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0][t_idx, c_idx] = processed_zyx + + click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + + +# NOTE WIP +def apply_transform_to_zyx_and_save_v2( + func, + position: Position, + output_path: Path, + input_channel_indices: list[int], + output_channel_indices: list[int], + t_idx: int, + c_idx: int = None, + **kwargs, +) -> None: + """Load a zyx array from a Position object, apply a transformation to CZYX or ZYX and save the result to file""" + click.echo(f"Processing c={c_idx}, t={t_idx}") + + # TODO: temporary fix to slumkit issue + if _is_nested(input_channel_indices): + # print(f'input_channel_indices: {input_channel_indices}') + input_channel_indices = [int(x) for x in input_channel_indices if x.isdigit()] + if _is_nested(output_channel_indices): + # print(f'input_channel_indices: {output_channel_indices}') + output_channel_indices = [int(x) for x in output_channel_indices if x.isdigit()] + click.echo(f'input_channel_indices: {input_channel_indices}') - # Apply transformation - registered_zyx = func(zyx_data, **kwargs) + # Process CZYX vs ZYX + if input_channel_indices is not None: + czyx_data = position.data.oindex[t_idx, input_channel_indices] + if not _check_nan_n_zeros(czyx_data): + transformed_czyx = func(czyx_data, **kwargs) + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0].oindex[t_idx, output_channel_indices] = transformed_czyx + click.echo(f"Finished Writing.. t={t_idx}") + else: + click.echo(f"Skipping t={t_idx} due to all zeros or nans") + else: + zyx_data = position.data.oindex[t_idx, c_idx] + # Checking if nans or zeros and skip processing + if not _check_nan_n_zeros(zyx_data): + # Apply transformation + transformed_zyx = func(zyx_data, **kwargs) - # Write to file - with open_ome_zarr(output_path, mode="r+") as output_dataset: - output_dataset[0][t_idx, c_idx] = registered_zyx + # Write to file + with open_ome_zarr(output_path, mode="r+") as output_dataset: + output_dataset[0][t_idx, c_idx] = transformed_zyx - click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") + else: + click.echo(f"Skipping c={c_idx}, t={t_idx} due to all zeros or nans") def process_single_position( @@ -158,7 +295,7 @@ def process_single_position( with mp.Pool(num_processes) as p: p.starmap( partial( - apply_transform_to_zyx_and_save, + apply_function_to_zyx_and_save, func, input_dataset, str(output_path), @@ -166,3 +303,299 @@ def process_single_position( ), itertools.product(range(T), range(C)), ) + + +# TODO: modifiy how we get the time and channesl like recOrder (isinstance(input, list) or instance(input,int) or all) +def process_single_position_v2( + func, + input_data_path: Path, + output_path: Path, + time_indices: list = [0], + input_channel_idx: list = [], + output_channel_idx: list = [], + num_processes: int = mp.cpu_count(), + **kwargs, +) -> None: + """Register a single position with multiprocessing parallelization over T and C""" + # Function to be applied + click.echo(f"Function to be applied: \t{func}") + + # Get the reader and writer + click.echo(f"Input data path:\t{input_data_path}") + click.echo(f"Output data path:\t{str(output_path)}") + input_dataset = open_ome_zarr(str(input_data_path)) + stdout_buffer = io.StringIO() + with contextlib.redirect_stdout(stdout_buffer): + input_dataset.print_tree() + click.echo(f" Input data tree: {stdout_buffer.getvalue()}") + + # Find time indices + if time_indices == "all": + time_indices = range(input_dataset.data.shape[0]) + elif isinstance(time_indices, list): + time_indices = time_indices + + # Check for invalid times + time_ubound = input_dataset.data.shape[0] - 1 + if np.max(time_indices) > time_ubound: + raise ValueError( + f"time_indices = {time_indices} includes a time index beyond the maximum index of the dataset = {time_ubound}" + ) + + # Check the arguments for the function + all_func_params = inspect.signature(func).parameters.keys() + # Extract the relevant kwargs for the function 'func' + func_args = {} + non_func_args = {} + + for k, v in kwargs.items(): + if k in all_func_params: + func_args[k] = v + else: + non_func_args[k] = v + + # Write the settings into the metadata if existing + if 'extra_metadata' in non_func_args: + # For each dictionary in the nest + with open_ome_zarr(output_path, mode='r+') as output_dataset: + for params_metadata_keys in kwargs['extra_metadata'].keys(): + output_dataset.zattrs['extra_metadata'] = non_func_args['extra_metadata'] + + # Loop through (T, C), deskewing and writing as we go + if input_channel_idx is None or len(input_channel_idx) == 0: + # If C is not empty, use itertools.product with both ranges + _, C, _, _, _ = input_dataset.data.shape + iterable = itertools.product(time_indices, range(C)) + partial_apply_transform_to_zyx_and_save = partial( + apply_transform_to_zyx_and_save_v2, + func, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + input_channel_indices=None, + **func_args, + ) + else: + # If C is empty, use only the range for time_indices + iterable = itertools.product(time_indices) + partial_apply_transform_to_zyx_and_save = partial( + apply_transform_to_zyx_and_save_v2, + func, + input_dataset, + output_path / Path(*input_data_path.parts[-3:]), + input_channel_idx, + output_channel_idx, + c_idx=0, + **func_args, + ) + + click.echo(f"\nStarting multiprocess pool with {num_processes} processes") + with mp.Pool(num_processes) as p: + p.starmap( + partial_apply_transform_to_zyx_and_save, + iterable, + ) + + +def copy_n_paste(zyx_data: np.ndarray, zyx_slicing_params: list) -> np.ndarray: + """ + Load a zyx array and crop given a list of ZYX slices() + + Parameters + ---------- + zyx_data : np.ndarray + data to copy + zyx_slicing_params : list + list of slicing parameters for z,y,x + + Returns + ------- + np.ndarray + crop of the input zyx_data given the slicing parameters + """ + zyx_data = np.nan_to_num(zyx_data, nan=0) + zyx_data_sliced = zyx_data[ + zyx_slicing_params[0], + zyx_slicing_params[1], + zyx_slicing_params[2], + ] + return zyx_data_sliced + + +def copy_n_paste_czyx(czyx_data: np.ndarray, czyx_slicing_params: list) -> np.ndarray: + """ + Load a zyx array and crop given a list of ZYX slices() + + Parameters + ---------- + czyx_data : np.ndarray + data to copy + czyx_slicing_params : list + list of slicing parameters for z,y,x + + Returns + ------- + np.ndarray + crop of the input czyx_data given the slicing parameters + """ + czyx_data_sliced = czyx_data[ + :, + czyx_slicing_params[0], + czyx_slicing_params[1], + czyx_slicing_params[2], + ] + return czyx_data_sliced + + +def append_channels(input_data_path: Path, target_data_path: Path) -> None: + """ + Append channels to a target zarr store + + Parameters + ---------- + input_data_path : Path + input zarr path = /input.zarr + target_data_path : Path + target zarr path = /target.zarr + """ + appending_dataset = open_ome_zarr(input_data_path, mode="r") + appending_channel_names = appending_dataset.channel_names + with open_ome_zarr(target_data_path, mode="r+") as dataset: + target_data_channel_names = dataset.channel_names + num_channels = len(target_data_channel_names) - 1 + print(f"channels in target {target_data_channel_names}") + print(f"adding channels {appending_channel_names}") + for name, position in tqdm(dataset.positions(), desc='Positions'): + for i, appending_channel_idx in enumerate( + tqdm(appending_channel_names, desc='Channel', leave=False) + ): + position.append_channel(appending_channel_idx) + position["0"][:, num_channels + i + 1] = appending_dataset[str(name)][0][:, i] + dataset.print_tree() + appending_dataset.close() + + +def model_to_yaml(model, yaml_path: Path) -> None: + """ + Save a model's dictionary representation to a YAML file. + + Borrowing from recOrder==0.4.0 + + Parameters + ---------- + model : object + The model object to convert to YAML. + yaml_path : Path + The path to the output YAML file. + + Raises + ------ + TypeError + If the `model` object does not have a `dict()` method. + + Notes + ----- + This function converts a model object into a dictionary representation + using the `dict()` method. It removes any fields with None values before + writing the dictionary to a YAML file. + + Examples + -------- + >>> from my_model import MyModel + >>> model = MyModel() + >>> model_to_yaml(model, 'model.yaml') + + """ + yaml_path = Path(yaml_path) + + if not hasattr(model, "dict"): + raise TypeError("The 'model' object does not have a 'dict()' method.") + + model_dict = model.dict() + + # Remove None-valued fields + clean_model_dict = {key: value for key, value in model_dict.items() if value is not None} + + with open(yaml_path, "w+") as f: + yaml.dump(clean_model_dict, f, default_flow_style=False, sort_keys=False) + + +def yaml_to_model(yaml_path: Path, model): + """ + Load model settings from a YAML file and create a model instance. + + Borrowing from recOrder==0.4.0 + + Parameters + ---------- + yaml_path : Path + The path to the YAML file containing the model settings. + model : class + The model class used to create an instance with the loaded settings. + + Returns + ------- + object + An instance of the model class with the loaded settings. + + Raises + ------ + TypeError + If the provided model is not a class or does not have a callable constructor. + FileNotFoundError + If the YAML file specified by `yaml_path` does not exist. + + Notes + ----- + This function loads model settings from a YAML file using `yaml.safe_load()`. + It then creates an instance of the provided `model` class using the loaded settings. + + Examples + -------- + >>> from my_model import MyModel + >>> model = yaml_to_model('model.yaml', MyModel) + + """ + yaml_path = Path(yaml_path) + + if not callable(getattr(model, "__init__", None)): + raise TypeError("The provided model must be a class with a callable constructor.") + + try: + with open(yaml_path, "r") as file: + raw_settings = yaml.safe_load(file) + except FileNotFoundError: + raise FileNotFoundError(f"The YAML file '{yaml_path}' does not exist.") + + return model(**raw_settings) + + +def _is_nested(lst): + return any(isinstance(i, list) for i in lst) or any(isinstance(i, str) for i in lst) + + +def _check_nan_n_zeros(input_array): + """ + Checks if any of the channels are all zeros or nans and returns true + """ + if len(input_array.shape) == 3: + # Check if all the values are zeros or nans + if np.all(input_array == 0) or np.all(np.isnan(input_array)): + # Return true + return True + elif len(input_array.shape) == 4: + # Get the number of channels + num_channels = input_array.shape[0] + # Loop through the channels + for c in range(num_channels): + # Get the channel + zyx_array = input_array[c, :, :, :] + + # Check if all the values are zeros or nans + if np.all(zyx_array == 0) or np.all(np.isnan(zyx_array)): + # Return true + return True + else: + raise ValueError("Input array must be 3D or 4D") + + # Return false + return False diff --git a/mantis/tests/conftest.py b/mantis/tests/conftest.py index d36d7a2b..9d0da86d 100644 --- a/mantis/tests/conftest.py +++ b/mantis/tests/conftest.py @@ -56,11 +56,37 @@ def example_plate(tmp_path): plate_path, layout="hcs", mode="w", - channel_names=["GFP", "RFP"], + channel_names=["GFP", "RFP", "Phase3D", "Orientation", "Retardance", "Birefringence"], ) for row, col, fov in position_list: position = plate_dataset.create_position(row, col, fov) - position.create_zeros("0", (3, 2, 4, 5, 6), dtype=np.uint16) + position["0"] = np.random.uniform(0.0, 255.0, size=(3, 6, 4, 5, 6)).astype(np.float32) + + yield plate_path, plate_dataset + + +@pytest.fixture(scope="function") +def example_plate_2(tmp_path): + plate_path = tmp_path / "plate.zarr" + position_list = ( + ("A", "1", "0"), + ("B", "1", "0"), + ("B", "2", "0"), + ) + + # Generate input dataset + plate_dataset = open_ome_zarr( + plate_path, + layout="hcs", + mode="w", + channel_names=["GFP", "RFP"], + ) + + for row, col, fov in position_list: + position = plate_dataset.create_position(row, col, fov) + position["0"] = np.random.randint( + 0, np.iinfo(np.uint16).max, size=(3, 2, 4, 5, 6), dtype=np.uint16 + ) yield plate_path, plate_dataset diff --git a/mantis/tests/test_analysis/test_affine.py b/mantis/tests/test_analysis/test_affine.py new file mode 100644 index 00000000..d6544ea2 --- /dev/null +++ b/mantis/tests/test_analysis/test_affine.py @@ -0,0 +1,55 @@ +import ants +import numpy as np + +from mantis.analysis.register import ( + apply_affine_transform, + convert_transform_to_ants, + convert_transform_to_numpy, +) + + +def test_numpy_to_ants_transform_zyx(): + T_numpy = np.eye(4) + T_ants = convert_transform_to_ants(T_numpy) + assert isinstance(T_ants, ants.core.ants_transform.ANTsTransform) + + +def test_ants_to_numpy_transform_zyx(): + T_ants = ants.new_ants_transform(transform_type='AffineTransform') + T_ants.set_parameters(np.eye(12)) + T_numpy = convert_transform_to_numpy(T_ants) + assert isinstance(T_numpy, np.ndarray) + assert T_numpy.shape == (4, 4) + + +def test_affine_transform(): + # Create input data + zyx_data = np.ones((10, 10, 10)) + matrix = np.eye(4) + output_shape_zyx = (10, 10, 10) + + # Call the function + result = apply_affine_transform(zyx_data, matrix, output_shape_zyx) + + # Check the result + assert isinstance(result, np.ndarray) + assert result.shape == output_shape_zyx + + +def test_3d_translation(): + # Create input data + zyx_data = np.ones((10, 10, 10)) + matrix = np.eye(4) + translation = np.array([-3, 1, 4]) + matrix[:3, -1] = translation + output_shape_zyx = (10, 10, 10) + + # Call the function + result = apply_affine_transform(zyx_data, matrix, output_shape_zyx) + + # Check the result + assert isinstance(result, np.ndarray) + assert result.shape == output_shape_zyx + assert np.all( + result[3:10, 0:9, 0:6] == 1 + ) # Test if the shifts where going to the right direction diff --git a/mantis/tests/test_analysis/test_analysis_settings.py b/mantis/tests/test_analysis/test_analysis_settings.py index 018079da..299cd199 100644 --- a/mantis/tests/test_analysis/test_analysis_settings.py +++ b/mantis/tests/test_analysis/test_analysis_settings.py @@ -1,20 +1,28 @@ import numpy as np import pytest +from pydantic import ValidationError + from mantis.analysis.AnalysisSettings import DeskewSettings, RegistrationSettings def test_deskew_settings(): # Test extra parameter - with pytest.raises(TypeError): - DeskewSettings(typo_param="test") + with pytest.raises(ValidationError): + DeskewSettings( + pixel_size_um=0.116, ls_angle_deg=36, scan_step_um=0.313, typo_param="test" + ) # Test negative value - with pytest.raises(TypeError): - DeskewSettings(pixel_size_um=-3) + with pytest.raises(ValidationError): + DeskewSettings(pixel_size_um=-3, ls_angle_deg=36, scan_step_um=0.313) + + # Test light sheet angle range + with pytest.raises(ValueError): + DeskewSettings(pixel_size_um=0.116, ls_angle_deg=90, scan_step_um=0.313) # Test px_to_scan_ratio logic - with pytest.raises(TypeError): + with pytest.raises(ValueError): DeskewSettings(pixel_size_um=0.116, ls_angle_deg=36, scan_step_um=None) @@ -26,17 +34,31 @@ def test_example_deskew_settings(example_deskew_settings): def test_apply_affine_settings(): # Test extra parameter - with pytest.raises(TypeError): - RegistrationSettings(typo_param="test") + with pytest.raises(ValidationError): + RegistrationSettings( + source_channel_index=0, + target_channel_index=0, + affine_transform_zyx=np.identity(4).tolist(), + typo_param="test", + ) # Test wrong output shape size - with pytest.raises(TypeError): - RegistrationSettings(output_shape_zyx=[1, 2, 3, 4]) + with pytest.raises(ValidationError): + RegistrationSettings( + source_channel_index=0, + target_channel_index=0, + affine_transform_zyx=np.identity(4).tolist(), + typo_param="test", + ) # Test wrong matrix shape - with pytest.raises(TypeError): - random_array = np.random.rand(5, 5) - RegistrationSettings(affine_transform_zyx=random_array.tolist()) + with pytest.raises(ValidationError): + RegistrationSettings( + source_channel_index=0, + target_channel_index=0, + affine_transform_zyx=np.identity(5).tolist(), + typo_param="test", + ) def test_example_apply_affine_settings(example_apply_affine_settings): diff --git a/mantis/tests/test_cli/test_apply_affine_cli.py b/mantis/tests/test_cli/test_apply_affine_cli.py index 72e9d859..20fb2c46 100644 --- a/mantis/tests/test_cli/test_apply_affine_cli.py +++ b/mantis/tests/test_cli/test_apply_affine_cli.py @@ -3,12 +3,15 @@ from click.testing import CliRunner from numpy import testing -from mantis.cli.apply_affine import apply_affine_to_scale +from mantis.cli.apply_affine import rescale_voxel_size from mantis.cli.main import cli -def test_apply_affine_cli(tmp_path, example_plate, example_apply_affine_settings): +def test_apply_affine_cli( + tmp_path, example_plate, example_plate_2, example_apply_affine_settings +): plate_path, _ = example_plate + plate_path_2, _ = example_plate_2 config_path, _ = example_apply_affine_settings output_path = tmp_path / "output.zarr" @@ -17,8 +20,10 @@ def test_apply_affine_cli(tmp_path, example_plate, example_apply_affine_settings cli, [ "apply-affine", - "-i", + "-s", str(plate_path) + "/A/1/0", + "-t", + str(plate_path_2) + "/A/1/0", # test could be improved with different stores "-c", str(config_path), "-o", @@ -27,8 +32,8 @@ def test_apply_affine_cli(tmp_path, example_plate, example_apply_affine_settings catch_exceptions=False, ) - assert output_path.exists() assert result.exit_code == 0 + assert output_path.exists() def test_apply_affine_to_scale(): @@ -37,18 +42,18 @@ def test_apply_affine_to_scale(): # Test real positive m1_diag = np.array([2, 3, 4]) m1 = np.diag(m1_diag) - output1 = apply_affine_to_scale(m1, input) + output1 = rescale_voxel_size(m1, input) testing.assert_allclose(m1_diag, output1) # Test real with negative m2_diag = np.array([2, -3, 4]) m2 = np.diag(m2_diag) - output2 = apply_affine_to_scale(m2, input) + output2 = rescale_voxel_size(m2, input) testing.assert_allclose(np.abs(m2_diag), output2) # Test transpose m3 = np.array([[0, 2, 0], [1, 0, 0], [0, 0, 3]]) - output3 = apply_affine_to_scale(m3, input) + output3 = rescale_voxel_size(m3, input) testing.assert_allclose(np.array([2, 1, 3]), output3) # Test rotation @@ -60,5 +65,5 @@ def test_apply_affine_to_scale(): [0, 3 * np.sin(theta), 3 * np.cos(theta)], ] ) - output4 = apply_affine_to_scale(m4, input) + output4 = rescale_voxel_size(m4, input) testing.assert_allclose(np.array([2, 3, 3]), output4) diff --git a/mantis/tests/test_cli/test_estimate_phase_to_fluor_affine_cli.py b/mantis/tests/test_cli/test_estimate_phase_to_fluor_affine_cli.py deleted file mode 100644 index 7eae6f94..00000000 --- a/mantis/tests/test_cli/test_estimate_phase_to_fluor_affine_cli.py +++ /dev/null @@ -1,25 +0,0 @@ -from click.testing import CliRunner - -from mantis.cli.main import cli - - -def test_estimate_phase_to_fluor_affine_cli(tmp_path, example_plate): - plate_path, _ = example_plate - output_path = tmp_path / "config.yaml" - - runner = CliRunner() - result = runner.invoke( - cli, - [ - "estimate-phase-to-fluor-affine", - "-lf", - str(plate_path) + "/A/1/0", - "-ls", - str(plate_path) + "/B/1/0", # test could be improved with different stores - "-o", - str(output_path), - ], - ) - - # Weak test - assert "Enter phase_channel index to process" in result.output diff --git a/mantis/tests/test_cli/test_optimize_affine.py b/mantis/tests/test_cli/test_optimize_affine.py new file mode 100644 index 00000000..35f9b592 --- /dev/null +++ b/mantis/tests/test_cli/test_optimize_affine.py @@ -0,0 +1,31 @@ +from click.testing import CliRunner + +from mantis.cli.main import cli + + +def test_optimize_affine_cli(tmp_path, example_plate, example_apply_affine_settings): + plate_path, _ = example_plate + config_path, _ = example_apply_affine_settings + output_path = tmp_path / "config.yaml" + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "optimize-affine", + "-s", + str(plate_path) + "/A/1/0", + "-t", + str(plate_path) + "/B/1/0", # test could be improved with different stores + "-c", + str(config_path), + "-o", + str(output_path), + ], + ) + + # Weak test + # NOTE: we changed the output of the function so this is no longer printed. Do we need to compare with something? + # assert "Getting dataset info" in result.output + assert result.exit_code == 0 + assert output_path.exists() diff --git a/pyproject.toml b/pyproject.toml index 14ffc2b7..2672f689 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dependencies = [ "slurmkit @ git+https://github.com/royerlab/slurmkit", "tifffile", "waveorder @ git+https://github.com/mehta-lab/waveorder", + "largestinteriorrectangle", + "antspyx", ]