diff --git a/recOrder/cli/apply_inverse_models.py b/recOrder/cli/apply_inverse_models.py index 98de1b03..f2fe8fb7 100644 --- a/recOrder/cli/apply_inverse_models.py +++ b/recOrder/cli/apply_inverse_models.py @@ -5,11 +5,13 @@ import numpy as np import torch from waveorder.models import ( + inplane_oriented_thick_pol3d_vector, inplane_oriented_thick_pol3d, isotropic_fluorescent_thick_3d, isotropic_thin_3d, phase_thick_3d, ) +from waveorder.stokes import stokes_after_adr, _s12_to_orientation def radians_to_nanometers(retardance_rad, wavelength_illumination_um): @@ -219,10 +221,52 @@ def birefringence_and_phase( retardance = radians_to_nanometers( reconstructed_parameters_3d[0], wavelength_illumination ) + # Load singular system + U = torch.tensor( + np.array(transfer_function_dataset["singular_system_U"]) + ) + S = torch.tensor( + np.array(transfer_function_dataset["singular_system_S"][0]) + ) + Vh = torch.tensor( + np.array(transfer_function_dataset["singular_system_Vh"]) + ) + singular_system = (U, S, Vh) + + # Convert retardance and orientation to stokes + stokes = stokes_after_adr(*reconstructed_parameters_3d) + + stokes = torch.nan_to_num_(torch.stack(stokes), nan=0.0) # very rare nans from previous like + + # Apply reconstruction + joint_recon_params = inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function( + szyx_data=stokes, + singular_system=singular_system, + intensity_to_stokes_matrix=None, + **settings_phase.apply_inverse.dict(), + ) + + new_ret = ( + joint_recon_params[1] ** 2 + joint_recon_params[2] ** 2 + ) ** (0.5) + new_ori = _s12_to_orientation( + joint_recon_params[1], -joint_recon_params[2] + ) + + # Convert stokes to retardance and orientation + # new_ret, new_ori, _ = estimate_ar_from_stokes012(*joint_recon_params) + + # Convert retardance + new_ret_nm = radians_to_nanometers(new_ret, wavelength_illumination) # Save output = torch.stack( - (retardance,) + reconstructed_parameters_3d[1:] + (zyx_phase,) + (retardance,) + + reconstructed_parameters_3d[1:] + + (zyx_phase,) + + (new_ret_nm,) + + (new_ori,) + + (joint_recon_params[0],) ) return output diff --git a/recOrder/cli/apply_inverse_transfer_function.py b/recOrder/cli/apply_inverse_transfer_function.py index b4f97a38..d8b99eaf 100644 --- a/recOrder/cli/apply_inverse_transfer_function.py +++ b/recOrder/cli/apply_inverse_transfer_function.py @@ -77,6 +77,10 @@ def get_reconstruction_output_metadata(position_path: Path, config_path: Path): channel_names.append("Phase2D") elif recon_dim == 3: channel_names.append("Phase3D") + if recon_biref and recon_phase: + channel_names.append("Retardance_Joint_Decon") + channel_names.append("Orientation_Joint_Decon") + channel_names.append("Phase_Joint_Decon") if recon_fluo: fluor_name = settings.input_channel_names[0] if recon_dim == 2: @@ -313,7 +317,7 @@ def apply_inverse_transfer_function_cli( settings = utils.yaml_to_model(config_filepath, ReconstructionSettings) gb_ram_request = 0 - gb_per_element = 4 / 2**30 # bytes_per_float32 / bytes_per_gb + gb_per_element = 4 / 2 ** 30 # bytes_per_float32 / bytes_per_gb voxel_resource_multiplier = 4 fourier_resource_multiplier = 32 input_memory = Z * Y * X * gb_per_element @@ -336,13 +340,13 @@ def apply_inverse_transfer_function_cli( f"{cpu_request} CPU{'s' if cpu_request > 1 else ''} and " f"{gb_ram_request} GB of memory per CPU." ) - executor = submitit.AutoExecutor(folder="logs") + executor = submitit.AutoExecutor(folder="logs") #, cluster="debug") executor.update_parameters( slurm_array_parallelism=np.min([50, num_jobs]), slurm_mem_per_cpu=f"{gb_ram_request}G", slurm_cpus_per_task=cpu_request, - slurm_time=60, + slurm_time=600, slurm_partition="cpu", # more slurm_*** resource parameters here ) diff --git a/recOrder/cli/compute_transfer_function.py b/recOrder/cli/compute_transfer_function.py index 1b31a812..5e8e992d 100644 --- a/recOrder/cli/compute_transfer_function.py +++ b/recOrder/cli/compute_transfer_function.py @@ -4,6 +4,7 @@ import numpy as np from iohub.ngff import open_ome_zarr, Position from waveorder.models import ( + inplane_oriented_thick_pol3d_vector, inplane_oriented_thick_pol3d, isotropic_fluorescent_thick_3d, isotropic_thin_3d, @@ -20,6 +21,70 @@ from recOrder.io import utils +def generate_and_save_vector_birefringence_transfer_function( + settings: ReconstructionSettings, dataset: Position, zyx_shape: tuple +): + """Generates and saves the vector birefringence transfer function + to the dataset, based on the settings. + + Parameters + ---------- + settings : ReconstructionSettings + dataset : NGFF Node + The dataset that will be updated. + zyx_shape : tuple + A tuple of integers specifying the input data's shape in (Z, Y, X) order + """ + echo_headline( + "Generating vector birefringence transfer function with settings:" + ) + echo_settings(settings.birefringence.transfer_function) + echo_settings(settings.phase.transfer_function) + + num_elements = np.array(zyx_shape).prod() + max_tf_elements = 1e7 # empirical, based on memory usage + transverse_downsample_factor = np.ceil(np.sqrt(num_elements / max_tf_elements)) + echo_headline(f"Downsampling transfer function in X and Y by {transverse_downsample_factor}x") + + sfZYX_transfer_function, _, singular_system= ( + inplane_oriented_thick_pol3d_vector.calculate_transfer_function( + zyx_shape=zyx_shape, + scheme=str(len(settings.input_channel_names)) + "-State", + **settings.birefringence.transfer_function.dict(), + **settings.phase.transfer_function.dict(), + transverse_downsample_factor=transverse_downsample_factor, + ) + ) + + U, S, Vh = singular_system + chunks = (1, 1, 1, zyx_shape[1], zyx_shape[2]) + + # Add dummy channels + for i in range(3): + dataset.append_channel(f"ch{i}") + + dataset.create_image( + "vector_transfer_function", + sfZYX_transfer_function.cpu().numpy(), + chunks=chunks, + ) + dataset.create_image( + "singular_system_U", + U.cpu().numpy(), + chunks=chunks, + ) + dataset.create_image( + "singular_system_S", + S[None].cpu().numpy(), + chunks=chunks, + ) + dataset.create_image( + "singular_system_Vh", + Vh.cpu().numpy(), + chunks=chunks, + ) + + def generate_and_save_birefringence_transfer_function(settings, dataset): """Generates and saves the birefringence transfer function to the dataset, based on the settings. @@ -40,9 +105,9 @@ def generate_and_save_birefringence_transfer_function(settings, dataset): ) ) # Save - dataset[ - "intensity_to_stokes_matrix" - ] = intensity_to_stokes_matrix.cpu().numpy()[None, None, None, ...] + dataset["intensity_to_stokes_matrix"] = ( + intensity_to_stokes_matrix.cpu().numpy()[None, None, None, ...] + ) def generate_and_save_phase_transfer_function( @@ -200,6 +265,10 @@ def compute_transfer_function_cli( generate_and_save_fluorescence_transfer_function( settings, output_dataset, zyx_shape ) + if settings.birefringence is not None and settings.phase is not None: + generate_and_save_vector_birefringence_transfer_function( + settings, output_dataset, zyx_shape + ) # Write settings to metadata output_dataset.zattrs["settings"] = settings.dict() diff --git a/recOrder/io/utils.py b/recOrder/io/utils.py index c29769cc..81165502 100644 --- a/recOrder/io/utils.py +++ b/recOrder/io/utils.py @@ -8,7 +8,6 @@ from iohub import open_ome_zarr - def add_index_to_path(path: Path): """Takes a path to a file or folder and appends the smallest index that does not already exist in that folder. diff --git a/recOrder/io/visualization.py b/recOrder/io/visualization.py index 524b131d..11469ea5 100644 --- a/recOrder/io/visualization.py +++ b/recOrder/io/visualization.py @@ -51,18 +51,21 @@ def ret_ori_overlay( overlay_final = np.zeros_like(retardance) if cmap == "JCh": - J = ret_ - C = np.ones_like(J) * 60 + J_MAX = 65 + C_MAX = 60 + + J = (ret_ / ret_max) * J_MAX + C = np.ones_like(J) * C_MAX C[ret_ < ret_min] = 0 h = ori_ JCh = np.stack((J, C, h), axis=-1) - JCh_rgb = cspace_convert(JCh, "JCh", "sRGB1") + JCh_rgb = cspace_convert(JCh, "JCh", "sRGB255") JCh_rgb[JCh_rgb < 0] = 0 - JCh_rgb[JCh_rgb > 1] = 1 + JCh_rgb[JCh_rgb > 255] = 255 - overlay_final = JCh_rgb + overlay_final = JCh_rgb.astype(np.uint8) elif cmap == "HSV": I_hsv = np.moveaxis( np.stack( diff --git a/setup.cfg b/setup.cfg index f36021c7..9b705ac8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ include_package_data = True python_requires = >=3.10 setup_requires = setuptools_scm install_requires = - waveorder==2.2.0rc0 + waveorder @ git+https://github.com/mehta-lab/waveorder.git@1cb7d53a135771368e26065d9e427535e0475858 click>=8.0.1 natsort>=7.1.1 colorspacious>=1.1.2