Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 8, 2024
1 parent 5f2c1c8 commit 9db4887
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(self, x):
acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths)
filtering_cubes = subsample(self.cassi_system.filtering_cube, np.linspace(450, 650, self.cassi_system.filtering_cube.shape[-1]), np.linspace(450, 650, 28))
filtering_cubes = filtering_cubes.permute(0, 3, 1, 2).float().to(self.device)
filtering_cubes = torch.flip(filtering_cubes, dims=(2,)) # -1 magnification
filtering_cubes = torch.flip(filtering_cubes, dims=(2,3)) # -1 magnification
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
#print("displacement_in_pix", displacement_in_pix)

Expand All @@ -60,7 +60,7 @@ def forward(self, x):
# TODO : replace by the real reconstruction model
if self.config == "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml":
acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C
acquired_cubes = torch.flip(acquired_cubes, dims=(2,)) # -1 magnification
acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification
reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)
else:
mask_3d = expand_mask_3d(pattern).float().to(self.device)
Expand Down Expand Up @@ -140,7 +140,7 @@ def _common_step(self, batch, batch_idx):
return loss, y_hat, hyperspectral_cube

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
return optimizer

def subsample(input, origin_sampling, target_sampling):
Expand Down
2 changes: 1 addition & 1 deletion training_simca_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
#data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28"

datamodule = CubesDataModule(data_dir, batch_size=16, num_workers=11)
datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=11)

name = "testing_simca_reconstruction"

Expand Down

0 comments on commit 9db4887

Please sign in to comment.