Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

isotropic_fluorescent_thick_3d model #128

Merged
merged 14 commits into from
Jun 16, 2023
File renamed without changes.
73 changes: 73 additions & 0 deletions examples/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import napari
import numpy as np

from waveorder.models import isotropic_fluorescent_thick_3d

# Parameters
# all lengths must use consistent units e.g. um
simulation_arguments = {
"zyx_shape": (100, 256, 256),
"yx_pixel_size": 6.5 / 63,
"z_pixel_size": 0.25,
}
phantom_arguments = {"sphere_radius": 5}
transfer_function_arguments = {
"wavelength_illumination": 0.532,
"z_padding": 0,
"index_of_refraction_media": 1.3,
"numerical_aperture_detection": 1.2,
}

# Create a phantom
zyx_fluorescence_density = (
isotropic_fluorescent_thick_3d.generate_test_phantom(
**simulation_arguments, **phantom_arguments
)
)

# Calculate transfer function
optical_transfer_function = (
isotropic_fluorescent_thick_3d.calculate_transfer_function(
**simulation_arguments, **transfer_function_arguments
)
)

# Display transfer function
viewer = napari.Viewer()
zyx_scale = np.array(
[
simulation_arguments["z_pixel_size"],
simulation_arguments["yx_pixel_size"],
simulation_arguments["yx_pixel_size"],
]
)
isotropic_fluorescent_thick_3d.visualize_transfer_function(
viewer,
optical_transfer_function,
zyx_scale,
)
input("Showing OTFs. Press <enter> to continue...")
viewer.layers.select_all()
viewer.layers.remove_selected()

# Simulate
zyx_data = isotropic_fluorescent_thick_3d.apply_transfer_function(
zyx_fluorescence_density,
optical_transfer_function,
transfer_function_arguments["z_padding"],
)

# Reconstruct
zyx_recon = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
transfer_function_arguments["z_padding"],
)

# Display
viewer.add_image(
zyx_fluorescence_density.numpy(), name="Phantom", scale=zyx_scale
)
viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale)
viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale)
input("Showing object, data, and recon. Press <enter> to quit...")
53 changes: 53 additions & 0 deletions tests/models/test_isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch
from waveorder.models import isotropic_fluorescent_thick_3d


@pytest.mark.parametrize("axial_flip", (True, False))
def test_calculate_transfer_function(axial_flip):
z_padding = 5
transfer_function = (
isotropic_fluorescent_thick_3d.calculate_transfer_function(
zyx_shape=(20, 100, 101),
yx_pixel_size=6.5 / 40,
z_pixel_size=2,
wavelength_illumination=0.5,
z_padding=z_padding,
index_of_refraction_media=1.0,
numerical_aperture_detection=0.55,
axial_flip=axial_flip,
)
)

assert transfer_function.shape == (20 + 2 * z_padding, 100, 101)


def test_apply_inverse_transfer_function():
# Create sample data
zyx_data = torch.randn(10, 5, 5)
z_padding = 2
optical_transfer_function = torch.randn(10 + 2 * z_padding, 5, 5)

# Test Tikhonov method
result_tikhonov = (
isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
z_padding,
method="Tikhonov",
reg_re=1e-3,
)
)
assert result_tikhonov.shape == (10, 5, 5)

# TODO: Fix TV method
# result_tv = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
# zyx_data,
# optical_transfer_function,
# z_padding,
# method="TV",
# reg_re=1e-3,
# rho=1e-3,
# itr=10,
# )
# assert result_tv.shape == (10, 5, 5)
36 changes: 36 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from waveorder import util
import torch
import pytest


def test_gen_coordinate():
Expand All @@ -8,3 +9,38 @@ def test_gen_coordinate():

assert frr.shape == YX_shape
assert frr[0, 0] == 0


# test util.pad_zyx function
@pytest.fixture
def zyx_data():
return torch.ones((3, 4, 5)) # Example input data


def test_pad_zyx_negative_padding():
zyx_data = torch.zeros((3, 4, 5))
z_padding = -1
with pytest.raises(Exception):
util.pad_zyx_along_z(zyx_data, z_padding)


def test_pad_zyx_no_padding(zyx_data):
z_padding = 0
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert torch.all(result == zyx_data)


def test_pad_zyx_small_padding(zyx_data):
z_padding = 2
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert result.shape == (7, 4, 5)
assert torch.all(result[:2] == torch.flip(zyx_data[:2], dims=[0]))
assert torch.all(result[-2:] == torch.flip(zyx_data[-2:], dims=[0]))


def test_pad_zyx_large_padding(zyx_data):
z_padding = 5
result = util.pad_zyx_along_z(zyx_data, z_padding)
assert result.shape == (13, 4, 5)
assert torch.all(result[:5] == 0)
assert torch.all(result[-5:] == 0)
137 changes: 137 additions & 0 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from waveorder import optics, util


def generate_test_phantom(
zyx_shape,
yx_pixel_size,
z_pixel_size,
sphere_radius,
):
sphere, _, _ = util.generate_sphere_target(
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
)

return sphere


def calculate_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
axial_flip=False,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
)

z_total = zyx_shape[0] + 2 * z_padding
z_position_list = torch.fft.ifftshift(
(torch.arange(z_total) - z_total // 2) * z_pixel_size
)
if axial_flip:
z_position_list = torch.flip(z_position_list, dims=(0,))

det_pupil = optics.generate_pupil(
radial_frequencies,
numerical_aperture_detection,
wavelength_illumination,
)

propagation_kernel = optics.generate_propagation_kernel(
radial_frequencies,
det_pupil,
wavelength_illumination / index_of_refraction_media,
z_position_list,
)

point_spread_function = (
torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2
)
optical_transfer_function = torch.fft.fftn(
point_spread_function, dim=(0, 1, 2)
)
optical_transfer_function /= torch.max(
torch.abs(optical_transfer_function)
) # normalize

return optical_transfer_function


def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
arrays = [
(torch.imag(optical_transfer_function), "Im(OTF)"),
(torch.real(optical_transfer_function), "Re(OTF)"),
]

for array in arrays:
lim = 0.1 * torch.max(torch.abs(array[0]))
viewer.add_image(
torch.fft.ifftshift(array[0]).cpu().numpy(),
name=array[1],
colormap="bwr",
contrast_limits=(-lim, lim),
scale=1 / zyx_scale,
)
viewer.dims.order = (0, 1, 2)


def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
if (
zyx_object.shape[0] + 2 * z_padding
!= optical_transfer_function.shape[0]
):
raise ValueError(
"Please check padding: ZYX_obj.shape[0] + 2 * Z_pad != H_re.shape[0]"
)
if z_padding > 0:
optical_transfer_function = optical_transfer_function[
z_padding:-z_padding
]

# Very simple simulation, consider adding noise and bkg knobs
zyx_obj_hat = torch.fft.fftn(zyx_object)
zyx_data = zyx_obj_hat * optical_transfer_function
data = torch.real(torch.fft.ifftn(zyx_data))

data += 10 # Add a direct background
return data


def apply_inverse_transfer_function(
zyx_data,
optical_transfer_function,
z_padding,
method="Tikhonov",
reg_re=1e-3,
rho=1e-3,
itr=10,
):
# Handle padding
zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding)

# Reconstruct
if method == "Tikhonov":
f_real = util.single_variable_tikhonov_deconvolution_3D(
zyx_padded, optical_transfer_function, reg_re=reg_re
)

elif method == "TV":
raise NotImplementedError
f_real = util.single_variable_admm_tv_deconvolution_3D(
zyx_padded,
optical_transfer_function,
reg_re=reg_re,
rho=rho,
itr=itr,
)

# Unpad
if z_padding != 0:
f_real = f_real[z_padding:-z_padding]

return f_real
3 changes: 2 additions & 1 deletion waveorder/models/isotropic_thin_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def apply_transfer_function(

# sum and add background
data = zyx_absorption_data + zyx_phase_data
data = torch.tensor(data + 10) # Add a direct background
data += 10 # Add a direct background
return data


Expand Down Expand Up @@ -203,6 +203,7 @@ def apply_inverse_transfer_function(

# ADMM deconvolution with anisotropic TV regularization
elif method == "TV":
raise NotImplementedError
absorption, phase = util.dual_variable_admm_tv_deconv_2d(
AHA, b_vec, rho=rho, itr=itr
)
Expand Down
Loading