Skip to content

Commit

Permalink
wip fast psfs
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Dec 5, 2024
1 parent b83dedd commit be6b7e6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
40 changes: 40 additions & 0 deletions examples/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,57 @@
imag_potential_transfer_function,
zyx_scale,
)
import torch
import time

psf = torch.real(torch.fft.fftn(real_potential_transfer_function))
psf = torch.fft.fftshift(psf, dim=(0, 1, 2))
viewer.add_image(psf.numpy(), name="PSF", scale=zyx_scale)

input("Showing OTFs. Press <enter> to continue...")
viewer.layers.select_all()
viewer.layers.remove_selected()

# Simulate
# Time the apply_transfer_function method
start_time = time.time()
zyx_data = phase_thick_3d.apply_transfer_function(
zyx_phase,
real_potential_transfer_function,
transfer_function_arguments["z_padding"],
brightness=1e3,
)
end_time = time.time()
print(f"\tapply_transfer_function took\t{end_time - start_time:.4f} seconds")

# Time the conv3d method
torch.set_grad_enabled(False)
start_time = time.time()
zyx_phase_cuda = zyx_phase[None, None].to("cuda:1")
psf_cuda = psf[None, None, 25:-25, 50:-51, 50:-51].to("cuda:1")
end_time = time.time()
print(f"\tTransfer to GPU took\t\t{end_time - start_time:.4f} seconds")
start_time = time.time()
zyx_data_cuda = torch.nn.functional.conv3d(
zyx_phase_cuda,
psf_cuda,
padding="same",
)
end_time = time.time()
import pdb; pdb.set_trace()
print(f"\tconv3d took\t\t\t{end_time - start_time:.4f} seconds")
# Time the transfer to CPU
start_time = time.time()
# zyx_data_cuda.cpu()
end_time = time.time()
print(f"\tTransfer to CPU took\t\t{end_time - start_time:.4f} seconds")

viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale)
viewer.add_image(zyx_data_cuda[0,0].detach().cpu().numpy(), name="Data-truncpsf", scale=zyx_scale)
viewer.add_image(psf_cuda[0,0].detach().cpu().numpy(), name="psf", scale=zyx_scale)
import pdb

pdb.set_trace()

# Reconstruct
zyx_recon = phase_thick_3d.apply_inverse_transfer_function(
Expand Down
21 changes: 21 additions & 0 deletions waveorder/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ def generate_test_phantom(

return zyx_phase

def calculate_point_spread_function(
zyx_shape,
zyx_scale,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
psf_shape = sampling.point_spread_function_shape(
zyx_shape[0],
zyx_scale,
numerical_aperture_illumination,
numerical_aperture_detection,
index_of_refraction_media
)
psf = torch.zeros(psf_shape)

# Calculate PS and P on the Fourier grid



def calculate_transfer_function(
zyx_shape,
Expand Down

0 comments on commit be6b7e6

Please sign in to comment.