Skip to content

Commit

Permalink
more accurate sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 7, 2024
1 parent c97c03b commit 5f2c1c8
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, model_name):

def on_validation_start(self,stage=None):
print("---VALIDATION START---")
#self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml"
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
#self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml"
config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
Expand All @@ -46,7 +46,9 @@ def forward(self, x):
# generate first acquisition with simca

acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths)
filtering_cubes = self.cassi_system.filtering_cube.permute(0, 3, 1, 2)[:,:28,:,:].float().to(self.device)
filtering_cubes = subsample(self.cassi_system.filtering_cube, np.linspace(450, 650, self.cassi_system.filtering_cube.shape[-1]), np.linspace(450, 650, 28))
filtering_cubes = filtering_cubes.permute(0, 3, 1, 2).float().to(self.device)
filtering_cubes = torch.flip(filtering_cubes, dims=(2,)) # -1 magnification
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
#print("displacement_in_pix", displacement_in_pix)

Expand Down Expand Up @@ -141,6 +143,15 @@ def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def subsample(input, origin_sampling, target_sampling):
[bs, row, col, nC] = input.shape
output = torch.zeros(bs, row, col, len(target_sampling))
for i in range(len(target_sampling)):
sample = target_sampling[i]
idx = np.abs(origin_sampling-sample).argmin()
output[:,:,:,i] = input[:,:,:,idx]
return output

def expand_mask_3d(mask_batch):
if len(mask_batch.shape)==3:
mask3d = mask_batch.unsqueeze(-1).repeat((1, 1, 1, 28))
Expand Down

0 comments on commit 5f2c1c8

Please sign in to comment.