From 057d66db55bcbbe0abd54efe565b5baa5b014637 Mon Sep 17 00:00:00 2001 From: Leo Paillet Date: Fri, 8 Mar 2024 17:53:11 +0100 Subject: [PATCH] differentiable subsampling --- optimization_modules.py | 24 +++++++++--------------- training_simca_reconstruction.py | 19 +++++++++++++------ 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/optimization_modules.py b/optimization_modules.py index 701b358..214fdce 100644 --- a/optimization_modules.py +++ b/optimization_modules.py @@ -19,13 +19,7 @@ class JointReconstructionModule_V1(pl.LightningModule): def __init__(self, model_name,log_dir="tb_logs"): super().__init__() - # TODO : use a real reconstruction module self.reconstruction_model = model_generator(model_name, None) - """ if torch.cuda.is_available(): - self.reconstruction_model = self.reconstruction_model.cuda() - else: - self.reconstruction_model.to('cpu') """ - #self.reconstruction_model = EmptyModule() self.loss_fn = nn.MSELoss() self.ssim_loss = SSIM(window_size=11, size_average=True) @@ -91,9 +85,9 @@ def forward(self, x): self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C - filtering_cubes = subsample(filtering_cube, np.linspace(450, 650, filtering_cube.shape[-1]), np.linspace(450, 650, 28)).permute((0, 3, 1, 2)) + filtering_cubes = subsample(filtering_cube, torch.linspace(450, 650, filtering_cube.shape[-1]), torch.linspace(450, 650, 28)).permute((0, 3, 1, 2)).float().to(self.device) - reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes.to(self.device)) + reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes) return reconstructed_cube @@ -227,14 +221,14 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_ pix_j_col_value = np.random.randint(0,x) pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() - pixe_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() - axs[i].plot(pixe_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j]) + pix_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() + axs[i].plot(pix_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j]) axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j]) axs[i].set_title(f"Reconstruction quality") axs[i].set_xlabel("Wavelength index") - axs[i].set_ylabel("pxie values") + axs[i].set_ylabel("pix values") axs[i].grid(True) plt.legend() @@ -257,12 +251,12 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_ def subsample(input, origin_sampling, target_sampling): [bs, row, col, nC] = input.shape - output = torch.zeros(bs, row, col, len(target_sampling)) + indices = torch.zeros(len(target_sampling), dtype=torch.int) for i in range(len(target_sampling)): sample = target_sampling[i] - idx = np.abs(origin_sampling-sample).argmin() - output[:,:,:,i] = input[:,:,:,idx] - return output + idx = torch.abs(origin_sampling-sample).argmin() + indices[i] = idx + return input[:,:,:,indices] def expand_mask_3d(mask_batch): if len(mask_batch.shape)==3: diff --git a/training_simca_reconstruction.py b/training_simca_reconstruction.py index f2ea614..8780ea2 100644 --- a/training_simca_reconstruction.py +++ b/training_simca_reconstruction.py @@ -3,12 +3,13 @@ from optimization_modules import JointReconstructionModule_V1 from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +import torch -data_dir = "./datasets_reconstruction/" +data_dir = "./datasets_reconstruction/cave_1024_28" #data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -datamodule = CubesDataModule(data_dir, batch_size=32, num_workers=1) +datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1) name = "testing_simca_reconstruction" model_name = "mst_plus_plus" @@ -35,9 +36,15 @@ reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name) -trainer = pl.Trainer( logger=logger, - accelerator="gpu", - max_epochs=500, - log_every_n_steps=1) +if torch.cuda.is_available(): + trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=500, + log_every_n_steps=1) +else: + trainer = pl.Trainer( logger=logger, + accelerator="cpu", + max_epochs=500, + log_every_n_steps=1) trainer.fit(reconstruction_module, datamodule)