From 4a20aebd91c9a4dae54a6a5f28585686af76d3d0 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Tue, 23 Apr 2024 14:25:46 -0700 Subject: [PATCH] Phase reconstruction is invariant to voxel-size (#164) * fix bug finding focus in stack with only one slice * refactor for clarify * formatting * print -> warnings.warn * test single-slice case * fix test bugs * z-scale-invariant test object * no rescaling on output * forward simulation takes a "brightness" - simulating real microscope * fix example script * add background parameter for fluorescence forward model * test voxel-size invariance * rename I_norm -> direct_intensity * refactor to clarify discretization factor * remove comment * fix fluorescence example bug * improved docsring --------- Co-authored-by: Ivan Ivanov --- .../models/isotropic_fluorescent_thick_3d.py | 4 +- examples/models/phase_thick_3d.py | 5 +- tests/models/test_phase_thick_3d.py | 81 ++++++++++++++++++- .../models/isotropic_fluorescent_thick_3d.py | 21 ++++- waveorder/models/phase_thick_3d.py | 35 ++++---- waveorder/optics.py | 14 ++-- 6 files changed, 126 insertions(+), 34 deletions(-) diff --git a/examples/models/isotropic_fluorescent_thick_3d.py b/examples/models/isotropic_fluorescent_thick_3d.py index 21a68169..01840798 100644 --- a/examples/models/isotropic_fluorescent_thick_3d.py +++ b/examples/models/isotropic_fluorescent_thick_3d.py @@ -6,13 +6,13 @@ # Parameters # all lengths must use consistent units e.g. um simulation_arguments = { - "zyx_shape": (100, 256, 256), + "zyx_shape": (200, 256, 256), "yx_pixel_size": 6.5 / 63, "z_pixel_size": 0.25, } phantom_arguments = {"sphere_radius": 5} transfer_function_arguments = { - "wavelength_illumination": 0.532, + "wavelength_emission": 0.532, "z_padding": 0, "index_of_refraction_media": 1.3, "numerical_aperture_detection": 1.2, diff --git a/examples/models/phase_thick_3d.py b/examples/models/phase_thick_3d.py index aa9bd4aa..54d7e43d 100644 --- a/examples/models/phase_thick_3d.py +++ b/examples/models/phase_thick_3d.py @@ -14,12 +14,12 @@ "zyx_shape": (100, 256, 256), "yx_pixel_size": 6.5 / 63, "z_pixel_size": 0.25, - "wavelength_illumination": 0.532, "index_of_refraction_media": 1.3, } phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5} transfer_function_arguments = { "z_padding": 0, + "wavelength_illumination": 0.532, "numerical_aperture_illumination": 0.9, "numerical_aperture_detection": 1.2, } @@ -61,6 +61,7 @@ zyx_phase, real_potential_transfer_function, transfer_function_arguments["z_padding"], + brightness=1e3, ) # Reconstruct @@ -69,8 +70,6 @@ real_potential_transfer_function, imag_potential_transfer_function, transfer_function_arguments["z_padding"], - simulation_arguments["z_pixel_size"], - simulation_arguments["wavelength_illumination"], ) # Display diff --git a/tests/models/test_phase_thick_3d.py b/tests/models/test_phase_thick_3d.py index 224c96c6..d60c7fa6 100644 --- a/tests/models/test_phase_thick_3d.py +++ b/tests/models/test_phase_thick_3d.py @@ -1,5 +1,5 @@ import pytest - +import numpy as np from waveorder.models import phase_thick_3d @@ -20,3 +20,82 @@ def test_calculate_transfer_function(invert_phase_contrast): assert H_re.shape == (20 + 2 * z_padding, 100, 101) assert H_im.shape == (20 + 2 * z_padding, 100, 101) + + +# Helper function for testing reconstruction invariances +def simulate_phase_recon( + z_pixel_size_um=0.1, + yx_pixel_size_um=6.5 / 63, +): + + z_fov_um = 50 + yx_fov_um = 50 + + n_z = np.int32(z_fov_um / z_pixel_size_um) + n_yx = np.int32(yx_fov_um / yx_pixel_size_um) + + # Parameters + # all lengths must use consistent units e.g. um + simulation_arguments = { + "zyx_shape": (n_z, n_yx, n_yx), + "yx_pixel_size": yx_pixel_size_um, + "z_pixel_size": z_pixel_size_um, + "index_of_refraction_media": 1.3, + } + phantom_arguments = { + "index_of_refraction_sample": 1.40, + "sphere_radius": 5, + } + transfer_function_arguments = { + "z_padding": 0, + "wavelength_illumination": 0.532, + "numerical_aperture_illumination": 0.9, + "numerical_aperture_detection": 1.3, + } + + # Create a phantom + zyx_phase = phase_thick_3d.generate_test_phantom( + **simulation_arguments, **phantom_arguments + ) + + # Calculate transfer function + ( + real_potential_transfer_function, + imag_potential_transfer_function, + ) = phase_thick_3d.calculate_transfer_function( + **simulation_arguments, **transfer_function_arguments + ) + + # Simulate + zyx_data = phase_thick_3d.apply_transfer_function( + zyx_phase, + real_potential_transfer_function, + transfer_function_arguments["z_padding"], + brightness=1000, + ) + + # Reconstruct + zyx_recon = phase_thick_3d.apply_inverse_transfer_function( + zyx_data, + real_potential_transfer_function, + imag_potential_transfer_function, + transfer_function_arguments["z_padding"], + regularization_strength=1e-3, + ) + + Z, Y, X = zyx_phase.shape + recon_center = zyx_recon[Z // 2, Y // 2, X // 2].numpy() + + return recon_center + + +def test_phase_invariance(): + recon = simulate_phase_recon() + + # test z pixel size invariance + recon1 = simulate_phase_recon(z_pixel_size_um=0.3) + assert np.abs((recon1 - recon) / recon) < 0.02 + + # test yx pixel size invariance + recon2 = simulate_phase_recon(yx_pixel_size_um=0.7 * 6.5 / 63) + assert np.abs((recon2 - recon) / recon) < 0.02 \ No newline at end of file diff --git a/waveorder/models/isotropic_fluorescent_thick_3d.py b/waveorder/models/isotropic_fluorescent_thick_3d.py index 58234852..b52e71b5 100644 --- a/waveorder/models/isotropic_fluorescent_thick_3d.py +++ b/waveorder/models/isotropic_fluorescent_thick_3d.py @@ -81,7 +81,24 @@ def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale): viewer.dims.order = (0, 1, 2) -def apply_transfer_function(zyx_object, optical_transfer_function, z_padding): +def apply_transfer_function( + zyx_object, optical_transfer_function, z_padding, background=10 +): + """Simulate imaging by applying a transfer function + + Parameters + ---------- + zyx_object : torch.Tensor + optical_transfer_function : torch.Tensor + z_padding : int + background : int, optional + constant background counts added to each voxel, by default 10 + + Returns + ------- + Simulated data : torch.Tensor + + """ if ( zyx_object.shape[0] + 2 * z_padding != optical_transfer_function.shape[0] @@ -99,7 +116,7 @@ def apply_transfer_function(zyx_object, optical_transfer_function, z_padding): zyx_data = zyx_obj_hat * optical_transfer_function data = torch.real(torch.fft.ifftn(zyx_data)) - data += 10 # Add a direct background + data += background # Add a direct background return data diff --git a/waveorder/models/phase_thick_3d.py b/waveorder/models/phase_thick_3d.py index 39555b50..ac29d356 100644 --- a/waveorder/models/phase_thick_3d.py +++ b/waveorder/models/phase_thick_3d.py @@ -12,7 +12,6 @@ def generate_test_phantom( zyx_shape, yx_pixel_size, z_pixel_size, - wavelength_illumination, index_of_refraction_media, index_of_refraction_sample, sphere_radius, @@ -24,12 +23,9 @@ def generate_test_phantom( radius=sphere_radius, blur_size=2 * yx_pixel_size, ) - zyx_phase = ( - sphere - * (index_of_refraction_sample - index_of_refraction_media) - * z_pixel_size - / wavelength_illumination - ) # phase in radians + zyx_phase = sphere * ( + index_of_refraction_sample - index_of_refraction_media + ) # refractive index increment return zyx_phase @@ -120,12 +116,19 @@ def visualize_transfer_function( def apply_transfer_function( - zyx_object, real_potential_transfer_function, z_padding + zyx_object, real_potential_transfer_function, z_padding, brightness ): # This simplified forward model only handles phase, so it resuses the fluorescence forward model # TODO: extend to absorption - return isotropic_fluorescent_thick_3d.apply_transfer_function( - zyx_object, real_potential_transfer_function, z_padding + return ( + isotropic_fluorescent_thick_3d.apply_transfer_function( + zyx_object, + real_potential_transfer_function, + z_padding, + background=0, + ) + * brightness + + brightness ) @@ -134,8 +137,6 @@ def apply_inverse_transfer_function( real_potential_transfer_function: Tensor, imaginary_potential_transfer_function: Tensor, z_padding: int, - z_pixel_size: float, # TODO: MOVE THIS PARAM TO OTF? (leaky param) - wavelength_illumination: float, # TOOD: MOVE THIS PARAM TO OTF? (leaky param) absorption_ratio: float = 0.0, reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov", regularization_strength: float = 1e-3, @@ -158,14 +159,6 @@ def apply_inverse_transfer_function( z_padding : int Padding for axial dimension. Use zero for defocus stacks that extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise. - z_pixel_size : float - spacing between axial samples in sample space - units must be consistent with wavelength_illumination - TODO: move this leaky parameter to calculate_transfer_function - wavelength_illumination : float, - illumination wavelength - units must be consistent with z_pixel_size - TODO: move this leaky parameter to calculate_transfer_function absorption_ratio : float, optional, Absorption-to-phase ratio in the sample. Use default 0 for purely phase objects. @@ -223,4 +216,4 @@ def apply_inverse_transfer_function( if z_padding != 0: f_real = f_real[z_padding:-z_padding] - return f_real * z_pixel_size / 4 / np.pi * wavelength_illumination + return f_real diff --git a/waveorder/optics.py b/waveorder/optics.py index 8033ab86..4151b9de 100644 --- a/waveorder/optics.py +++ b/waveorder/optics.py @@ -739,19 +739,23 @@ def compute_weak_object_transfer_function_3D( H1 = torch.fft.ifft2(torch.conj(SPHz_hat) * PG_hat, dim=(1, 2)) H1 = H1 * window[:, None, None] - H1 = torch.fft.fft(H1, dim=0) * z_pixel_size + H1 = torch.fft.fft(H1, dim=0) H2 = torch.fft.ifft2(SPHz_hat * torch.conj(PG_hat), dim=(1, 2)) H2 = H2 * window[:, None, None] - H2 = torch.fft.fft(H2, dim=0) * z_pixel_size + H2 = torch.fft.fft(H2, dim=0) - I_norm = torch.sum( + direct_intensity = torch.sum( illumination_pupil_support * detection_pupil * torch.conj(detection_pupil) ) - real_potential_transfer_function = (H1 + H2) / I_norm - imag_potential_transfer_function = 1j * (H1 - H2) / I_norm + real_potential_transfer_function = (H1 + H2) / direct_intensity + imag_potential_transfer_function = 1j * (H1 - H2) / direct_intensity + + # Discretization factor for unitless input and output + real_potential_transfer_function *= z_pixel_size + imag_potential_transfer_function *= z_pixel_size return real_potential_transfer_function, imag_potential_transfer_function