diff --git a/examples/models/inplane_oriented_thick_pol3d_vector.py b/examples/models/inplane_oriented_thick_pol3d_vector.py index 72f01d4..c173632 100644 --- a/examples/models/inplane_oriented_thick_pol3d_vector.py +++ b/examples/models/inplane_oriented_thick_pol3d_vector.py @@ -66,19 +66,19 @@ intensity_to_stokes_matrix, ) -# Reconstruct -fzyx_object_recon = ( - inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function( - szyx_data, - singular_system, - intensity_to_stokes_matrix, - regularization_strength=1e-1, - ) -) +# from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer + +# add_transfer_function_to_viewer( +# viewer, +# singular_system[1], +# zyx_scale=(z_pixel_size, yx_pixel_size, yx_pixel_size), +# layer_name="Singular Values", +# ) +# import pdb; pdb.set_trace() + # Display arrays = [ - (fzyx_object_recon, "Object - recon"), (szyx_data, "Data"), (fzyx_object, "Object"), ] @@ -86,6 +86,22 @@ for array in arrays: viewer.add_image(torch.real(array[0]).cpu().numpy(), name=array[1]) + +# Reconstruct +for reg_strength in [0.005, 0.008, 0.01, 0.05, 0.1]: + fzyx_object_recon = ( + inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function( + szyx_data, + singular_system, + intensity_to_stokes_matrix, + regularization_strength=reg_strength, + ) + ) + viewer.add_image( + torch.real(fzyx_object_recon).cpu().numpy(), + name=f"Object - recon, reg_strength={reg_strength}", + ) + viewer.grid.enabled = True viewer.grid.shape = (2, 5) import pdb diff --git a/waveorder/models/inplane_oriented_thick_pol3d_vector.py b/waveorder/models/inplane_oriented_thick_pol3d_vector.py index 1df364a..94e9938 100644 --- a/waveorder/models/inplane_oriented_thick_pol3d_vector.py +++ b/waveorder/models/inplane_oriented_thick_pol3d_vector.py @@ -1,9 +1,10 @@ import torch +import tqdm import numpy as np from torch import Tensor from typing import Literal -from torch.nn.functional import avg_pool3d +from torch.nn.functional import avg_pool3d, interpolate from waveorder import optics, sampling, stokes, util from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer @@ -40,6 +41,7 @@ def calculate_transfer_function( numerical_aperture_detection, invert_phase_contrast=False, fourier_oversample_factor=1, + transverse_downsample_factor=1, ): if z_padding != 0: raise NotImplementedError("Padding not implemented for this model") @@ -58,15 +60,34 @@ def calculate_transfer_function( yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) z_factor = int(np.ceil(z_pixel_size / axial_nyquist)) + print("YX factor:", yx_factor) + print("Z factor:", z_factor) + + tf_calculation_shape = ( + zyx_shape[0] * z_factor * fourier_oversample_factor, + int( + np.ceil( + zyx_shape[1] + * yx_factor + * fourier_oversample_factor + / transverse_downsample_factor + ) + ), + int( + np.ceil( + zyx_shape[2] + * yx_factor + * fourier_oversample_factor + / transverse_downsample_factor + ) + ), + ) + sfZYX_transfer_function, intensity_to_stokes_matrix = ( _calculate_wrap_unsafe_transfer_function( swing, scheme, - ( - zyx_shape[0] * z_factor * fourier_oversample_factor, - zyx_shape[1] * yx_factor * fourier_oversample_factor, - zyx_shape[2] * yx_factor * fourier_oversample_factor, - ), + tf_calculation_shape, yx_pixel_size / yx_factor, z_pixel_size / z_factor, wavelength_illumination, @@ -96,11 +117,29 @@ def calculate_transfer_function( pooled_sfZYX_transfer_function.shape[1], zyx_shape[0] + 2 * z_padding, ) + zyx_shape[1:] + + cropped = sampling.nd_fourier_central_cuboid( + pooled_sfZYX_transfer_function, sfzyx_out_shape + ) + + # Compute singular system on cropped and downsampled + U, S, Vh = calculate_singular_system(cropped) + + # Interpolate to final size in YX + def complex_interpolate(tensor, zyx_shape): + interpolated_real = interpolate(tensor.real, size=zyx_shape) + interpolated_imag = interpolate(tensor.imag, size=zyx_shape) + return interpolated_real + 1j * interpolated_imag + + full_cropped = complex_interpolate(cropped, zyx_shape) + full_U = complex_interpolate(U, zyx_shape) + full_S = interpolate(S[None], size=zyx_shape)[0] # S is real + full_Vh = complex_interpolate(Vh, zyx_shape) + return ( - sampling.nd_fourier_central_cuboid( - pooled_sfZYX_transfer_function, sfzyx_out_shape - ), + full_cropped, intensity_to_stokes_matrix, + (full_U, full_S, full_Vh), ) @@ -142,6 +181,7 @@ def _calculate_wrap_unsafe_transfer_function( z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size) # 2D pupils + print("\tCalculating pupils...") ill_pupil = optics.generate_pupil( radial_frequencies, numerical_aperture_illumination, @@ -187,6 +227,8 @@ def _calculate_wrap_unsafe_transfer_function( P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64) S_3D = torch.fft.ifft(S, dim=-3) + + print("\tCalculating greens tensor spectrum...") G_3D = optics.generate_greens_tensor_spectrum( zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]), zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),