From 633b1d8147125abaf761d7ef0abfa6c1aba82589 Mon Sep 17 00:00:00 2001 From: Antoine Rouxel Date: Fri, 8 Mar 2024 17:26:54 +0100 Subject: [PATCH] using dd-cassi now --- optimization_modules.py | 225 +++++++++++------- simca/CassiSystem_lightning.py | 25 +- ...tem_optim_optics_full_triplet_dd_cassi.yml | 48 ++++ simca/functions_acquisition_torch.py | 2 + training_simca_reconstruction.py | 43 +++- 5 files changed, 228 insertions(+), 115 deletions(-) create mode 100755 simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml diff --git a/optimization_modules.py b/optimization_modules.py index 5ecba09..701b358 100644 --- a/optimization_modules.py +++ b/optimization_modules.py @@ -5,12 +5,18 @@ from MST.simulation.train_code.architecture import * from simca import load_yaml_config import matplotlib.pyplot as plt +import torchvision import numpy as np - +from simca.functions_acquisition import * +from piqa import SSIM +from torch.utils.tensorboard import SummaryWriter +import io +import torchvision.transforms as transforms +from PIL import Image class JointReconstructionModule_V1(pl.LightningModule): - def __init__(self, model_name): + def __init__(self, model_name,log_dir="tb_logs"): super().__init__() # TODO : use a real reconstruction module @@ -21,10 +27,13 @@ def __init__(self, model_name): self.reconstruction_model.to('cpu') """ #self.reconstruction_model = EmptyModule() self.loss_fn = nn.MSELoss() + self.ssim_loss = SSIM(window_size=11, size_average=True) + + self.writer = SummaryWriter(log_dir) def on_validation_start(self,stage=None): print("---VALIDATION START---") - self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml" + self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml" self.shift_bool = False if self.shift_bool: self.crop_value_left = 8 @@ -41,68 +50,51 @@ def on_validation_start(self,stage=None): self.cassi_system = CassiSystemOptim(system_config=config_system) self.cassi_system.propagate_coded_aperture_grid() + def _normalize_data_by_itself(self, data): + # Calculate the mean and std for each batch individually + # Keep dimensions for broadcasting + mean = torch.mean(data, dim=[1, 2], keepdim=True) + std = torch.std(data, dim=[1, 2], keepdim=True) + + # Normalize each batch by its mean and std + normalized_data = (data - mean) / std + return normalized_data + + def forward(self, x): print("---FORWARD---") hyperspectral_cube, wavelengths = x hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) batch_size, H, W, C = hyperspectral_cube.shape - fig, ax = plt.subplots(1, 1) - plt.title(f"entry cube") - ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) - plt.show() + + # fig, ax = plt.subplots(1, 1) + # plt.title(f"entry cube") + # ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) + # plt.show() # print(f"batch size:{batch_size}") # generate pattern - pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) - pattern = pattern.to(self.device) + self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) + self.pattern = self.pattern.to(self.device) - plt.imshow(pattern[0, :, :].cpu().detach().numpy()) - plt.show() + # plt.imshow(pattern[0, :, :].cpu().detach().numpy()) + # plt.show() # print(f"pattern_size: {pattern.shape}") # generate first acquisition with simca - acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths) - filtering_cubes = subsample(self.cassi_system.filtering_cube, np.linspace(450, 650, self.cassi_system.filtering_cube.shape[-1]), np.linspace(450, 650, 28)) - filtering_cubes = filtering_cubes.permute(0, 3, 1, 2).float().to(self.device) - filtering_cubes = torch.flip(filtering_cubes, dims=(2,3)) # -1 magnification - displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths) - #print("displacement_in_pix", displacement_in_pix) - - # # vizualize first image acquisition - # plt.imshow(acquired_image1[0, :, :].cpu().detach().numpy()) - # plt.show() - # process first acquisition with reconstruction model - # TODO : replace by the real reconstruction model - if not self.shift_bool: - acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C - acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification - fig, ax = plt.subplots(1, 2) - plt.title(f"true cube cropped vs measurement") - ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy()) - ax[1].imshow(acquired_cubes[0, 0, :, :].cpu().detach().numpy()) - plt.show() - - reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes) - else: - shifted_image = self.shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device) - mask_3d = expand_mask_3d(self.cassi_system.pattern_crop.flip(dims=(1, 2))[:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right]).float().to(self.device) + filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device) + self.acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) + - fig,ax = plt.subplots(1,2) - plt.title(f"true cube cropped vs measurement") - ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy()) - ax[1].imshow(shifted_image[0, 0, :, :].cpu().detach().numpy()) - plt.show() + self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) + acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C - reconstructed_cube = self.reconstruction_model(shifted_image, mask_3d) + filtering_cubes = subsample(filtering_cube, np.linspace(450, 650, filtering_cube.shape[-1]), np.linspace(450, 650, 28)).permute((0, 3, 1, 2)) - + reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes.to(self.device)) - #print(acquired_cubes.shape) - #print(filtering_cubes.shape) - - # return reconstructed_cube @@ -110,7 +102,23 @@ def forward(self, x): def training_step(self, batch, batch_idx): print("Training step") - loss, y_hat, y = self._common_step(batch, batch_idx) + loss,reconstructed_cube, ref_cube = self._common_step(batch, batch_idx) + + hyperspectral_cube, wavelengths = batch + hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) + + output_images = self._convert_output_to_images(self.acquired_image1) + patterns = self._convert_output_to_images(self.pattern) + input_images = self._convert_output_to_images(hyperspectral_cube[:,:,:,0]) + + + self._log_images('train/output_images', output_images, self.global_step) + self._log_images('train/input_images', input_images, self.global_step) + self._log_images('train/patterns', patterns, self.global_step) + + spectral_filter_plot = self.plot_spectral_filter(ref_cube,reconstructed_cube) + self.writer.add_image('Spectral Filter', spectral_filter_plot, self.global_step) + self.log_dict( { "train_loss": loss, }, @@ -119,12 +127,19 @@ def training_step(self, batch, batch_idx): prog_bar=True, ) - return {"loss": loss, "scores":y_hat, "y":y} + return {"loss": loss} + + def _normalize_image_tensor(self, tensor): + # Normalize the tensor to the range [0, 1] + min_val = tensor.min() + max_val = tensor.max() + normalized_tensor = (tensor - min_val) / (max_val - min_val) + return normalized_tensor def validation_step(self, batch, batch_idx): print("Validation step") - loss, y_hat, y = self._common_step(batch, batch_idx) + loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) self.log_dict( { "val_loss": loss, @@ -134,11 +149,11 @@ def validation_step(self, batch, batch_idx): prog_bar=True, ) - return {"loss": loss, "scores":y_hat, "y":y} + return {"loss": loss} def test_step(self, batch, batch_idx): print("Test step") - loss, y_hat, y = self._common_step(batch, batch_idx) + loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) self.log_dict( { "test_loss": loss, }, @@ -146,55 +161,99 @@ def test_step(self, batch, batch_idx): on_epoch=True, prog_bar=True, ) - return {"loss": loss, "scores":y_hat, "y":y} + return {"loss": loss} def predict_step(self, batch, batch_idx): print("Predict step") - loss, _, _ = self._common_step(batch, batch_idx) + loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def _common_step(self, batch, batch_idx): - y_hat = self.forward(batch) + reconstructed_cube = self.forward(batch) hyperspectral_cube, wavelengths = batch - #hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1) - hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right] - fig, ax = plt.subplots(1, 2) - plt.title(f"true cube vs reconstructed cube") - ax[0].imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy()) - ax[1].imshow(y_hat[0, 0, :, :].cpu().detach().numpy()) - plt.show() + hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) + reconstructed_cube = reconstructed_cube.permute(0, 2, 3, 1).to(self.device) + ref_cube = match_dataset_to_instrument(hyperspectral_cube, reconstructed_cube[0, :, :,0]) + + # fig, ax = plt.subplots(1, 2) + # plt.title(f"true cube vs reconstructed cube") + # ax[0].imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) + # ax[1].imshow(reconstructed_cube[0, :, :, 0].cpu().detach().numpy()) + # plt.show() - #print("y_hat shape", y_hat.shape) - #print("hyperspectral_cube shape", hyperspectral_cube.shape) - loss = torch.sqrt(self.loss_fn(y_hat, hyperspectral_cube)) + loss = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube)) - return loss, y_hat, hyperspectral_cube + return loss,reconstructed_cube, ref_cube def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=4e-4) return optimizer - - def shift_back(self, inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256] - [bs, row, col] = inputs.shape - nC = 28 - d = d[0] - d -= d.min() - self.crop_value_right = 8+int(np.round(d.max())) - output = torch.zeros(bs, nC, row, col - int(np.round(d.max()))).float().to(self.device) - for i in range(nC): - shift = int(np.round(d[i])) - #output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2 - # if shift >=0: - # output[:, i, :, :] = inputs[:, :, shift:row+shift] - # else: - # output[:, i, :, :] = inputs[:, :, shift-row:shift] - output[:, i, :, :] = inputs[:, :, shift:shift + col - int(np.round(d.max()))] - return output + + def _log_images(self, tag, images, global_step): + # Convert model output to image grid and log to TensorBoard + img_grid = torchvision.utils.make_grid(images) + self.writer.add_image(tag, img_grid, global_step) + + def _convert_output_to_images(self, acquired_images): + + acquired_images = acquired_images.unsqueeze(1) + + # Create a grid of images for visualization + img_grid = torchvision.utils.make_grid(acquired_images) + return img_grid + + def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_cube): + + + batch_size, y,x, lmabda_ = ref_hyperspectral_cube.shape + + # Create a figure with subplots arranged horizontally + fig, axs = plt.subplots(1, batch_size, figsize=(batch_size * 5, 4)) # Adjust figure size as needed + + # Check if batch_size is 1, axs might not be iterable + if batch_size == 1: + axs = [axs] + + # Plot each spectral filter in its own subplot + for i in range(batch_size): + colors = ['b', 'g', 'r'] + for j in range(3): + pix_j_row_value = np.random.randint(0,y) + pix_j_col_value = np.random.randint(0,x) + + pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() + pixe_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() + axs[i].plot(pixe_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j]) + axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j]) + + axs[i].set_title(f"Reconstruction quality") + + axs[i].set_xlabel("Wavelength index") + axs[i].set_ylabel("pxie values") + axs[i].grid(True) + + plt.legend() + # Adjust layout + plt.tight_layout() + + # Create a buffer to save plot + buf = io.BytesIO() + plt.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + + # Convert PNG buffer to PIL Image + image = Image.open(buf) + + # Convert PIL Image to Tensor + image_tensor = transforms.ToTensor()(image) + return image_tensor + def subsample(input, origin_sampling, target_sampling): [bs, row, col, nC] = input.shape diff --git a/simca/CassiSystem_lightning.py b/simca/CassiSystem_lightning.py index 9bea7a6..9a1d6f6 100755 --- a/simca/CassiSystem_lightning.py +++ b/simca/CassiSystem_lightning.py @@ -161,7 +161,7 @@ def generate_filtering_cube(self): numpy.ndarray: filtering cube generated according to the optical system & the pattern configuration (R x C x W) """ - self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern, + self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern.unsqueeze(-1).repeat(1, 1, 1, self.wavelengths.shape[0]), X_init=self.X_coordinates_propagated_coded_aperture, Y_init=self.Y_coordinates_propagated_coded_aperture, X_target=self.X_detector_coordinates_grid, @@ -193,24 +193,8 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals """ self.wavelengths= self.wavelengths.to(self.device) - - - # - # plt.plot(hyperspectral_cube[0,0,0,:].cpu().numpy()) - # plt.title("Original spectrum") - # plt.show() - # print("cube shape: ", hyperspectral_cube.shape) - # print("wavelengths shape: ", wavelengths.shape) - # print("self.wavelengths shape: ", self.wavelengths.shape) dataset = self.interpolate_dataset_along_wavelengths_torch(hyperspectral_cube, wavelengths,self.wavelengths, chunck_size) - - # plt.plot(dataset[0,0,0,:].cpu().numpy()) - # plt.title("Interpolated spectrum") - # plt.show() - - - if dataset is None: return None @@ -226,8 +210,10 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals except: return print("Please generate filtering cube first") - scene = match_dataset_to_instrument(dataset, self.filtering_cube) - scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device) + + scene = match_dataset_to_instrument(dataset, self.filtering_cube[0,:,:,0]) + + # scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device) measurement_in_3D = generate_dd_measurement_torch(scene, self.filtering_cube, chunck_size) @@ -250,6 +236,7 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals # print("dataset shape: ", dataset.shape) # print("X coded shape: ", X_coded_aper_coordinates_crop.shape) + scene = match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop) pattern_crop = crop_center_3D(pattern, scene.shape[2], scene.shape[1]).to(self.device) diff --git a/simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml b/simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml new file mode 100755 index 0000000..e60a715 --- /dev/null +++ b/simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml @@ -0,0 +1,48 @@ +##### Configuration file for the chosen optical system + +infos: + system name: HYACAMEO + +system architecture: + system type: DD-CASSI + propagation type: simca + focal lens: 75000 + dispersive element: + # dispersive element caracteristics + type: tripleprism # name of the dispersive element + glass1: N-SK2 # glass type of the dispersive element (only used if type == 'prism') + glass2: N-SF4 # glass type of the dispersive element (only used if type == 'prism') + glass3: N-SK2 # glass type of the dispersive element (only used if type == 'prism') + A1: 26.7 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A2: 53.4 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A3: 26.7 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + nd1: 1.6074 + nd2: 1.7552 + nd3: 1.6074 + vd1: 56.65 + vd2: 27.58 + vd3: 56.65 + continuous glass materials 1: False + continuous glass materials 2: False + continuous glass materials 3: False + m: 1 # grating order to consider (only used if type == 'grating') -- no units + G: 30 # grating density (only used if type == 'grating') -- in lines/mm + alpha_c: 0 + delta alpha c: 0 + delta beta c: 0 + wavelength center: 550 # central wavelength -- in nm +detector: + number of pixels along X: 112 # number of pixels along X axis -- no units + number of pixels along Y: 112 # number of pixels along Y axis -- no units + pixel size along X: 70 # pixel size along X -- in um + pixel size along Y: 70 # pixel size along Y -- in um +coded aperture: + number of pixels along X: 131 # number of pixels along X axis -- no units + number of pixels along Y: 131 # number of pixels along Y axis -- no units + pixel size along X: 80 # pixel size along X -- in um + pixel size along Y: 80 # pixel size along Y -- in um + +spectral range: + wavelength min: 450 # minimum wavelength -- in nm + wavelength max: 650 # maximum wavelength -- in nm + number of spectral samples: 55 diff --git a/simca/functions_acquisition_torch.py b/simca/functions_acquisition_torch.py index 480d30e..cc63f31 100644 --- a/simca/functions_acquisition_torch.py +++ b/simca/functions_acquisition_torch.py @@ -82,6 +82,8 @@ def interpolate_data_on_grid_positions_torch(data, X_init, Y_init, X_target, Y_t torch.Tensor: Interpolated 4D data on the target grid. """ + print(data.shape) + # Ensure tensors are on the correct device and data type device = data.device dtype = torch.float64 # Using double precision for interpolation calculations diff --git a/training_simca_reconstruction.py b/training_simca_reconstruction.py index 07f3344..f2ea614 100644 --- a/training_simca_reconstruction.py +++ b/training_simca_reconstruction.py @@ -1,26 +1,43 @@ import pytorch_lightning as pl from data_handler import CubesDataModule from optimization_modules import JointReconstructionModule_V1 -import torch +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger -data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" + +data_dir = "./datasets_reconstruction/" #data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=5) +datamodule = CubesDataModule(data_dir, batch_size=32, num_workers=1) name = "testing_simca_reconstruction" - model_name = "mst_plus_plus" -reconstruction_module = JointReconstructionModule_V1(model_name) +log_dir = 'tb_logs' + +logger = TensorBoardLogger(log_dir, name=name) + +early_stop_callback = EarlyStopping( + monitor='val_loss', # Metric to monitor + patience=15, # Number of epochs to wait for improvement + verbose=True, + mode='min' # 'min' for metrics where lower is better, 'max' for vice versa + ) + +checkpoint_callback = ModelCheckpoint( + monitor='val_loss', # Metric to monitor + dirpath='checkpoints/', # Directory path for saving checkpoints + filename='best-checkpoint', # Checkpoint file name + save_top_k=1, # Save the top k models + mode='min', # 'min' for metrics where lower is better, 'max' for vice versa + save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' +) + +reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name) -if torch.cuda.is_available(): - trainer = pl.Trainer( accelerator="gpu", - max_epochs=500, - log_every_n_steps=5) -else: - trainer = pl.Trainer( accelerator="cpu", - max_epochs=500, - log_every_n_steps=5) +trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=500, + log_every_n_steps=1) trainer.fit(reconstruction_module, datamodule)