-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
isotropic_fluorescent_thick_3d
model (#128)
* typo * model outline * prototype transfer function * 3d phantom + visualize transfer function * refactor apply_transfer_function * typo * refactor padding (with gpt docs + tests) * complete example * test apply_inverse_transfer_function * TV reconstructions raise NotImplementedError * `pad_zyx` -> `pad_zyx_along_z` * Simplify `data += 10` * Update tests/test_util.py Co-authored-by: Ziwen Liu <[email protected]> --------- Co-authored-by: Ziwen Liu <[email protected]>
- Loading branch information
1 parent
c645c84
commit 1a707e1
Showing
8 changed files
with
368 additions
and
41 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import napari | ||
import numpy as np | ||
|
||
from waveorder.models import isotropic_fluorescent_thick_3d | ||
|
||
# Parameters | ||
# all lengths must use consistent units e.g. um | ||
simulation_arguments = { | ||
"zyx_shape": (100, 256, 256), | ||
"yx_pixel_size": 6.5 / 63, | ||
"z_pixel_size": 0.25, | ||
} | ||
phantom_arguments = {"sphere_radius": 5} | ||
transfer_function_arguments = { | ||
"wavelength_illumination": 0.532, | ||
"z_padding": 0, | ||
"index_of_refraction_media": 1.3, | ||
"numerical_aperture_detection": 1.2, | ||
} | ||
|
||
# Create a phantom | ||
zyx_fluorescence_density = ( | ||
isotropic_fluorescent_thick_3d.generate_test_phantom( | ||
**simulation_arguments, **phantom_arguments | ||
) | ||
) | ||
|
||
# Calculate transfer function | ||
optical_transfer_function = ( | ||
isotropic_fluorescent_thick_3d.calculate_transfer_function( | ||
**simulation_arguments, **transfer_function_arguments | ||
) | ||
) | ||
|
||
# Display transfer function | ||
viewer = napari.Viewer() | ||
zyx_scale = np.array( | ||
[ | ||
simulation_arguments["z_pixel_size"], | ||
simulation_arguments["yx_pixel_size"], | ||
simulation_arguments["yx_pixel_size"], | ||
] | ||
) | ||
isotropic_fluorescent_thick_3d.visualize_transfer_function( | ||
viewer, | ||
optical_transfer_function, | ||
zyx_scale, | ||
) | ||
input("Showing OTFs. Press <enter> to continue...") | ||
viewer.layers.select_all() | ||
viewer.layers.remove_selected() | ||
|
||
# Simulate | ||
zyx_data = isotropic_fluorescent_thick_3d.apply_transfer_function( | ||
zyx_fluorescence_density, | ||
optical_transfer_function, | ||
transfer_function_arguments["z_padding"], | ||
) | ||
|
||
# Reconstruct | ||
zyx_recon = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function( | ||
zyx_data, | ||
optical_transfer_function, | ||
transfer_function_arguments["z_padding"], | ||
) | ||
|
||
# Display | ||
viewer.add_image( | ||
zyx_fluorescence_density.numpy(), name="Phantom", scale=zyx_scale | ||
) | ||
viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale) | ||
viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale) | ||
input("Showing object, data, and recon. Press <enter> to quit...") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import torch | ||
from waveorder.models import isotropic_fluorescent_thick_3d | ||
|
||
|
||
@pytest.mark.parametrize("axial_flip", (True, False)) | ||
def test_calculate_transfer_function(axial_flip): | ||
z_padding = 5 | ||
transfer_function = ( | ||
isotropic_fluorescent_thick_3d.calculate_transfer_function( | ||
zyx_shape=(20, 100, 101), | ||
yx_pixel_size=6.5 / 40, | ||
z_pixel_size=2, | ||
wavelength_illumination=0.5, | ||
z_padding=z_padding, | ||
index_of_refraction_media=1.0, | ||
numerical_aperture_detection=0.55, | ||
axial_flip=axial_flip, | ||
) | ||
) | ||
|
||
assert transfer_function.shape == (20 + 2 * z_padding, 100, 101) | ||
|
||
|
||
def test_apply_inverse_transfer_function(): | ||
# Create sample data | ||
zyx_data = torch.randn(10, 5, 5) | ||
z_padding = 2 | ||
optical_transfer_function = torch.randn(10 + 2 * z_padding, 5, 5) | ||
|
||
# Test Tikhonov method | ||
result_tikhonov = ( | ||
isotropic_fluorescent_thick_3d.apply_inverse_transfer_function( | ||
zyx_data, | ||
optical_transfer_function, | ||
z_padding, | ||
method="Tikhonov", | ||
reg_re=1e-3, | ||
) | ||
) | ||
assert result_tikhonov.shape == (10, 5, 5) | ||
|
||
# TODO: Fix TV method | ||
# result_tv = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function( | ||
# zyx_data, | ||
# optical_transfer_function, | ||
# z_padding, | ||
# method="TV", | ||
# reg_re=1e-3, | ||
# rho=1e-3, | ||
# itr=10, | ||
# ) | ||
# assert result_tv.shape == (10, 5, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
from waveorder import optics, util | ||
|
||
|
||
def generate_test_phantom( | ||
zyx_shape, | ||
yx_pixel_size, | ||
z_pixel_size, | ||
sphere_radius, | ||
): | ||
sphere, _, _ = util.generate_sphere_target( | ||
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius | ||
) | ||
|
||
return sphere | ||
|
||
|
||
def calculate_transfer_function( | ||
zyx_shape, | ||
yx_pixel_size, | ||
z_pixel_size, | ||
wavelength_illumination, | ||
z_padding, | ||
index_of_refraction_media, | ||
numerical_aperture_detection, | ||
axial_flip=False, | ||
): | ||
radial_frequencies = util.generate_radial_frequencies( | ||
zyx_shape[1:], yx_pixel_size | ||
) | ||
|
||
z_total = zyx_shape[0] + 2 * z_padding | ||
z_position_list = torch.fft.ifftshift( | ||
(torch.arange(z_total) - z_total // 2) * z_pixel_size | ||
) | ||
if axial_flip: | ||
z_position_list = torch.flip(z_position_list, dims=(0,)) | ||
|
||
det_pupil = optics.generate_pupil( | ||
radial_frequencies, | ||
numerical_aperture_detection, | ||
wavelength_illumination, | ||
) | ||
|
||
propagation_kernel = optics.generate_propagation_kernel( | ||
radial_frequencies, | ||
det_pupil, | ||
wavelength_illumination / index_of_refraction_media, | ||
z_position_list, | ||
) | ||
|
||
point_spread_function = ( | ||
torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2 | ||
) | ||
optical_transfer_function = torch.fft.fftn( | ||
point_spread_function, dim=(0, 1, 2) | ||
) | ||
optical_transfer_function /= torch.max( | ||
torch.abs(optical_transfer_function) | ||
) # normalize | ||
|
||
return optical_transfer_function | ||
|
||
|
||
def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale): | ||
arrays = [ | ||
(torch.imag(optical_transfer_function), "Im(OTF)"), | ||
(torch.real(optical_transfer_function), "Re(OTF)"), | ||
] | ||
|
||
for array in arrays: | ||
lim = 0.1 * torch.max(torch.abs(array[0])) | ||
viewer.add_image( | ||
torch.fft.ifftshift(array[0]).cpu().numpy(), | ||
name=array[1], | ||
colormap="bwr", | ||
contrast_limits=(-lim, lim), | ||
scale=1 / zyx_scale, | ||
) | ||
viewer.dims.order = (0, 1, 2) | ||
|
||
|
||
def apply_transfer_function(zyx_object, optical_transfer_function, z_padding): | ||
if ( | ||
zyx_object.shape[0] + 2 * z_padding | ||
!= optical_transfer_function.shape[0] | ||
): | ||
raise ValueError( | ||
"Please check padding: ZYX_obj.shape[0] + 2 * Z_pad != H_re.shape[0]" | ||
) | ||
if z_padding > 0: | ||
optical_transfer_function = optical_transfer_function[ | ||
z_padding:-z_padding | ||
] | ||
|
||
# Very simple simulation, consider adding noise and bkg knobs | ||
zyx_obj_hat = torch.fft.fftn(zyx_object) | ||
zyx_data = zyx_obj_hat * optical_transfer_function | ||
data = torch.real(torch.fft.ifftn(zyx_data)) | ||
|
||
data += 10 # Add a direct background | ||
return data | ||
|
||
|
||
def apply_inverse_transfer_function( | ||
zyx_data, | ||
optical_transfer_function, | ||
z_padding, | ||
method="Tikhonov", | ||
reg_re=1e-3, | ||
rho=1e-3, | ||
itr=10, | ||
): | ||
# Handle padding | ||
zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding) | ||
|
||
# Reconstruct | ||
if method == "Tikhonov": | ||
f_real = util.single_variable_tikhonov_deconvolution_3D( | ||
zyx_padded, optical_transfer_function, reg_re=reg_re | ||
) | ||
|
||
elif method == "TV": | ||
raise NotImplementedError | ||
f_real = util.single_variable_admm_tv_deconvolution_3D( | ||
zyx_padded, | ||
optical_transfer_function, | ||
reg_re=reg_re, | ||
rho=rho, | ||
itr=itr, | ||
) | ||
|
||
# Unpad | ||
if z_padding != 0: | ||
f_real = f_real[z_padding:-z_padding] | ||
|
||
return f_real |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.