Skip to content

Commit

Permalink
isotropic_fluorescent_thick_3d model (#128)
Browse files Browse the repository at this point in the history
* typo

* model outline

* prototype transfer function

* 3d phantom + visualize transfer function

* refactor apply_transfer_function

* typo

* refactor padding (with gpt docs + tests)

* complete example

* test apply_inverse_transfer_function

* TV reconstructions raise NotImplementedError

* `pad_zyx` -> `pad_zyx_along_z`

* Simplify `data += 10`

* Update tests/test_util.py

Co-authored-by: Ziwen Liu <[email protected]>

---------

Co-authored-by: Ziwen Liu <[email protected]>
  • Loading branch information
talonchandler and ziw-liu authored Jun 16, 2023
1 parent c645c84 commit 1a707e1
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 41 deletions.
File renamed without changes.
73 changes: 73 additions & 0 deletions examples/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import napari
import numpy as np

from waveorder.models import isotropic_fluorescent_thick_3d

# Parameters
# all lengths must use consistent units e.g. um
simulation_arguments = {
"zyx_shape": (100, 256, 256),
"yx_pixel_size": 6.5 / 63,
"z_pixel_size": 0.25,
}
phantom_arguments = {"sphere_radius": 5}
transfer_function_arguments = {
"wavelength_illumination": 0.532,
"z_padding": 0,
"index_of_refraction_media": 1.3,
"numerical_aperture_detection": 1.2,
}

# Create a phantom
zyx_fluorescence_density = (
isotropic_fluorescent_thick_3d.generate_test_phantom(
**simulation_arguments, **phantom_arguments
)
)

# Calculate transfer function
optical_transfer_function = (
isotropic_fluorescent_thick_3d.calculate_transfer_function(
**simulation_arguments, **transfer_function_arguments
)
)

# Display transfer function
viewer = napari.Viewer()
zyx_scale = np.array(
[
simulation_arguments["z_pixel_size"],
simulation_arguments["yx_pixel_size"],
simulation_arguments["yx_pixel_size"],
]
)
isotropic_fluorescent_thick_3d.visualize_transfer_function(
viewer,
optical_transfer_function,
zyx_scale,
)
input("Showing OTFs. Press <enter> to continue...")
viewer.layers.select_all()
viewer.layers.remove_selected()

# Simulate
zyx_data = isotropic_fluorescent_thick_3d.apply_transfer_function(
zyx_fluorescence_density,
optical_transfer_function,
transfer_function_arguments["z_padding"],
)

# Reconstruct
zyx_recon = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
transfer_function_arguments["z_padding"],
)

# Display
viewer.add_image(
zyx_fluorescence_density.numpy(), name="Phantom", scale=zyx_scale
)
viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale)
viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale)
input("Showing object, data, and recon. Press <enter> to quit...")
53 changes: 53 additions & 0 deletions tests/models/test_isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch
from waveorder.models import isotropic_fluorescent_thick_3d


@pytest.mark.parametrize("axial_flip", (True, False))
def test_calculate_transfer_function(axial_flip):
z_padding = 5
transfer_function = (
isotropic_fluorescent_thick_3d.calculate_transfer_function(
zyx_shape=(20, 100, 101),
yx_pixel_size=6.5 / 40,
z_pixel_size=2,
wavelength_illumination=0.5,
z_padding=z_padding,
index_of_refraction_media=1.0,
numerical_aperture_detection=0.55,
axial_flip=axial_flip,
)
)

assert transfer_function.shape == (20 + 2 * z_padding, 100, 101)


def test_apply_inverse_transfer_function():
# Create sample data
zyx_data = torch.randn(10, 5, 5)
z_padding = 2
optical_transfer_function = torch.randn(10 + 2 * z_padding, 5, 5)

# Test Tikhonov method
result_tikhonov = (
isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
z_padding,
method="Tikhonov",
reg_re=1e-3,
)
)
assert result_tikhonov.shape == (10, 5, 5)

# TODO: Fix TV method
# result_tv = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
# zyx_data,
# optical_transfer_function,
# z_padding,
# method="TV",
# reg_re=1e-3,
# rho=1e-3,
# itr=10,
# )
# assert result_tv.shape == (10, 5, 5)
36 changes: 36 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from waveorder import util
import torch
import pytest


def test_gen_coordinate():
Expand All @@ -8,3 +9,38 @@ def test_gen_coordinate():

assert frr.shape == YX_shape
assert frr[0, 0] == 0


# test util.pad_zyx function
@pytest.fixture
def zyx_data():
return torch.ones((3, 4, 5)) # Example input data


def test_pad_zyx_negative_padding():
zyx_data = torch.zeros((3, 4, 5))
z_padding = -1
with pytest.raises(Exception):
util.pad_zyx_along_z(zyx_data, z_padding)


def test_pad_zyx_no_padding(zyx_data):
z_padding = 0
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert torch.all(result == zyx_data)


def test_pad_zyx_small_padding(zyx_data):
z_padding = 2
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert result.shape == (7, 4, 5)
assert torch.all(result[:2] == torch.flip(zyx_data[:2], dims=[0]))
assert torch.all(result[-2:] == torch.flip(zyx_data[-2:], dims=[0]))


def test_pad_zyx_large_padding(zyx_data):
z_padding = 5
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert result.shape == (13, 4, 5)
assert torch.all(result[:5] == 0)
assert torch.all(result[-5:] == 0)
137 changes: 137 additions & 0 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from waveorder import optics, util


def generate_test_phantom(
zyx_shape,
yx_pixel_size,
z_pixel_size,
sphere_radius,
):
sphere, _, _ = util.generate_sphere_target(
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
)

return sphere


def calculate_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
axial_flip=False,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
)

z_total = zyx_shape[0] + 2 * z_padding
z_position_list = torch.fft.ifftshift(
(torch.arange(z_total) - z_total // 2) * z_pixel_size
)
if axial_flip:
z_position_list = torch.flip(z_position_list, dims=(0,))

det_pupil = optics.generate_pupil(
radial_frequencies,
numerical_aperture_detection,
wavelength_illumination,
)

propagation_kernel = optics.generate_propagation_kernel(
radial_frequencies,
det_pupil,
wavelength_illumination / index_of_refraction_media,
z_position_list,
)

point_spread_function = (
torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2
)
optical_transfer_function = torch.fft.fftn(
point_spread_function, dim=(0, 1, 2)
)
optical_transfer_function /= torch.max(
torch.abs(optical_transfer_function)
) # normalize

return optical_transfer_function


def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
arrays = [
(torch.imag(optical_transfer_function), "Im(OTF)"),
(torch.real(optical_transfer_function), "Re(OTF)"),
]

for array in arrays:
lim = 0.1 * torch.max(torch.abs(array[0]))
viewer.add_image(
torch.fft.ifftshift(array[0]).cpu().numpy(),
name=array[1],
colormap="bwr",
contrast_limits=(-lim, lim),
scale=1 / zyx_scale,
)
viewer.dims.order = (0, 1, 2)


def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
if (
zyx_object.shape[0] + 2 * z_padding
!= optical_transfer_function.shape[0]
):
raise ValueError(
"Please check padding: ZYX_obj.shape[0] + 2 * Z_pad != H_re.shape[0]"
)
if z_padding > 0:
optical_transfer_function = optical_transfer_function[
z_padding:-z_padding
]

# Very simple simulation, consider adding noise and bkg knobs
zyx_obj_hat = torch.fft.fftn(zyx_object)
zyx_data = zyx_obj_hat * optical_transfer_function
data = torch.real(torch.fft.ifftn(zyx_data))

data += 10 # Add a direct background
return data


def apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
z_padding,
method="Tikhonov",
reg_re=1e-3,
rho=1e-3,
itr=10,
):
# Handle padding
zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding)

# Reconstruct
if method == "Tikhonov":
f_real = util.single_variable_tikhonov_deconvolution_3D(
zyx_padded, optical_transfer_function, reg_re=reg_re
)

elif method == "TV":
raise NotImplementedError
f_real = util.single_variable_admm_tv_deconvolution_3D(
zyx_padded,
optical_transfer_function,
reg_re=reg_re,
rho=rho,
itr=itr,
)

# Unpad
if z_padding != 0:
f_real = f_real[z_padding:-z_padding]

return f_real
3 changes: 2 additions & 1 deletion waveorder/models/isotropic_thin_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def apply_transfer_function(

# sum and add background
data = zyx_absorption_data + zyx_phase_data
data = torch.tensor(data + 10) # Add a direct background
data += 10 # Add a direct background
return data


Expand Down Expand Up @@ -203,6 +203,7 @@ def apply_inverse_transfer_function(

# ADMM deconvolution with anisotropic TV regularization
elif method == "TV":
raise NotImplementedError
absorption, phase = util.dual_variable_admm_tv_deconv_2d(
AHA, b_vec, rho=rho, itr=itr
)
Expand Down
Loading

0 comments on commit 1a707e1

Please sign in to comment.