Skip to content

Commit

Permalink
early draft of the training without resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 6, 2024
1 parent 69c30a5 commit 3fef420
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions mask_optim_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from simca.CassiSystemOptim import CassiSystemOptim
from simca.CassiSystem import CassiSystem
from data_handler import CubesDataModule
from MST.simulation.train_code.architecture import *
import numpy as np
import snoop
import matplotlib.pyplot as plt
Expand All @@ -14,7 +15,7 @@
config_dataset = load_yaml_config("simca/configs/dataset.yml")
config_patterns = load_yaml_config("simca/configs/pattern.yml")
config_acquisition = load_yaml_config("simca/configs/acquisition.yml")
config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml")
config_system = load_yaml_config()

# Load datacubes
# Generate random mask
Expand All @@ -26,6 +27,46 @@

data_dir = "/local/users/ademaio/lpaillet/mst_datasets"
datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2)
model_name = 'mst_plus_plus'
device = 'cuda'
lr = 0.001

# cassi_system.dataset = datamodule.train_dataloader[i][0]
# cassi_system.wavelengths = datamodule.train_dataloader[i][1]
cassi_system = CassiSystemOptim(system_config=config_system)
cassi_system.device = device

cassi_system.update_optical_model(system_config=config_system)
X_vec_out, Y_vec_out = cassi_system.propagate_coded_aperture_grid()

model = model_generator(model_name, None)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
mse = torch.nn.MSELoss()

def train(model_name):

optimizer.zero_grad()
# cassi_system.dataset = datamodule.train_dataloader[i][0]
# cassi_system.wavelengths = datamodule.train_dataloader[i][1]
input_mask = np.random.randint(0,1,size=(128,128))
cassi_system.pattern = input_mask
cassi_system.generate_filtering_cube()

input_acq = cassi_system.image_acquisition(use_psf=False)

model.train()
output = model(input_acq, input_mask)
loss = torch.sqrt(mse(output, cassi_system.dataset))

loss.backward()
optimizer.step()

model.eval()
input_mask = np.random.randint(0,1,size=(128,128))
cassi_system.pattern = input_mask
cassi_system.generate_filtering_cube()
# cassi_system.dataset = datamodule.test_dataloader[i][0]
# cassi_system.wavelengths = datamodule.test_dataloader[i][1]
input_acq = cassi_system.image_acquisition(use_psf=False)

output = model(input_acq, input_mask)
loss = torch.sqrt(mse(output, cassi_system.dataset))
model.train()

0 comments on commit 3fef420

Please sign in to comment.