From 3fef420ec25db3e4e0fc0a369ff18774fd522637 Mon Sep 17 00:00:00 2001 From: Leo Paillet Date: Wed, 6 Mar 2024 19:05:03 +0100 Subject: [PATCH] early draft of the training without resnet --- mask_optim_recon.py | 47 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/mask_optim_recon.py b/mask_optim_recon.py index 5b60648..aef44bc 100644 --- a/mask_optim_recon.py +++ b/mask_optim_recon.py @@ -2,6 +2,7 @@ from simca.CassiSystemOptim import CassiSystemOptim from simca.CassiSystem import CassiSystem from data_handler import CubesDataModule +from MST.simulation.train_code.architecture import * import numpy as np import snoop import matplotlib.pyplot as plt @@ -14,7 +15,7 @@ config_dataset = load_yaml_config("simca/configs/dataset.yml") config_patterns = load_yaml_config("simca/configs/pattern.yml") config_acquisition = load_yaml_config("simca/configs/acquisition.yml") -config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") +config_system = load_yaml_config() # Load datacubes # Generate random mask @@ -26,6 +27,46 @@ data_dir = "/local/users/ademaio/lpaillet/mst_datasets" datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2) +model_name = 'mst_plus_plus' +device = 'cuda' +lr = 0.001 -# cassi_system.dataset = datamodule.train_dataloader[i][0] -# cassi_system.wavelengths = datamodule.train_dataloader[i][1] \ No newline at end of file +cassi_system = CassiSystemOptim(system_config=config_system) +cassi_system.device = device + +cassi_system.update_optical_model(system_config=config_system) +X_vec_out, Y_vec_out = cassi_system.propagate_coded_aperture_grid() + +model = model_generator(model_name, None) +optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) +mse = torch.nn.MSELoss() + +def train(model_name): + + optimizer.zero_grad() + # cassi_system.dataset = datamodule.train_dataloader[i][0] + # cassi_system.wavelengths = datamodule.train_dataloader[i][1] + input_mask = np.random.randint(0,1,size=(128,128)) + cassi_system.pattern = input_mask + cassi_system.generate_filtering_cube() + + input_acq = cassi_system.image_acquisition(use_psf=False) + + model.train() + output = model(input_acq, input_mask) + loss = torch.sqrt(mse(output, cassi_system.dataset)) + + loss.backward() + optimizer.step() + + model.eval() + input_mask = np.random.randint(0,1,size=(128,128)) + cassi_system.pattern = input_mask + cassi_system.generate_filtering_cube() + # cassi_system.dataset = datamodule.test_dataloader[i][0] + # cassi_system.wavelengths = datamodule.test_dataloader[i][1] + input_acq = cassi_system.image_acquisition(use_psf=False) + + output = model(input_acq, input_mask) + loss = torch.sqrt(mse(output, cassi_system.dataset)) + model.train()