From be6b7e65bbcab15444cb9c61dcbb1cc6c45d63d4 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Thu, 5 Dec 2024 15:26:29 -0800 Subject: [PATCH] wip fast psfs --- examples/models/phase_thick_3d.py | 40 ++++++++++++++++++++++++++++++ waveorder/models/phase_thick_3d.py | 21 ++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/examples/models/phase_thick_3d.py b/examples/models/phase_thick_3d.py index 54d7e43..37d84fd 100644 --- a/examples/models/phase_thick_3d.py +++ b/examples/models/phase_thick_3d.py @@ -52,17 +52,57 @@ imag_potential_transfer_function, zyx_scale, ) +import torch +import time + +psf = torch.real(torch.fft.fftn(real_potential_transfer_function)) +psf = torch.fft.fftshift(psf, dim=(0, 1, 2)) +viewer.add_image(psf.numpy(), name="PSF", scale=zyx_scale) + input("Showing OTFs. Press to continue...") viewer.layers.select_all() viewer.layers.remove_selected() # Simulate +# Time the apply_transfer_function method +start_time = time.time() zyx_data = phase_thick_3d.apply_transfer_function( zyx_phase, real_potential_transfer_function, transfer_function_arguments["z_padding"], brightness=1e3, ) +end_time = time.time() +print(f"\tapply_transfer_function took\t{end_time - start_time:.4f} seconds") + +# Time the conv3d method +torch.set_grad_enabled(False) +start_time = time.time() +zyx_phase_cuda = zyx_phase[None, None].to("cuda:1") +psf_cuda = psf[None, None, 25:-25, 50:-51, 50:-51].to("cuda:1") +end_time = time.time() +print(f"\tTransfer to GPU took\t\t{end_time - start_time:.4f} seconds") +start_time = time.time() +zyx_data_cuda = torch.nn.functional.conv3d( + zyx_phase_cuda, + psf_cuda, + padding="same", +) +end_time = time.time() +import pdb; pdb.set_trace() +print(f"\tconv3d took\t\t\t{end_time - start_time:.4f} seconds") +# Time the transfer to CPU +start_time = time.time() +# zyx_data_cuda.cpu() +end_time = time.time() +print(f"\tTransfer to CPU took\t\t{end_time - start_time:.4f} seconds") + +viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale) +viewer.add_image(zyx_data_cuda[0,0].detach().cpu().numpy(), name="Data-truncpsf", scale=zyx_scale) +viewer.add_image(psf_cuda[0,0].detach().cpu().numpy(), name="psf", scale=zyx_scale) +import pdb + +pdb.set_trace() # Reconstruct zyx_recon = phase_thick_3d.apply_inverse_transfer_function( diff --git a/waveorder/models/phase_thick_3d.py b/waveorder/models/phase_thick_3d.py index ec6df00..ec001a9 100644 --- a/waveorder/models/phase_thick_3d.py +++ b/waveorder/models/phase_thick_3d.py @@ -30,6 +30,27 @@ def generate_test_phantom( return zyx_phase +def calculate_point_spread_function( + zyx_shape, + zyx_scale, + wavelength_illumination, + index_of_refraction_media, + numerical_aperture_illumination, + numerical_aperture_detection, + invert_phase_contrast=False, +): + psf_shape = sampling.point_spread_function_shape( + zyx_shape[0], + zyx_scale, + numerical_aperture_illumination, + numerical_aperture_detection, + index_of_refraction_media + ) + psf = torch.zeros(psf_shape) + + # Calculate PS and P on the Fourier grid + + def calculate_transfer_function( zyx_shape,