diff --git a/odak/learn/wave/classical.py b/odak/learn/wave/classical.py index bba42146..13d50b64 100644 --- a/odak/learn/wave/classical.py +++ b/odak/learn/wave/classical.py @@ -223,7 +223,7 @@ def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device) FY, FX = torch.meshgrid(fx, fy, indexing = 'ij') k = wavenumber(wavelength) - H = torch.exp(1j* k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5).to(device) + H = torch.exp(1j*distance* (k- torch.pi *wavelength*(FX**2 + FY**2))) return H diff --git a/test/test_learn_wave_propagate_beam.py b/test/test_learn_wave_propagate_beam.py index 3b1ce19d..0c7f244d 100644 --- a/test/test_learn_wave_propagate_beam.py +++ b/test/test_learn_wave_propagate_beam.py @@ -9,41 +9,41 @@ def test(): wavelength = 532e-9 # (1) pixel_pitch = 8e-6 # (2) distance = 0.5e-2 # (3) - propagation_type = 'Bandlimited Angular Spectrum' # (4) + propagation_types = ['Angular Spectrum', 'Transfer Function Fresnel'] # (4) k = odak.learn.wave.wavenumber(wavelength) # (5) - + amplitude = torch.zeros(500, 500) amplitude[200:300, 200:300 ] = 1. # (5) phase = torch.randn_like(amplitude) * 2 * odak.pi # (6) hologram = odak.learn.wave.generate_complex_field(amplitude, phase) # (7) - - image_plane = odak.learn.wave.propagate_beam( - hologram, - k, - distance, - pixel_pitch, - wavelength, - propagation_type, - zero_padding = [True, False, True] # (8) - ) # (9) - - image_intensity = odak.learn.wave.calculate_amplitude(image_plane) ** 2 # (10) - hologram_intensity = amplitude ** 2 - - odak.learn.tools.save_image( - 'image_intensity.png', - image_intensity, - cmin = 0., - cmax = 1. - ) # (11) - odak.learn.tools.save_image( - 'hologram_intensity.png', - hologram_intensity, - cmin = 0., - cmax = 1. - ) # (12) + for propagation_type in propagation_types: + image_plane = odak.learn.wave.propagate_beam( + hologram, + k, + distance, + pixel_pitch, + wavelength, + propagation_type, + zero_padding = [True, False, True] # (8) + ) # (9) + + image_intensity = odak.learn.wave.calculate_amplitude(image_plane) ** 2 # (10) + hologram_intensity = amplitude ** 2 + + odak.learn.tools.save_image( + 'image_intensity_'+'{}.png'.format(propagation_type.replace(' ', '_')), + image_intensity, + cmin = 0., + cmax = 1. + ) # (11) + odak.learn.tools.save_image( + 'hologram_intensity_'+'{}.png'.format(propagation_type.replace(' ', '_')), + hologram_intensity, + cmin = 0., + cmax = 1. + ) # (12) assert True == True