Skip to content

Commit

Permalink
fixed size bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 7, 2024
1 parent 05c4dc2 commit a4541df
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
10 changes: 5 additions & 5 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ def forward(self, x):
print("---FORWARD---")

hyperspectral_cube, wavelengths = x
hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1).to(self.device)
batch_size, H, W, C = hyperspectral_cube.shape
print(f"batch size:{batch_size}")
# print(f"batch size:{batch_size}")
# generate pattern
pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
pattern = pattern.to(self.device)
print(f"pattern_size: {pattern.shape}")
# print(f"pattern_size: {pattern.shape}")

# generate first acquisition with simca

acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths)
filtering_cubes = self.cassi_system.filtering_cube.permute(0, 3, 1, 2)[:,:28,:,:].float()
filtering_cubes = self.cassi_system.filtering_cube.permute(0, 3, 1, 2)[:,:28,:,:].float().to(self.device)
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
#print("displacement_in_pix", displacement_in_pix)

Expand All @@ -56,7 +56,7 @@ def forward(self, x):
# TODO : replace by the real reconstruction model
# mask_3d = expand_mask_3d(patterns)
acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1))
acquired_cubes = torch.flip(acquired_cubes, dims=(1,)).float() # -1 magnification
acquired_cubes = torch.flip(acquired_cubes, dims=(1,)).float().to(self.device) # -1 magnification

#print(acquired_cubes.shape)
#print(filtering_cubes.shape)
Expand Down
24 changes: 14 additions & 10 deletions simca/CassiSystem_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
# plt.plot(hyperspectral_cube[0,0,0,:].cpu().numpy())
# plt.title("Original spectrum")
# plt.show()
print("cube shape: ", hyperspectral_cube.shape)
print("wavelengths shape: ", wavelengths.shape)
print("self.wavelengths shape: ", self.wavelengths.shape)
# print("cube shape: ", hyperspectral_cube.shape)
# print("wavelengths shape: ", wavelengths.shape)
# print("self.wavelengths shape: ", self.wavelengths.shape)
dataset = self.interpolate_dataset_along_wavelengths_torch(hyperspectral_cube, wavelengths,self.wavelengths, chunck_size)


Expand Down Expand Up @@ -247,8 +247,8 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
self.X_coded_aper_coordinates = X_coded_aper_coordinates_crop
self.Y_coded_aper_coordinates = Y_coded_aper_coordinates_crop

print("dataset shape: ", dataset.shape)
print("X coded shape: ", X_coded_aper_coordinates_crop.shape)
# print("dataset shape: ", dataset.shape)
# print("X coded shape: ", X_coded_aper_coordinates_crop.shape)

scene = match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop)

Expand All @@ -263,11 +263,11 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
plt.show()

# filtered_scene = scene * pattern_crop[..., None].repeat((1, 1, scene.shape[2]))
print(f"scene: {scene.shape}")
print(f"pattern_crop: {pattern_crop.shape}")
# print(f"scene: {scene.shape}")
# print(f"pattern_crop: {pattern_crop.shape}")
filtered_scene = scene * pattern_crop

print(f"filtered_scene: {filtered_scene.shape}")
# print(f"filtered_scene: {filtered_scene.shape}")


plt.imshow(pattern_crop[0,:,:,0].cpu().numpy())
Expand Down Expand Up @@ -494,9 +494,13 @@ def interpolate_dataset_along_wavelengths_torch(self, hyperspectral_cube, wavele
except:
self.dataset = hyperspectral_cube
self.dataset_wavelengths = wavelengths

#print(self.dataset.shape)
#print(self.dataset_wavelengths.shape)

self.dataset = hyperspectral_cube
self.dataset_wavelengths = wavelengths



self.dataset_wavelengths = torch.from_numpy(self.dataset_wavelengths) if isinstance(self.dataset_wavelengths,
np.ndarray) else self.dataset_wavelengths
Expand Down Expand Up @@ -547,7 +551,7 @@ def interpolate_data_along_wavelength_torch_old(self, data, current_sampling, ne
new_sampling (numpy.ndarray): new sampling for the 3rd axis
chunk_size (int): size of the chunks to use for the interpolation
"""
print(self.device)
# print(self.device)
# Generate the coordinates for the original grid
x = torch.arange(data.shape[0]).float()
y = torch.arange(data.shape[1]).float()
Expand Down

0 comments on commit a4541df

Please sign in to comment.