Skip to content

Commit

Permalink
Update to adapt to scene
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 6, 2024
1 parent 9884211 commit 60ac205
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 240 deletions.
341 changes: 210 additions & 131 deletions main_script_optim.py

Large diffs are not rendered by default.

56 changes: 35 additions & 21 deletions simca/CassiSystemOptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, system_config=None):
self.wavelengths = self.set_wavelengths(self.system_config["spectral range"]["wavelength min"],
self.system_config["spectral range"]["wavelength max"],
self.system_config["spectral range"]["number of spectral samples"])

self.optical_model = OpticalModelTorch(self.system_config)


Expand Down Expand Up @@ -63,7 +62,7 @@ def create_coordinates_grid(self, nb_of_pixels_along_x, nb_of_pixels_along_y, de
# Create a two-dimensional grid of coordinates
X_input_grid, Y_input_grid = np.meshgrid(x, y)

return torch.from_numpy(X_input_grid), torch.from_numpy(Y_input_grid)
return torch.from_numpy(X_input_grid).float(), torch.from_numpy(Y_input_grid).float()

def set_wavelengths(self, wavelength_min, wavelength_max, nb_of_spectral_samples):
"""
Expand Down Expand Up @@ -154,8 +153,7 @@ def generate_filtering_cube(self):
X_init=self.X_coordinates_propagated_coded_aperture,
Y_init=self.Y_coordinates_propagated_coded_aperture,
X_target=self.X_detector_coordinates_grid,
Y_target=self.Y_detector_coordinates_grid)

Y_target=self.Y_detector_coordinates_grid).to(self.device)

return self.filtering_cube

Expand Down Expand Up @@ -185,7 +183,7 @@ def image_acquisition(self, use_psf=False, chunck_size=50):
return print("Please generate filtering cube first")

scene = match_dataset_to_instrument(dataset, self.filtering_cube)
scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)) if isinstance(scene, np.ndarray) else scene
scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device)

measurement_in_3D = generate_dd_measurement_torch(scene, self.filtering_cube, chunck_size)

Expand Down Expand Up @@ -238,6 +236,7 @@ def image_acquisition(self, use_psf=False, chunck_size=50):
def generate_custom_pattern_parameters_slit(self, position=0.5):
# Position is a float: 0 means slit is on the left edge, 1 means the slit is on the right edge
self.array_x_positions = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"]))+ position
self.array_x_positions = self.array_x_positions.to(self.device)
return self.array_x_positions

def generate_custom_pattern_parameters_slit_width(self, nb_slits=1, nb_rows=1, start_width=1):
Expand All @@ -246,14 +245,16 @@ def generate_custom_pattern_parameters_slit_width(self, nb_slits=1, nb_rows=1, s
# self.array_x_positions is of shape (self.system_config["coded aperture"]["number of pixels along Y"], nb_rows)

self.array_x_positions = torch.zeros((nb_slits, nb_rows))+start_width # Every slit starts with the same width
self.array_x_positions = self.array_x_positions.to(self.device)
self.array_x_positions_normalized = torch.zeros((nb_slits, nb_rows))+start_width
return self.array_x_positions

def generate_custom_slit_pattern_width(self, start_pattern = "line", start_position = 0):
nb_slits, nb_rows = self.array_x_positions.shape
pos_slits = self.system_config["coded aperture"]["number of pixels along Y"]//(nb_slits+1) # Equally spaced slits
height_slits = self.system_config["coded aperture"]["number of pixels along X"]//nb_rows # Same length slits
pos_slits = self.system_config["coded aperture"]["number of pixels along X"]//(nb_slits+1) # Equally spaced slits
height_slits = self.system_config["coded aperture"]["number of pixels along Y"]//nb_rows # Same length slits

self.pattern = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"])) # Pattern of correct size
self.pattern = torch.zeros((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"])).to(self.device) # Pattern of correct size
if start_pattern == "line":
if start_position != 0:
start_position = start_position - pos_slits/self.system_config["coded aperture"]["number of pixels along X"]
Expand Down Expand Up @@ -312,22 +313,35 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit
# Normalize to make sure the maximum value is 1
self.pattern = self.pattern + padded/padded.max() """

c = torch.tensor(start_position[i])
d = self.array_x_positions[j,i]/2
m = (c-d)*(self.system_config["coded aperture"]["number of pixels along X"]-1)
M = (c+d)*(self.system_config["coded aperture"]["number of pixels along X"]-1)
rect = torch.arange(self.system_config["coded aperture"]["number of pixels along X"])
rect = torch.clamp(-(rect-m)*(rect-M)+1,0,1)
c = start_position[i].clone().detach() # center of the slit
#d = ((torch.tanh(1.1*self.array_x_positions[j,i])+1)/2)/2 # width of the slit at pos
d = self.array_x_positions[j,i]/2 # width of the slit at pos
m = (c-d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # left bound
M = (c+d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # right bound
rect = torch.arange(self.system_config["coded aperture"]["number of pixels along X"]).to(self.device)
clamp_M = torch.clamp(M-rect, 0, 1)

clamp_m = torch.clamp(rect-m, 0, 1)
diff = 1-clamp_m
reg = torch.where(diff < 1, diff, -1)
clamp_m = torch.where(reg!=0, reg, 1)
clamp_m = torch.where(clamp_m!=-1, clamp_m, 0)
clamp_m = torch.roll(clamp_m, -1)
clamp_m[-1]=1

rect = clamp_M - clamp_m +1
rect = torch.where(rect!=2, rect, 0)
rect = torch.where(rect <= 1, rect, rect-1)
#rect = torch.clamp(-(rect-m)*(rect-M)+1,0,1).to(self.device)

gaussian_range = torch.arange(self.system_config["coded aperture"]["number of pixels along X"], dtype=torch.float32)
center_pos = 0.5*(len(gaussian_range)-1)
sigma = 1.5
gaussian_peaks = np.exp(-((center_pos - gaussian_range) ** 2) / (2 * sigma ** 2))
gaussian_peaks = torch.exp(-((center_pos - gaussian_range) ** 2) / (2 * sigma ** 2)).to(self.device)
gaussian = gaussian_peaks /gaussian_peaks.max()
res = torch.nn.functional.conv1d(rect.unsqueeze(0), gaussian.unsqueeze(0).unsqueeze(0), padding = 144//2).squeeze()
res = torch.nn.functional.conv1d(rect.unsqueeze(0), gaussian.unsqueeze(0).unsqueeze(0), padding = (len(gaussian_range)-1)//2).squeeze().to(self.device)

res = res/res.max()

self.pattern[i, :] = self.pattern[i, :] + res

# Normalize to make sure the maximum value is 1
Expand All @@ -338,15 +352,15 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit
def generate_custom_slit_pattern(self):

# Create a grid to represent positions
grid_positions = torch.arange(self.empty_grid.shape[1], dtype=torch.float32)
grid_positions = torch.arange(self.empty_grid.shape[1], dtype=torch.float32).to(self.device)
# Expand dimensions for broadcasting
expanded_x_positions = (self.array_x_positions.unsqueeze(-1)) * (self.empty_grid.shape[1]-1)
expanded_grid_positions = grid_positions.unsqueeze(0)
expanded_x_positions = ((self.array_x_positions.unsqueeze(-1)) * (self.empty_grid.shape[1]-1)).to(self.device)
expanded_grid_positions = grid_positions.unsqueeze(0).to(self.device)

# Apply Gaussian-like function
# Adjust 'sigma' to control the sharpness
sigma = 1.5
gaussian_peaks = torch.exp(-((expanded_grid_positions - expanded_x_positions) ** 2) / (2 * sigma ** 2))
gaussian_peaks = torch.exp(-((expanded_grid_positions - expanded_x_positions) ** 2) / (2 * sigma ** 2)).to(self.device)

# Normalize to make sure the maximum value is 1
self.pattern = gaussian_peaks / gaussian_peaks.max()
Expand Down
4 changes: 2 additions & 2 deletions simca/OpticalModelTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def rotation_y_torch(theta):
"""
# Ensure theta is a tensor with requires_grad=True
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, requires_grad=True)
theta = torch.tensor(theta, requires_grad=True, dtype=torch.float32)

cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
Expand Down Expand Up @@ -78,7 +78,7 @@ def rotation_x_torch(theta):
"""
# Ensure theta is a tensor with requires_grad=True
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, requires_grad=True)
theta = torch.tensor(theta, requires_grad=True, dtype=torch.float32)

cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
Expand Down
76 changes: 38 additions & 38 deletions simca/configs/cassi_system_optim_optics_full_triplet.yml
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
##### Configuration file for the chosen optical system

infos:
system name: HYACAMEO
system name: HYACAMEO

system architecture:
system type: DD-CASSI
propagation type: simca
focal lens: 50000
dispersive element:
system type: DD-CASSI
propagation type: simca
focal lens: 50000
dispersive element:
# dispersive element caracteristics
type: tripleprism # name of the dispersive element
glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism')
glass2: SF4 # glass type of the dispersive element (only used if type == 'prism')
glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism')
A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
nd1: 1.6074
nd2: 1.7552
nd3: 1.6074
vd1: 56.65
vd2: 27.58
vd3: 56.65
continuous glass materials 1: False
continuous glass materials 2: False
continuous glass materials 3: False
m: 1 # grating order to consider (only used if type == 'grating') -- no units
G: 30 # grating density (only used if type == 'grating') -- in lines/mm
alpha_c: 2.6
delta alpha c: 0
delta beta c: 0
wavelength center: 600 # central wavelength -- in nm
type: tripleprism # name of the dispersive element
glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism')
glass2: SF4 # glass type of the dispersive element (only used if type == 'prism')
glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism')
A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees
nd1: 1.6074
nd2: 1.7552
nd3: 1.6074
vd1: 56.65
vd2: 27.58
vd3: 56.65
continuous glass materials 1: False
continuous glass materials 2: False
continuous glass materials 3: False
m: 1 # grating order to consider (only used if type == 'grating') -- no units
G: 30 # grating density (only used if type == 'grating') -- in lines/mm
alpha_c: 2.6
delta alpha c: 0
delta beta c: 0
wavelength center: 600 # central wavelength -- in nm
detector:
number of pixels along X: 201 # number of pixels along X axis -- no units
number of pixels along Y: 201 # number of pixels along Y axis -- no units
pixel size along X: 40 # pixel size along X -- in um
pixel size along Y: 40 # pixel size along Y -- in um
number of pixels along X: 145 # number of pixels along X axis -- no units
number of pixels along Y: 901 # number of pixels along Y axis -- no units
pixel size along X: 15 # pixel size along X -- in um
pixel size along Y: 15 # pixel size along Y -- in um
coded aperture:
number of pixels along X: 11 # number of pixels along X axis -- no units
number of pixels along Y: 13 # number of pixels along Y axis -- no units
pixel size along X: 1000 # pixel size along X -- in um
pixel size along Y: 1000 # pixel size along Y -- in um
number of pixels along X: 145 # 151 # number of pixels along X axis -- no units
number of pixels along Y: 901 # 151 # number of pixels along Y axis -- no units
pixel size along X: 15 # pixel size along X -- in um
pixel size along Y: 15 # pixel size along Y -- in um

spectral range:
wavelength min: 400 # minimum wavelength -- in nm
wavelength max: 1050 # maximum wavelength -- in nm
number of spectral samples: 15
wavelength min: 410 # minimum wavelength -- in nm
wavelength max: 1050 # maximum wavelength -- in nm
number of spectral samples: 151
Loading

0 comments on commit 60ac205

Please sign in to comment.