From 60ac2050e801f99088b8d979cc28f33b1977c04d Mon Sep 17 00:00:00 2001 From: Leo Paillet Date: Wed, 6 Mar 2024 09:53:58 +0100 Subject: [PATCH] Update to adapt to scene --- main_script_optim.py | 341 +++++++++++------- simca/CassiSystemOptim.py | 56 +-- simca/OpticalModelTorch.py | 4 +- ...cassi_system_optim_optics_full_triplet.yml | 76 ++-- simca/cost_functions.py | 131 ++++++- simca/functions_acquisition_torch.py | 13 +- simca/functions_optim.py | 52 +-- 7 files changed, 433 insertions(+), 240 deletions(-) diff --git a/main_script_optim.py b/main_script_optim.py index 7e70bc7..58f353b 100644 --- a/main_script_optim.py +++ b/main_script_optim.py @@ -19,7 +19,7 @@ config_patterns = load_yaml_config("simca/configs/pattern.yml") config_acquisition = load_yaml_config("simca/configs/acquisition.yml") -dataset_name = "indian_pines" +dataset_name = "washington_mall" test = "SMILE" @@ -27,7 +27,7 @@ if test=="SMILE": config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") - aspect = 1 + aspect = 0.2 elif test=="EQUAL_LIGHT" or test=="MAX_CENTER": config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") aspect = 1 @@ -38,6 +38,9 @@ config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") aspect = 1 +config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet.yml") +#config_system = load_yaml_config("simca/configs/cassi_system_satur_test.yml") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == '__main__': @@ -47,7 +50,9 @@ # time0 = time.time() # DATASET : Load the hyperspectral dataset - cassi_system.load_dataset(dataset_name, config_dataset["datasets directory"]) + list_dataset_data = cassi_system.load_dataset(dataset_name, config_dataset["datasets directory"]) + """ plt.imshow(np.sum(list_dataset_data[0], axis=2)) + plt.show() """ # Loop beginning if optics optim. cassi_system.update_optical_model(system_config=config_system) @@ -69,30 +74,42 @@ pos_slit_detector_list = [20/145, 40/145, 60/145, 80/145] image_counter = 0 - patterns1 = [] - corrected_patterns1 = [] - smile_positions = [] - corrected_smile_positions = [] - patterns2 = [] - width_values = [] - cubes1 = [] - corrected_cubes1 = [] - cubes2 = [] - acquisitions = [] - - prev_position = None - - #for position in np.linspace(0.4, 0.6, np.round(cassi_system.system_config["coded aperture"]["number of pixels along X"]*0.2).astype('int')): - #for position in np.arange(first_pos, last_pos, step_pos): - for position_i in range(len(pattern_pos)): - position = pattern_pos[position_i] - pos_slit_detector = pos_slit_detector_list[position_i] + + + gen_randint = torch.Generator() + gen_randint.manual_seed(2009) + list_of_rand_seeds = torch.randint(low=0, high=2009, size=(100,), generator=gen_randint) + for seed_i in range(0, len(list_of_rand_seeds)): + seed = int(list_of_rand_seeds[seed_i]) + + pattern_pos = [0.76] + pattern_pos = [0.2] + pattern_pos = [0.68] + pos_slit_detector_list = [50/145] # 64 if pixel_size_Y=200 + pos_slit_detector_list = [33/145] + patterns1 = [] + corrected_patterns1 = [] + smile_positions = [] + corrected_smile_positions = [] + patterns2 = [] + start_width_values = [] + width_values = [] + cubes1 = [] + corrected_cubes1 = [] + cubes2 = [] + acquisitions = [] + + prev_position = None + + position = pattern_pos[0] + pos_slit_detector = pos_slit_detector_list[0] image_counter += 1 print(f"===== Start of image acquisition {image_counter} =====") max_iter_cnt = 25 cassi_system = CassiSystemOptim(system_config=config_system) + cassi_system.device = device # time0 = time.time() # DATASET : Load the hyperspectral dataset @@ -110,7 +127,7 @@ # Adjust the learning rate if algo == "LBFGS": - lr = 0.002 # default: 0.05 + lr = 0.005 # default: 0.05 elif algo == "ADAM": lr = 0.005 # default: 0.005 @@ -118,133 +135,195 @@ pos_slit_detector = 0.41 elif position == 0.7: pos_slit_detector = 0.124 """ - cassi_system = optim_smile(cassi_system, position, pos_slit_detector, sigma, device, algo, lr, num_iterations, max_iter_cnt, prev_position = prev_position, plot_frequency=None) + + data = np.load(f"./results/24-03-04_19h33/results.npz") + if data is None: + cassi_system = optim_smile(cassi_system, position, pos_slit_detector, sigma, device, algo, lr, num_iterations, max_iter_cnt, prev_position = prev_position, plot_frequency=None) + + pattern = cassi_system.pattern.detach().to('cpu').numpy() + cube = cassi_system.filtering_cube.detach().to('cpu').numpy() - pattern = cassi_system.pattern.detach().numpy() - cube = cassi_system.filtering_cube.detach().numpy()[:,:,0] + patterns1.append(pattern) + cubes1.append(cube) + start_position = cassi_system.array_x_positions.detach().to('cpu').numpy() + smile_positions.append(start_position) - patterns1.append(pattern) - cubes1.append(cube) - start_position = cassi_system.array_x_positions.detach().numpy() - smile_positions.append(start_position) + diffs = np.diff(start_position) + diffs_ind = np.nonzero(diffs)[0] + pos_middle = start_position[diffs_ind.min()+1:diffs_ind.max()+1] + poly_coeffs = np.polyfit(np.linspace(1,2, len(pos_middle)), pos_middle, deg = 2) + poly = np.poly1d(poly_coeffs) + start_position[diffs_ind.min()+1:diffs_ind.max()+1] = poly(np.linspace(1,2, len(pos_middle))) - diffs = np.diff(start_position) - diffs_ind = np.nonzero(diffs)[0] - pos_middle = start_position[diffs_ind.min()+1:diffs_ind.max()+1] - poly_coeffs = np.polyfit(np.linspace(1,2, len(pos_middle)), pos_middle, deg = 4) - poly = np.poly1d(poly_coeffs) - start_position[diffs_ind.min()+1:diffs_ind.max()+1] = poly(np.linspace(1,2, len(pos_middle))) + corrected_smile_positions.append(start_position) - corrected_smile_positions.append(start_position) + start_position = torch.tensor(start_position) - start_position = torch.tensor(start_position) + cassi_system.array_x_positions.data = start_position + cassi_system.generate_custom_slit_pattern() + cassi_system.generate_filtering_cube() - cassi_system.array_x_positions.data = start_position - cassi_system.generate_custom_slit_pattern() - cassi_system.generate_filtering_cube() + pattern = cassi_system.pattern.detach().to('cpu').numpy() + cube = cassi_system.filtering_cube.detach().to('cpu').numpy() + corrected_patterns1.append(pattern) + corrected_cubes1.append(cube) - pattern = cassi_system.pattern.detach().numpy() - cube = cassi_system.filtering_cube.detach().numpy()[:,:,0] - corrected_patterns1.append(pattern) - corrected_cubes1.append(cube) + prev_position = (cassi_system.array_x_positions.detach()-position) + else: + patterns1.append(data["patterns_smile"][0]) + cubes1.append(data["cubes_smile"][0]) + smile_positions.append(data["smile_positions"][0]) - prev_position = (cassi_system.array_x_positions.detach()-position) + corrected_smile_positions.append(data["corrected_smile_positions"][0]) + start_position = torch.tensor(data["corrected_smile_positions"][0]) + + corrected_patterns1.append(data["corrected_patterns_smile"][0]) + corrected_cubes1.append(data["corrected_cubes_smile"][0]) + + prev_position = (start_position - position) # Adjust the learning rate if algo == "LBFGS": - lr = 0.002 # default: 0.05 + lr = 0.1 # default: 0.05 elif algo == "ADAM": - lr = 0.01 # default: 0.005 + lr = 0.01 # default: 0.01 - cassi_system = optim_width(cassi_system, start_position, pos_slit_detector, cassi_system.system_config["detector"]["number of pixels along Y"], sigma, device, algo, lr, num_iterations, max_iter_cnt, plot_frequency = None) + target = 100000 + max_iter_cnt = 25 - pattern = cassi_system.pattern.detach().numpy() - cube = cassi_system.filtering_cube.detach().numpy()[:,:,0] - acquisition = cassi_system.measurement.detach().numpy() + #num_iterations = 1 + gen = torch.Generator() + gen.manual_seed(seed) + start_width = torch.rand(size=(1,cassi_system.system_config["detector"]["number of pixels along Y"]), generator=gen)/3 + start_width_values.append(start_width.detach().to('cpu').numpy()) + # Create first histogram + cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=cassi_system.system_config["coded aperture"]["number of pixels along Y"], start_width = start_width) + cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = start_position) + cassi_system.generate_filtering_cube() + cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) + acq = cassi_system.image_acquisition(use_psf=False, chunck_size=cassi_system.system_config["detector"]["number of pixels along Y"]).detach().to('cpu').numpy() + fig_first_histo = plt.figure() + #plt.imshow(torch.sum(cassi_system.dataset, dim=2)) + #plt.imshow(acq, aspect=0.2) + plt.hist(acq[acq>100].flatten(), bins=100) + + # Run optimization + cassi_system = optim_width(cassi_system, start_position, target, cassi_system.system_config["coded aperture"]["number of pixels along Y"], start_width, device, algo, lr, num_iterations, max_iter_cnt, plot_frequency = None) + + pattern = cassi_system.pattern.detach().to('cpu').numpy() + cube = cassi_system.filtering_cube.detach().to('cpu').numpy() + acquisition = cassi_system.measurement.detach().to('cpu').numpy() + patterns2.append(pattern) cubes2.append(cube) acquisitions.append(acquisition) - width_values.append(cassi_system.array_x_positions.detach().numpy()) + width_values.append(cassi_system.array_x_positions.detach().to('cpu').numpy()) - #print(torch.std(cassi_system.measurement.detach())) - - print(f"Exec time: {time.time() - time_start}s") - - fig1 = plt.figure() - im1 = plt.imshow(patterns1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig1bis = plt.figure() - im1bis = plt.imshow(corrected_patterns1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig2 = plt.figure() - im2 = plt.imshow(cubes1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig2bis = plt.figure() - im2bis = plt.imshow(corrected_cubes1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig3 = plt.figure() - im3 = plt.imshow(patterns2[0], animated = True, aspect=aspect) - plt.colorbar() - - fig4 = plt.figure() - im4 = plt.imshow(cubes2[0], animated = True, aspect=aspect) - plt.colorbar() - - fig5 = plt.figure() - im5 = plt.imshow(acquisitions[0], animated = True, aspect=aspect) - plt.colorbar() - - def update1(i): - im1.set_array(patterns1[i]) - return im1, - def update1bis(i): - im1bis.set_array(corrected_patterns1[i]) - return im1bis, - def update2(i): - im2.set_array(cubes1[i]) - return im2, - def update2bis(i): - im2bis.set_array(corrected_cubes1[i]) - return im2bis, - def update3(i): - im3.set_array(patterns2[i]) - return im3, - def update4(i): - im4.set_array(cubes2[i]) - return im4, - def update5(i): - im5.set_array(acquisitions[i]) - return im5, - - animation_fig1 = anim.FuncAnimation(fig1, update1, frames=len(patterns1), interval = 1000, repeat=True) - animation_fig1bis = anim.FuncAnimation(fig1bis, update1bis, frames=len(corrected_patterns1), interval = 1000, repeat=True) - animation_fig2 = anim.FuncAnimation(fig2, update2, frames=len(cubes1), interval = 1000, repeat=True) - animation_fig2bis = anim.FuncAnimation(fig2bis, update2bis, frames=len(corrected_cubes1), interval = 1000, repeat=True) - animation_fig3 = anim.FuncAnimation(fig3, update3, frames=len(patterns2), interval = 1000, repeat=True) - animation_fig4 = anim.FuncAnimation(fig4, update4, frames=len(cubes2), interval = 1000, repeat=True) - animation_fig5 = anim.FuncAnimation(fig5, update5, frames=len(acquisitions), interval = 1000, repeat=True) - - plt.show() - - folder = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') - os.makedirs(f"./results/{folder}") - - animation_fig1.save(f"./results/{folder}/patterns_smile.gif") - animation_fig1bis.save(f"./results/{folder}/corrected_patterns_smile.gif") - animation_fig2.save(f"./results/{folder}/cubes_smile.gif") - animation_fig2bis.save(f"./results/{folder}/corrected_cubes_smile.gif") - animation_fig3.save(f"./results/{folder}/patterns_width.gif") - animation_fig4.save(f"./results/{folder}/cubes_width.gif") - animation_fig5.save(f"./results/{folder}/acquisitions.gif") - - np.savez(f"./results/{folder}/results.npz", smile_positions=np.stack(smile_positions, axis=0), patterns_smile=np.stack(patterns1, axis=0), cubes_smile = np.stack(cubes1, axis=0), - corrected_smile_positions=np.stack(corrected_smile_positions, axis=0), corrected_patterns_smile=np.stack(corrected_patterns1, axis=0), corrected_cubes_smile=np.stack(corrected_cubes1, axis=0), - width_values=np.stack(width_values, axis=0), patterns_width=np.stack(patterns2, axis=0), cubes_width = np.stack(cubes2, axis=0), - acquisitions=np.stack(acquisitions, axis=0)) + #print(torch.std(cassi_system.measurement.detach())) + + print(f"Exec time: {time.time() - time_start}s") + + fig1 = plt.figure() + im1 = plt.imshow(patterns1[0], animated = True, aspect=aspect) + plt.colorbar() + + fig1bis = plt.figure() + im1bis = plt.imshow(corrected_patterns1[0], animated = True, aspect=aspect) + plt.colorbar() + + fig2 = plt.figure() + im2 = plt.imshow(cubes1[0][:,:,cubes1[0].shape[2]//2], animated = True, aspect=aspect) + plt.colorbar() + + fig2bis = plt.figure() + im2bis = plt.imshow(corrected_cubes1[0][:,:,corrected_cubes1[0].shape[2]//2], animated = True, aspect=aspect) + plt.colorbar() + + fig3 = plt.figure() + im3 = plt.imshow(patterns2[0], animated = True, aspect=aspect) + plt.colorbar() + + fig4 = plt.figure() + im4 = plt.imshow(cubes2[0][:,:,cubes2[0].shape[2]//2], animated = True, aspect=aspect) + plt.colorbar() + + fig5 = plt.figure() + im5 = plt.imshow(np.clip(acquisitions[0], 1, None), animated = True, aspect=aspect, cmap="gray", norm="log") + plt.colorbar() + + def update1(i): + im1.set_array(patterns1[i]) + return im1, + def update1bis(i): + im1bis.set_array(corrected_patterns1[i]) + return im1bis, + def update2(i): + im2.set_array(cubes1[i][:,:,cubes1[0].shape[2]//2]) + return im2, + def update2bis(i): + im2bis.set_array(corrected_cubes1[i][:,:,corrected_cubes1[0].shape[2]//2]) + return im2bis, + def update3(i): + im3.set_array(patterns2[i]) + return im3, + def update4(i): + im4.set_array(cubes2[i][:,:,cubes2[0].shape[2]//2]) + return im4, + def update5(i): + im5.set_array(acquisitions[i]) + return im5, + + animation_fig1 = anim.FuncAnimation(fig1, update1, frames=len(patterns1), interval = 1000, repeat=True) + animation_fig1bis = anim.FuncAnimation(fig1bis, update1bis, frames=len(corrected_patterns1), interval = 1000, repeat=True) + animation_fig2 = anim.FuncAnimation(fig2, update2, frames=len(cubes1), interval = 1000, repeat=True) + animation_fig2bis = anim.FuncAnimation(fig2bis, update2bis, frames=len(corrected_cubes1), interval = 1000, repeat=True) + animation_fig3 = anim.FuncAnimation(fig3, update3, frames=len(patterns2), interval = 1000, repeat=True) + animation_fig4 = anim.FuncAnimation(fig4, update4, frames=len(cubes2), interval = 1000, repeat=True) + animation_fig5 = anim.FuncAnimation(fig5, update5, frames=len(acquisitions), interval = 1000, repeat=True) + + #print("Var: ", np.var(acquisitions[0][int(pos_slit_detector*145)-2:int(pos_slit_detector*145)+2].flatten())) + print("Var: ", np.var(acquisitions[0][acquisitions[0]>100].flatten())) + + + fig6 = plt.figure() + plt.hist(acquisitions[0][acquisitions[0]>100].flatten(), bins=100) + + plt.show() + + + + + folder = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') + os.makedirs(f"./results/{folder}") + + fig6.savefig(f"./results/{folder}/histo.png") + fig_first_histo.savefig(f"./results/{folder}/first_histo.png") + fig1.savefig(f"./results/{folder}/patterns_smile.png") + fig1bis.savefig(f"./results/{folder}/corrected_patterns_smile.png") + fig2.savefig(f"./results/{folder}/cubes_smile.png") + fig2bis.savefig(f"./results/{folder}/corrected_cubes_smile.png") + fig3.savefig(f"./results/{folder}/patterns_width.png") + fig4.savefig(f"./results/{folder}/cubes_width.png") + fig5.savefig(f"./results/{folder}/acquisitions.png") + #animation_fig1.save(f"./results/{folder}/patterns_smile.gif") + #animation_fig1bis.save(f"./results/{folder}/corrected_patterns_smile.gif") + #animation_fig2.save(f"./results/{folder}/cubes_smile.gif") + #animation_fig2bis.save(f"./results/{folder}/corrected_cubes_smile.gif") + #animation_fig3.save(f"./results/{folder}/patterns_width.gif") + #animation_fig4.save(f"./results/{folder}/cubes_width.gif") + #animation_fig5.save(f"./results/{folder}/acquisitions.gif") + + """np.savez(f"./results/{folder}/results.npz", smile_positions=np.stack(smile_positions, axis=0), patterns_smile=np.stack(patterns1, axis=0), cubes_smile = np.stack(cubes1, axis=0), + corrected_smile_positions=np.stack(corrected_smile_positions, axis=0), corrected_patterns_smile=np.stack(corrected_patterns1, axis=0), corrected_cubes_smile=np.stack(corrected_cubes1, axis=0), + start_width_values=np.stack(start_width_values, axis=0), width_values=np.stack(width_values, axis=0), patterns_width=np.stack(patterns2, axis=0), cubes_width = np.stack(cubes2, axis=0), + acquisitions=np.stack(acquisitions, axis=0), + variance=np.var(acquisitions[0][acquisitions[0]>500].flatten()))""" + np.savez(f"./results/{folder}/results.npz", smile_positions=np.stack(smile_positions, axis=0), patterns_smile=np.stack(patterns1, axis=0), + corrected_smile_positions=np.stack(corrected_smile_positions, axis=0), corrected_patterns_smile=np.stack(corrected_patterns1, axis=0), + start_width_values=np.stack(start_width_values, axis=0), width_values=np.stack(width_values, axis=0), patterns_width=np.stack(patterns2, axis=0), + acquisitions=np.stack(acquisitions, axis=0), + variance=np.var(acquisitions[0][acquisitions[0]>500].flatten())) + #np.savez(f"./results/{folder}/results.npz", patterns_smile=np.stack(patterns1, axis=0), cubes_smile = np.stack(cubes1, axis=0)) diff --git a/simca/CassiSystemOptim.py b/simca/CassiSystemOptim.py index 0f74558..9da4963 100755 --- a/simca/CassiSystemOptim.py +++ b/simca/CassiSystemOptim.py @@ -34,7 +34,6 @@ def __init__(self, system_config=None): self.wavelengths = self.set_wavelengths(self.system_config["spectral range"]["wavelength min"], self.system_config["spectral range"]["wavelength max"], self.system_config["spectral range"]["number of spectral samples"]) - self.optical_model = OpticalModelTorch(self.system_config) @@ -63,7 +62,7 @@ def create_coordinates_grid(self, nb_of_pixels_along_x, nb_of_pixels_along_y, de # Create a two-dimensional grid of coordinates X_input_grid, Y_input_grid = np.meshgrid(x, y) - return torch.from_numpy(X_input_grid), torch.from_numpy(Y_input_grid) + return torch.from_numpy(X_input_grid).float(), torch.from_numpy(Y_input_grid).float() def set_wavelengths(self, wavelength_min, wavelength_max, nb_of_spectral_samples): """ @@ -154,8 +153,7 @@ def generate_filtering_cube(self): X_init=self.X_coordinates_propagated_coded_aperture, Y_init=self.Y_coordinates_propagated_coded_aperture, X_target=self.X_detector_coordinates_grid, - Y_target=self.Y_detector_coordinates_grid) - + Y_target=self.Y_detector_coordinates_grid).to(self.device) return self.filtering_cube @@ -185,7 +183,7 @@ def image_acquisition(self, use_psf=False, chunck_size=50): return print("Please generate filtering cube first") scene = match_dataset_to_instrument(dataset, self.filtering_cube) - scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)) if isinstance(scene, np.ndarray) else scene + scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device) measurement_in_3D = generate_dd_measurement_torch(scene, self.filtering_cube, chunck_size) @@ -238,6 +236,7 @@ def image_acquisition(self, use_psf=False, chunck_size=50): def generate_custom_pattern_parameters_slit(self, position=0.5): # Position is a float: 0 means slit is on the left edge, 1 means the slit is on the right edge self.array_x_positions = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"]))+ position + self.array_x_positions = self.array_x_positions.to(self.device) return self.array_x_positions def generate_custom_pattern_parameters_slit_width(self, nb_slits=1, nb_rows=1, start_width=1): @@ -246,14 +245,16 @@ def generate_custom_pattern_parameters_slit_width(self, nb_slits=1, nb_rows=1, s # self.array_x_positions is of shape (self.system_config["coded aperture"]["number of pixels along Y"], nb_rows) self.array_x_positions = torch.zeros((nb_slits, nb_rows))+start_width # Every slit starts with the same width + self.array_x_positions = self.array_x_positions.to(self.device) + self.array_x_positions_normalized = torch.zeros((nb_slits, nb_rows))+start_width return self.array_x_positions def generate_custom_slit_pattern_width(self, start_pattern = "line", start_position = 0): nb_slits, nb_rows = self.array_x_positions.shape - pos_slits = self.system_config["coded aperture"]["number of pixels along Y"]//(nb_slits+1) # Equally spaced slits - height_slits = self.system_config["coded aperture"]["number of pixels along X"]//nb_rows # Same length slits + pos_slits = self.system_config["coded aperture"]["number of pixels along X"]//(nb_slits+1) # Equally spaced slits + height_slits = self.system_config["coded aperture"]["number of pixels along Y"]//nb_rows # Same length slits - self.pattern = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"])) # Pattern of correct size + self.pattern = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"])).to(self.device) # Pattern of correct size if start_pattern == "line": if start_position != 0: start_position = start_position - pos_slits/self.system_config["coded aperture"]["number of pixels along X"] @@ -312,22 +313,35 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit # Normalize to make sure the maximum value is 1 self.pattern = self.pattern + padded/padded.max() """ - c = torch.tensor(start_position[i]) - d = self.array_x_positions[j,i]/2 - m = (c-d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) - M = (c+d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) - rect = torch.arange(self.system_config["coded aperture"]["number of pixels along X"]) - rect = torch.clamp(-(rect-m)*(rect-M)+1,0,1) + c = start_position[i].clone().detach() # center of the slit + #d = ((torch.tanh(1.1*self.array_x_positions[j,i])+1)/2)/2 # width of the slit at pos + d = self.array_x_positions[j,i]/2 # width of the slit at pos + m = (c-d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # left bound + M = (c+d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # right bound + rect = torch.arange(self.system_config["coded aperture"]["number of pixels along X"]).to(self.device) + clamp_M = torch.clamp(M-rect, 0, 1) + + clamp_m = torch.clamp(rect-m, 0, 1) + diff = 1-clamp_m + reg = torch.where(diff < 1, diff, -1) + clamp_m = torch.where(reg!=0, reg, 1) + clamp_m = torch.where(clamp_m!=-1, clamp_m, 0) + clamp_m = torch.roll(clamp_m, -1) + clamp_m[-1]=1 + + rect = clamp_M - clamp_m +1 + rect = torch.where(rect!=2, rect, 0) + rect = torch.where(rect <= 1, rect, rect-1) + #rect = torch.clamp(-(rect-m)*(rect-M)+1,0,1).to(self.device) gaussian_range = torch.arange(self.system_config["coded aperture"]["number of pixels along X"], dtype=torch.float32) center_pos = 0.5*(len(gaussian_range)-1) sigma = 1.5 - gaussian_peaks = np.exp(-((center_pos - gaussian_range) ** 2) / (2 * sigma ** 2)) + gaussian_peaks = torch.exp(-((center_pos - gaussian_range) ** 2) / (2 * sigma ** 2)).to(self.device) gaussian = gaussian_peaks /gaussian_peaks.max() - res = torch.nn.functional.conv1d(rect.unsqueeze(0), gaussian.unsqueeze(0).unsqueeze(0), padding = 144//2).squeeze() + res = torch.nn.functional.conv1d(rect.unsqueeze(0), gaussian.unsqueeze(0).unsqueeze(0), padding = (len(gaussian_range)-1)//2).squeeze().to(self.device) res = res/res.max() - self.pattern[i, :] = self.pattern[i, :] + res # Normalize to make sure the maximum value is 1 @@ -338,15 +352,15 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit def generate_custom_slit_pattern(self): # Create a grid to represent positions - grid_positions = torch.arange(self.empty_grid.shape[1], dtype=torch.float32) + grid_positions = torch.arange(self.empty_grid.shape[1], dtype=torch.float32).to(self.device) # Expand dimensions for broadcasting - expanded_x_positions = (self.array_x_positions.unsqueeze(-1)) * (self.empty_grid.shape[1]-1) - expanded_grid_positions = grid_positions.unsqueeze(0) + expanded_x_positions = ((self.array_x_positions.unsqueeze(-1)) * (self.empty_grid.shape[1]-1)).to(self.device) + expanded_grid_positions = grid_positions.unsqueeze(0).to(self.device) # Apply Gaussian-like function # Adjust 'sigma' to control the sharpness sigma = 1.5 - gaussian_peaks = torch.exp(-((expanded_grid_positions - expanded_x_positions) ** 2) / (2 * sigma ** 2)) + gaussian_peaks = torch.exp(-((expanded_grid_positions - expanded_x_positions) ** 2) / (2 * sigma ** 2)).to(self.device) # Normalize to make sure the maximum value is 1 self.pattern = gaussian_peaks / gaussian_peaks.max() diff --git a/simca/OpticalModelTorch.py b/simca/OpticalModelTorch.py index 721d9f1..85e38cb 100644 --- a/simca/OpticalModelTorch.py +++ b/simca/OpticalModelTorch.py @@ -46,7 +46,7 @@ def rotation_y_torch(theta): """ # Ensure theta is a tensor with requires_grad=True if not isinstance(theta, torch.Tensor): - theta = torch.tensor(theta, requires_grad=True) + theta = torch.tensor(theta, requires_grad=True, dtype=torch.float32) cos_theta = torch.cos(theta) sin_theta = torch.sin(theta) @@ -78,7 +78,7 @@ def rotation_x_torch(theta): """ # Ensure theta is a tensor with requires_grad=True if not isinstance(theta, torch.Tensor): - theta = torch.tensor(theta, requires_grad=True) + theta = torch.tensor(theta, requires_grad=True, dtype=torch.float32) cos_theta = torch.cos(theta) sin_theta = torch.sin(theta) diff --git a/simca/configs/cassi_system_optim_optics_full_triplet.yml b/simca/configs/cassi_system_optim_optics_full_triplet.yml index 61b1946..c09e29a 100755 --- a/simca/configs/cassi_system_optim_optics_full_triplet.yml +++ b/simca/configs/cassi_system_optim_optics_full_triplet.yml @@ -1,48 +1,48 @@ ##### Configuration file for the chosen optical system infos: - system name: HYACAMEO + system name: HYACAMEO system architecture: - system type: DD-CASSI - propagation type: simca - focal lens: 50000 - dispersive element: + system type: DD-CASSI + propagation type: simca + focal lens: 50000 + dispersive element: # dispersive element caracteristics - type: tripleprism # name of the dispersive element - glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism') - glass2: SF4 # glass type of the dispersive element (only used if type == 'prism') - glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism') - A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees - A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees - A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees - nd1: 1.6074 - nd2: 1.7552 - nd3: 1.6074 - vd1: 56.65 - vd2: 27.58 - vd3: 56.65 - continuous glass materials 1: False - continuous glass materials 2: False - continuous glass materials 3: False - m: 1 # grating order to consider (only used if type == 'grating') -- no units - G: 30 # grating density (only used if type == 'grating') -- in lines/mm - alpha_c: 2.6 - delta alpha c: 0 - delta beta c: 0 - wavelength center: 600 # central wavelength -- in nm + type: tripleprism # name of the dispersive element + glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + glass2: SF4 # glass type of the dispersive element (only used if type == 'prism') + glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + nd1: 1.6074 + nd2: 1.7552 + nd3: 1.6074 + vd1: 56.65 + vd2: 27.58 + vd3: 56.65 + continuous glass materials 1: False + continuous glass materials 2: False + continuous glass materials 3: False + m: 1 # grating order to consider (only used if type == 'grating') -- no units + G: 30 # grating density (only used if type == 'grating') -- in lines/mm + alpha_c: 2.6 + delta alpha c: 0 + delta beta c: 0 + wavelength center: 600 # central wavelength -- in nm detector: - number of pixels along X: 201 # number of pixels along X axis -- no units - number of pixels along Y: 201 # number of pixels along Y axis -- no units - pixel size along X: 40 # pixel size along X -- in um - pixel size along Y: 40 # pixel size along Y -- in um + number of pixels along X: 145 # number of pixels along X axis -- no units + number of pixels along Y: 901 # number of pixels along Y axis -- no units + pixel size along X: 15 # pixel size along X -- in um + pixel size along Y: 15 # pixel size along Y -- in um coded aperture: - number of pixels along X: 11 # number of pixels along X axis -- no units - number of pixels along Y: 13 # number of pixels along Y axis -- no units - pixel size along X: 1000 # pixel size along X -- in um - pixel size along Y: 1000 # pixel size along Y -- in um + number of pixels along X: 145 # 151 # number of pixels along X axis -- no units + number of pixels along Y: 901 # 151 # number of pixels along Y axis -- no units + pixel size along X: 15 # pixel size along X -- in um + pixel size along Y: 15 # pixel size along Y -- in um spectral range: - wavelength min: 400 # minimum wavelength -- in nm - wavelength max: 1050 # maximum wavelength -- in nm - number of spectral samples: 15 + wavelength min: 410 # minimum wavelength -- in nm + wavelength max: 1050 # maximum wavelength -- in nm + number of spectral samples: 151 diff --git a/simca/cost_functions.py b/simca/cost_functions.py index 7ece0d0..07ac8e6 100644 --- a/simca/cost_functions.py +++ b/simca/cost_functions.py @@ -1,19 +1,22 @@ import torch from matplotlib import pyplot as plt +import math -def evaluate_slit_scanning_straightness(filtering_cube, sigma = 0.75, pos_slit=0.5): +def evaluate_slit_scanning_straightness(filtering_cube, device, sigma = 0.75, pos_slit=0.5): """ Evaluate the straightness of the slit scanning. working cost function for up to focal >100000 """ pos_cube = round(filtering_cube.shape[1]*pos_slit) - cost_value = torch.tensor(0.0, requires_grad=True) - - gaussian = torch.arange(filtering_cube.shape[1]) - pos_cube - gaussian = torch.exp(-torch.square(gaussian)/(2*sigma**2)).unsqueeze(0) - for i in range(filtering_cube.shape[2]): - """ vertical_binning = torch.sum(filtering_cube[:, :, i], axis=0) + cost_value = torch.tensor(0.0, requires_grad=True).to(device) + + gaussian = (torch.arange(filtering_cube.shape[1]) - pos_cube).to(device) + gaussian = torch.exp(-torch.square(gaussian)/(2*sigma**2)).unsqueeze(0).to(device) + gaussian = gaussian + w = filtering_cube.shape[2]//2 + """for i in range(filtering_cube.shape[2]): + vertical_binning = torch.sum(filtering_cube[:, :, i], axis=0) #max_value = torch.max(vertical_binning) std_deviation_vertical = torch.std(vertical_binning) #std_deviation_horizontal = torch.std(torch.sum(filtering_cube[:,:,i], axis=1)) @@ -23,8 +26,8 @@ def evaluate_slit_scanning_straightness(filtering_cube, sigma = 0.75, pos_slit=0 row_diffs = filtering_cube[1:, :, i] - filtering_cube[:-1, :, i] #cost_value = cost_value + max_value / std_deviation - torch.sum(torch.sum(torch.abs(row_diffs))) cost_value = cost_value + std_deviation_vertical - 0.2*torch.sum(torch.sum(torch.square(row_diffs))) #- 0.2*torch.sum(torch.sum(torch.abs(row_diffs)))""" - row_diffs = filtering_cube[1:, :, 0] - filtering_cube[:-1, :, 0] - cost_value = cost_value - torch.sum((torch.abs(filtering_cube[:, :, 0] - gaussian)+1e-8)**0.4) + 0.6*torch.sum(filtering_cube[:, pos_cube,0]) - 0.8*torch.sum(torch.sum(torch.abs(row_diffs))) #- 2*torch.sum(torch.var(filtering_cube[:, :, 0], dim=0)) + row_diffs = filtering_cube[1:, :, w] - filtering_cube[:-1, :, w] + cost_value = cost_value - torch.sum((torch.abs(filtering_cube[:, :, w] - gaussian)+1e-8)**0.4) + 0.6*torch.sum(filtering_cube[:, pos_cube,w]) - 0.8*torch.sum(torch.sum(torch.abs(row_diffs))) #- 2*torch.sum(torch.var(filtering_cube[:, :, 0], dim=0)) #delta = 2 #cost_value = cost_value - (delta**2)*torch.sum((torch.sqrt(1+((filtering_cube[:, :, 0] - gaussian)/delta)**2)-1)) # pseudo-huber loss # Minimizing the negative of cost_value to maximize the original objective @@ -47,16 +50,112 @@ def evaluate_mean_lighting(acquisition): return -cost_value -def evaluate_max_lighting(acquisition, pos_check): +def evaluate_max_lighting(widths, acquisition, target): cost_value = 0 - pos_cube = round(acquisition.shape[1]*pos_check) - - col = acquisition[:, pos_cube] - - cost_value = 2*torch.mean(col)**2 - 8*torch.var(col) + #col = acquisition[:, pos_cube] + #col = acquisition[:, pos_cube-2:pos_cube+2] + col = acquisition[acquisition>100].flatten() + #col = acquisition[93:208, 30].unsqueeze(1) + """for i in range(1, 5): + col = torch.cat((col, acquisition[93:208, 30+i*15].unsqueeze(1)), 1) """ + + #cost_value = 2*torch.mean(col)**2 - 25*torch.var(col) + #cost_value = - torch.var(col) + """ cost_value = 15*torch.mean(col)**2 - 25*torch.var(col) + cost_value = 8000*torch.mean(col)**2 - torch.sum((col-10000)**2) + cost_value = 0.75*torch.mean(col) - torch.mean(torch.abs(col-6000)) + cost_value = torch.mean(col) - 2*torch.std(col) + + lines = torch.mean(acquisition, axis=1) + + #cost_value = - torch.var(torch.log(col)) + + cost_value = - torch.var((torch.log(col)- torch.log(torch.tensor([40000])))**2) + cost_value = - torch.var(torch.log(col)) - torch.log(torch.var(col))# - torch.mean((torch.log(col)- torch.log(torch.tensor([14000])))**2) + #cost_value = - torch.var(torch.log(col)**2) - torch.var((torch.log(col)- torch.log(torch.tensor([20000])))**2) + cost_value = - torch.var(torch.log(col)**2) - 2*torch.var((torch.log(col)- torch.log(torch.tensor([6000])))**2) + #cost_value = - torch.sum((2000*10000*((col-2000)+(col-10000)) - (10000-2000))/2) + #cost_value = -torch.var(torch.log(col)) """ + #cost_value = -torch.var(torch.exp(col/11000)) + row_diffs = torch.abs(widths[0,1:] - widths[0,:-1]) + #cost_value = - torch.var(torch.exp(col/18100)) - torch.sum(-torch.log(1+row_diffs)) + #print(torch.var(torch.exp(col/20000))) + #print(torch.mean(torch.log(col))) + #print(torch.sum(-torch.log(1+row_diffs))) + - return -cost_value + def saturation(scene, target_, margin=0.05): + cost = 0 + for elem in scene: + if elem <= target_*(1+margin): + cost += (target_-elem) + else: + cost += (elem-target_)**3 + return cost + + def bowl(scene, target_, saturation=None): + if saturation is None: + saturation = target_*1.2 + s = saturation + t = target_ + cost = 0 + for x in scene: + if x <= t: + cost += ((t-x)/t)**2 + elif x < s: + A = -1/((s-t)**2) + 2/(t**2) + B = -1/(s-t) + t/((s-t)**2) - 2/t + C = - 1/2*A*t - B*t + cost += - math.log((s-x)/(s-t)) + 1/2*A*x**2 + B*x + C + #cost += - math.exp(80*(1-(saturation-elem)/(saturation-target_))) + #cost += (elem-target_)**4 + else: + cost += 1e18*x**2 + return cost + + def bowl_inverse(scene, target_, saturation=None): + if saturation is None: + saturation = target_*1.2 + s = saturation + t = target_ + cost = 0 + for x in scene: + if x <= t: + cost += ((t-x)/t)**2 + elif x < s: + A = -2/((s-t)**3) + 2/(t**2) + B = -1/((s-t)**2) - A*t + C = -1/(s-t) - 1/2*A*t - B*t + cost += - 1/(s-x) + 1/2*A*x**2 + B*x + C + #cost += - math.exp(80*(1-(saturation-elem)/(saturation-target_))) + #cost += (elem-target_)**4 + else: + cost += 1e18*x**2 + return cost + + + #cost_value = - torch.var(col) - torch.sum(-torch.log(1+row_diffs)) + #cost_value = - torch.var(torch.exp(col/18000)) - 10000*torch.sum(torch.log(1/(1+row_diffs))) + #cost_value = - torch.var(torch.exp(col/30000))# - 2*torch.count_nonzero(row_diffs) + #cost_value = - torch.var(torch.exp(col/20000)) - 40*torch.count_nonzero(row_diffs) + #cost_value = torch.mean(col) + #cost_value = - 2*torch.var((torch.exp(col/20000)- torch.exp(torch.tensor([9000])/20000))**2) + #cost_value = - saturation(col, 45000, margin=0.1) + + #print("Jumps: ", torch.count_nonzero(row_diffs)) + print("Var: ", torch.var(col)) + #print("Saturation: ", - saturation(col.flatten(), 120000, margin=0.1)) + print("Min: ", torch.min(col)) + print("Mean: ", torch.mean(col)) + print("Max: ", torch.max(col)) + + #cost_value = - torch.var(torch.exp(col/100000)) #- 1e-5*saturation(col.flatten(), 200000, margin=0.1) + #cost_value = - torch.var((torch.log(col.squeeze())- torch.log(torch.tensor([200000]).to('cuda')))**2) + #cost_value = - torch.var(col) #- saturation(col.flatten(), 120000, margin=0.1) + cost_value = - bowl(col, target, saturation=2.2e6) + print("Cost: ", - cost_value) + return - cost_value # def evalute_slit_scanning_straightness(filtering_cube,threshold): # """ diff --git a/simca/functions_acquisition_torch.py b/simca/functions_acquisition_torch.py index 4c8fa86..ded6b67 100644 --- a/simca/functions_acquisition_torch.py +++ b/simca/functions_acquisition_torch.py @@ -2,6 +2,7 @@ from tqdm import tqdm import torch from torch_geometric.nn.unpool import knn_interpolate +import snoop # TODO: sd measurement torch def generate_sd_measurement_cube(filtered_scene,X_input, Y_input, X_target, Y_target,grid_type,interp_method): @@ -27,7 +28,6 @@ def generate_sd_measurement_cube(filtered_scene,X_input, Y_input, X_target, Y_ta interp_method=interp_method) return measurement_sd - def generate_dd_measurement_torch(scene, filtering_cube,chunk_size): """ Generate DD-CASSI type system measurement from a scene and a filtering cube. ref : "Single-shot compressive spectral imaging with a dual-disperser architecture", M.Gehm et al., Optics Express, 2007 @@ -42,7 +42,7 @@ def generate_dd_measurement_torch(scene, filtering_cube,chunk_size): """ # Initialize an empty array for the result - filtered_scene = torch.empty_like(filtering_cube) + #filtered_scene = torch.empty_like(filtering_cube) # Calculate total iterations for tqdm total_iterations = (filtering_cube.shape[0] // chunk_size + 1) * (filtering_cube.shape[1] // chunk_size + 1) @@ -80,10 +80,10 @@ def interpolate_data_on_grid_positions_torch(data, X_init, Y_init, X_target, Y_t numpy.ndarray: 3D data interpolated on the target grid """ - X_init = torch.from_numpy(X_init).double() if isinstance(X_init, np.ndarray) else X_init - Y_init = torch.from_numpy(Y_init).double() if isinstance(Y_init, np.ndarray) else Y_init - X_target = torch.from_numpy(X_target).double() if isinstance(X_target, np.ndarray) else X_target - Y_target = torch.from_numpy(Y_target).double() if isinstance(Y_target, np.ndarray) else Y_target + X_init = torch.from_numpy(X_init).double().squeeze() if isinstance(X_init, np.ndarray) else X_init.squeeze() + Y_init = torch.from_numpy(Y_init).double().squeeze() if isinstance(Y_init, np.ndarray) else Y_init.squeeze() + X_target = torch.from_numpy(X_target).double().squeeze() if isinstance(X_target, np.ndarray) else X_target.squeeze() + Y_target = torch.from_numpy(Y_target).double().squeeze() if isinstance(Y_target, np.ndarray) else Y_target.squeeze() data = torch.from_numpy(data).double() if isinstance(data, np.ndarray) else data interpolated_data = torch.zeros((X_target.shape[0],X_target.shape[1],X_init.shape[2])) @@ -100,7 +100,6 @@ def interpolate_data_on_grid_positions_torch(data, X_init, Y_init, X_target, Y_t tasks = [(X_init[:, :, i], Y_init[:, :, i], data[:, :, i], X_target, Y_target, interp_method) for i in range(nb_of_grids)] - for index, zi in tqdm(enumerate(tasks), total=nb_of_grids, desc='Interpolate 3D data on grid positions'): interpolated_data[:, :, index] = worker(zi) diff --git a/simca/functions_optim.py b/simca/functions_optim.py index d62a3e2..a934696 100644 --- a/simca/functions_optim.py +++ b/simca/functions_optim.py @@ -46,23 +46,22 @@ def closure(): cassi_system.generate_filtering_cube() cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) - cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, sigma = sigma, pos_slit = pos_slit_detector) + cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, device, sigma = sigma, pos_slit = pos_slit_detector) cost_value.backward() return cost_value optimizer.step(closure) - cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, sigma = sigma, pos_slit = pos_slit_detector) + cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, device, sigma = sigma, pos_slit = pos_slit_detector) elif algo == "ADAM": optimizer.zero_grad() # Clear previous gradients cassi_system.generate_custom_slit_pattern() cassi_system.pattern = cassi_system.pattern.to(device) - #print(pattern[:, pattern.shape[1]//2-4:pattern.shape[1]//2+4]) cassi_system.generate_filtering_cube() cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) - cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, sigma = sigma, pos_slit = pos_slit_detector) + cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, device, sigma = sigma, pos_slit = pos_slit_detector) cost_value.backward() optimizer.step() @@ -73,15 +72,14 @@ def closure(): else: convergence_counter+=1 - if (iteration >= 50) and (convergence_counter >= max_iter_cnv): # If loss didn't decrease in 25 steps, break + if (iteration >= 50) and (convergence_counter >= max_iter_cnv): # If loss didn't decrease in max_iter_cnv steps, break break - # print("Gradients after backward:", cassi_system.array_x_positions.grad) cassi_system.array_x_positions.data = torch.relu(cassi_system.array_x_positions.data) # Prevent the parameters to be negative # Optional: Print cost_value every few iterations to monitor progress if iteration % 5 == 0: # Adjust printing frequency as needed - print(f"Iteration {iteration}, Cost: {cost_value.item()}") + print(f"\nIteration {iteration}, Cost: {cost_value.item()}") if plot_frequency is not None: if iteration % plot_frequency == 0: @@ -89,10 +87,10 @@ def closure(): plt.imshow(cassi_system.pattern.detach().numpy(), aspect=aspect_plot) plt.show() - plt.imshow(cassi_system.filtering_cube[:, :, 0].detach().numpy(), aspect=aspect_plot) + plt.imshow(cassi_system.filtering_cube[:, :, cassi_system.filtering_cube.shape[2]//2].detach().numpy(), aspect=aspect_plot) plt.show() - plt.plot(np.sum(cassi_system.filtering_cube[:, :, 0].detach().numpy(),axis=0)) + plt.plot(np.sum(cassi_system.filtering_cube[:, :, cassi_system.filtering_cube.shape[2]//2].detach().numpy(),axis=0)) plt.show() cassi_system.array_x_positions.data = torch.relu(best_x.data) @@ -103,15 +101,18 @@ def closure(): return cassi_system -def optim_width(cassi_system, position, pos_slit_detector, nb_rows, sigma, device, +def optim_width(cassi_system, position, target, nb_rows, start_width, device, algo, lr, num_iterations, max_iter_cnv, threshold = 0, plot_frequency = None, aspect_plot = 1): - #cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=nb_rows, start_width = sigma) - cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=nb_rows, start_width = 0.01) + + #start_width = torch.rand(size=(1,nb_rows), generator=gen)*1.5-0.75 + #start_width = 0.005 + cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=nb_rows, start_width = start_width) #0.005 cassi_system.array_x_positions = cassi_system.array_x_positions.to(device) # Ensure array_x_positions is a tensor with gradient tracking cassi_system.array_x_positions = cassi_system.array_x_positions.clone().detach().requires_grad_(True) + best_x = cassi_system.array_x_positions.clone().detach() convergence_counter = 0 # Counter to check convergence @@ -137,24 +138,23 @@ def closure(): cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) cassi_system.image_acquisition(use_psf=False, chunck_size=cassi_system.system_config["detector"]["number of pixels along Y"]) - cost_value = evaluate_max_lighting(cassi_system.measurement, pos_slit_detector) + cost_value = evaluate_max_lighting(cassi_system.array_x_positions.detach(), cassi_system.measurement, target) cost_value.backward() return cost_value optimizer.step(closure) - cost_value = evaluate_max_lighting(cassi_system.measurement, pos_slit_detector) + cost_value = evaluate_max_lighting(cassi_system.array_x_positions.detach(), cassi_system.measurement, target) elif algo == "ADAM": optimizer.zero_grad() # Clear previous gradients cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = position) cassi_system.pattern = cassi_system.pattern.to(device) - #print(pattern[:, pattern.shape[1]//2-4:pattern.shape[1]//2+4]) cassi_system.generate_filtering_cube() cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) cassi_system.image_acquisition(use_psf=False, chunck_size=cassi_system.system_config["detector"]["number of pixels along Y"]) - cost_value = evaluate_max_lighting(cassi_system.measurement, pos_slit_detector) + cost_value = evaluate_max_lighting(cassi_system.array_x_positions.detach(), cassi_system.measurement, target) cost_value.backward() optimizer.step() @@ -165,32 +165,34 @@ def closure(): else: convergence_counter+=1 - if (iteration >= 50) and (convergence_counter >= max_iter_cnv): # If loss didn't decrease in 25 steps, break + if (iteration >= 50) and (convergence_counter >= max_iter_cnv): # If loss didn't decrease in max_iter_cnv steps, break break - # print("Gradients after backward:", cassi_system.array_x_positions.grad) cassi_system.array_x_positions.data = torch.relu(cassi_system.array_x_positions.data) # Prevent the parameters to be negative # Optional: Print cost_value every few iterations to monitor progress if iteration % 5 == 0: # Adjust printing frequency as needed - print(f"Iteration {iteration}, Cost: {cost_value.item()}") + print(f"\nIteration {iteration}, Cost: {cost_value.item()}") if plot_frequency is not None: if iteration % plot_frequency == 0: print(f"Exec time: {time.time() - time_start:.3f}s") + plt.figure() plt.imshow(cassi_system.pattern.detach().numpy(), aspect=aspect_plot) - plt.show() - plt.imshow(cassi_system.filtering_cube[:, :, 0].detach().numpy(), aspect=aspect_plot) - plt.show() + plt.figure() + plt.imshow(cassi_system.filtering_cube[:, :, cassi_system.filtering_cube.shape[2]//2].detach().numpy(), aspect=aspect_plot) - plt.plot(np.sum(cassi_system.filtering_cube[:, :, 0].detach().numpy(),axis=0)) - plt.show() + #plt.figure() + #plt.plot(np.sum(cassi_system.filtering_cube[:, :, cassi_system.filtering_cube.shape[2]//2].detach().numpy(),axis=0)) - plt.imshow(cassi_system.measurement.detach().numpy(), zorder=5) + plt.figure() + plt.imshow(cassi_system.measurement.detach().numpy(), cmap="gray") + plt.colorbar() plt.show() cassi_system.array_x_positions.data = torch.relu(best_x.data) + #cassi_system.array_x_positions.data = best_x.data cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = position) cassi_system.pattern = cassi_system.pattern.to(device) cassi_system.generate_filtering_cube()