Skip to content

Commit

Permalink
working sd cassi with batch and good config
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 7, 2024
1 parent e969a22 commit 789af26
Show file tree
Hide file tree
Showing 5 changed files with 628 additions and 44 deletions.
24 changes: 6 additions & 18 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,33 @@ def __init__(self):
# TODO : use a real reconstruction module
# self.reconstruction_model = ReconstructionModel()
self.reconstruction_model = EmptyModule()

self.loss_fn = nn.MSELoss()

def setup(self, stage=None):
print("---SETUP---")
config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml")
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()
def on_validation_start(self,stage=None):
print("---VALIDATION START---")

def on_train_start(self,stage=None):
print("---TRAIN START---")
config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml")
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()

def on_validation_start(self,stage=None):
print("---VALDIATION START---")
config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml")
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()
def forward(self, x):
print("---FORWARD---")

hyperspectral_cube, wavelengths = x
# generate random patterns (one for scene in the batch)

hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
batch_size, H, W, C = hyperspectral_cube.shape

# generate pattern
pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
pattern = pattern.to(self.device)

# generate first acquisition with simca
# filtering_cube = self.cassi_system.generate_filtering_cube()
acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern,wavelengths)
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
print("displacement_in_pix", displacement_in_pix)

# vizualize first image acquisition
plt.imshow(acquired_image1[0, :, :].cpu().detach().numpy())
plt.show()

Expand Down
Loading

0 comments on commit 789af26

Please sign in to comment.