Skip to content

Commit

Permalink
Phase reconstruction is invariant to voxel-size (#164)
Browse files Browse the repository at this point in the history
* fix bug finding focus in stack with only one slice

* refactor for clarify

* formatting

* print -> warnings.warn

* test single-slice case

* fix test bugs

* z-scale-invariant test object

* no rescaling on output

* forward simulation takes a "brightness" - simulating real microscope

* fix example script

* add background parameter for fluorescence forward model

* test voxel-size invariance

* rename I_norm -> direct_intensity

* refactor to clarify discretization factor

* remove comment

* fix fluorescence example bug

* improved docsring

---------

Co-authored-by: Ivan Ivanov <[email protected]>
  • Loading branch information
talonchandler and ieivanov authored Apr 23, 2024
1 parent d08f296 commit 4a20aeb
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 34 deletions.
4 changes: 2 additions & 2 deletions examples/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Parameters
# all lengths must use consistent units e.g. um
simulation_arguments = {
"zyx_shape": (100, 256, 256),
"zyx_shape": (200, 256, 256),
"yx_pixel_size": 6.5 / 63,
"z_pixel_size": 0.25,
}
phantom_arguments = {"sphere_radius": 5}
transfer_function_arguments = {
"wavelength_illumination": 0.532,
"wavelength_emission": 0.532,
"z_padding": 0,
"index_of_refraction_media": 1.3,
"numerical_aperture_detection": 1.2,
Expand Down
5 changes: 2 additions & 3 deletions examples/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"zyx_shape": (100, 256, 256),
"yx_pixel_size": 6.5 / 63,
"z_pixel_size": 0.25,
"wavelength_illumination": 0.532,
"index_of_refraction_media": 1.3,
}
phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5}
transfer_function_arguments = {
"z_padding": 0,
"wavelength_illumination": 0.532,
"numerical_aperture_illumination": 0.9,
"numerical_aperture_detection": 1.2,
}
Expand Down Expand Up @@ -61,6 +61,7 @@
zyx_phase,
real_potential_transfer_function,
transfer_function_arguments["z_padding"],
brightness=1e3,
)

# Reconstruct
Expand All @@ -69,8 +70,6 @@
real_potential_transfer_function,
imag_potential_transfer_function,
transfer_function_arguments["z_padding"],
simulation_arguments["z_pixel_size"],
simulation_arguments["wavelength_illumination"],
)

# Display
Expand Down
81 changes: 80 additions & 1 deletion tests/models/test_phase_thick_3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

import numpy as np
from waveorder.models import phase_thick_3d


Expand All @@ -20,3 +20,82 @@ def test_calculate_transfer_function(invert_phase_contrast):

assert H_re.shape == (20 + 2 * z_padding, 100, 101)
assert H_im.shape == (20 + 2 * z_padding, 100, 101)


# Helper function for testing reconstruction invariances
def simulate_phase_recon(
z_pixel_size_um=0.1,
yx_pixel_size_um=6.5 / 63,
):

z_fov_um = 50
yx_fov_um = 50

n_z = np.int32(z_fov_um / z_pixel_size_um)
n_yx = np.int32(yx_fov_um / yx_pixel_size_um)

# Parameters
# all lengths must use consistent units e.g. um
simulation_arguments = {
"zyx_shape": (n_z, n_yx, n_yx),
"yx_pixel_size": yx_pixel_size_um,
"z_pixel_size": z_pixel_size_um,
"index_of_refraction_media": 1.3,
}
phantom_arguments = {
"index_of_refraction_sample": 1.40,
"sphere_radius": 5,
}
transfer_function_arguments = {
"z_padding": 0,
"wavelength_illumination": 0.532,
"numerical_aperture_illumination": 0.9,
"numerical_aperture_detection": 1.3,
}

# Create a phantom
zyx_phase = phase_thick_3d.generate_test_phantom(
**simulation_arguments, **phantom_arguments
)

# Calculate transfer function
(
real_potential_transfer_function,
imag_potential_transfer_function,
) = phase_thick_3d.calculate_transfer_function(
**simulation_arguments, **transfer_function_arguments
)

# Simulate
zyx_data = phase_thick_3d.apply_transfer_function(
zyx_phase,
real_potential_transfer_function,
transfer_function_arguments["z_padding"],
brightness=1000,
)

# Reconstruct
zyx_recon = phase_thick_3d.apply_inverse_transfer_function(
zyx_data,
real_potential_transfer_function,
imag_potential_transfer_function,
transfer_function_arguments["z_padding"],
regularization_strength=1e-3,
)

Z, Y, X = zyx_phase.shape
recon_center = zyx_recon[Z // 2, Y // 2, X // 2].numpy()

return recon_center


def test_phase_invariance():
recon = simulate_phase_recon()

# test z pixel size invariance
recon1 = simulate_phase_recon(z_pixel_size_um=0.3)
assert np.abs((recon1 - recon) / recon) < 0.02

# test yx pixel size invariance
recon2 = simulate_phase_recon(yx_pixel_size_um=0.7 * 6.5 / 63)
assert np.abs((recon2 - recon) / recon) < 0.02
21 changes: 19 additions & 2 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,24 @@ def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
viewer.dims.order = (0, 1, 2)


def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
def apply_transfer_function(
zyx_object, optical_transfer_function, z_padding, background=10
):
"""Simulate imaging by applying a transfer function
Parameters
----------
zyx_object : torch.Tensor
optical_transfer_function : torch.Tensor
z_padding : int
background : int, optional
constant background counts added to each voxel, by default 10
Returns
-------
Simulated data : torch.Tensor
"""
if (
zyx_object.shape[0] + 2 * z_padding
!= optical_transfer_function.shape[0]
Expand All @@ -99,7 +116,7 @@ def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
zyx_data = zyx_obj_hat * optical_transfer_function
data = torch.real(torch.fft.ifftn(zyx_data))

data += 10 # Add a direct background
data += background # Add a direct background
return data


Expand Down
35 changes: 14 additions & 21 deletions waveorder/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def generate_test_phantom(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
index_of_refraction_media,
index_of_refraction_sample,
sphere_radius,
Expand All @@ -24,12 +23,9 @@ def generate_test_phantom(
radius=sphere_radius,
blur_size=2 * yx_pixel_size,
)
zyx_phase = (
sphere
* (index_of_refraction_sample - index_of_refraction_media)
* z_pixel_size
/ wavelength_illumination
) # phase in radians
zyx_phase = sphere * (
index_of_refraction_sample - index_of_refraction_media
) # refractive index increment

return zyx_phase

Expand Down Expand Up @@ -120,12 +116,19 @@ def visualize_transfer_function(


def apply_transfer_function(
zyx_object, real_potential_transfer_function, z_padding
zyx_object, real_potential_transfer_function, z_padding, brightness
):
# This simplified forward model only handles phase, so it resuses the fluorescence forward model
# TODO: extend to absorption
return isotropic_fluorescent_thick_3d.apply_transfer_function(
zyx_object, real_potential_transfer_function, z_padding
return (
isotropic_fluorescent_thick_3d.apply_transfer_function(
zyx_object,
real_potential_transfer_function,
z_padding,
background=0,
)
* brightness
+ brightness
)


Expand All @@ -134,8 +137,6 @@ def apply_inverse_transfer_function(
real_potential_transfer_function: Tensor,
imaginary_potential_transfer_function: Tensor,
z_padding: int,
z_pixel_size: float, # TODO: MOVE THIS PARAM TO OTF? (leaky param)
wavelength_illumination: float, # TOOD: MOVE THIS PARAM TO OTF? (leaky param)
absorption_ratio: float = 0.0,
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
regularization_strength: float = 1e-3,
Expand All @@ -158,14 +159,6 @@ def apply_inverse_transfer_function(
z_padding : int
Padding for axial dimension. Use zero for defocus stacks that
extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise.
z_pixel_size : float
spacing between axial samples in sample space
units must be consistent with wavelength_illumination
TODO: move this leaky parameter to calculate_transfer_function
wavelength_illumination : float,
illumination wavelength
units must be consistent with z_pixel_size
TODO: move this leaky parameter to calculate_transfer_function
absorption_ratio : float, optional,
Absorption-to-phase ratio in the sample.
Use default 0 for purely phase objects.
Expand Down Expand Up @@ -223,4 +216,4 @@ def apply_inverse_transfer_function(
if z_padding != 0:
f_real = f_real[z_padding:-z_padding]

return f_real * z_pixel_size / 4 / np.pi * wavelength_illumination
return f_real
14 changes: 9 additions & 5 deletions waveorder/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,19 +739,23 @@ def compute_weak_object_transfer_function_3D(

H1 = torch.fft.ifft2(torch.conj(SPHz_hat) * PG_hat, dim=(1, 2))
H1 = H1 * window[:, None, None]
H1 = torch.fft.fft(H1, dim=0) * z_pixel_size
H1 = torch.fft.fft(H1, dim=0)

H2 = torch.fft.ifft2(SPHz_hat * torch.conj(PG_hat), dim=(1, 2))
H2 = H2 * window[:, None, None]
H2 = torch.fft.fft(H2, dim=0) * z_pixel_size
H2 = torch.fft.fft(H2, dim=0)

I_norm = torch.sum(
direct_intensity = torch.sum(
illumination_pupil_support
* detection_pupil
* torch.conj(detection_pupil)
)
real_potential_transfer_function = (H1 + H2) / I_norm
imag_potential_transfer_function = 1j * (H1 - H2) / I_norm
real_potential_transfer_function = (H1 + H2) / direct_intensity
imag_potential_transfer_function = 1j * (H1 - H2) / direct_intensity

# Discretization factor for unitless input and output
real_potential_transfer_function *= z_pixel_size
imag_potential_transfer_function *= z_pixel_size

return real_potential_transfer_function, imag_potential_transfer_function

Expand Down

0 comments on commit 4a20aeb

Please sign in to comment.