diff --git a/mask_optim_recon.py b/mask_optim_recon.py index aef44bc..6166620 100644 --- a/mask_optim_recon.py +++ b/mask_optim_recon.py @@ -3,6 +3,7 @@ from simca.CassiSystem import CassiSystem from data_handler import CubesDataModule from MST.simulation.train_code.architecture import * +from MST.simulation.train_code.utils import * import numpy as np import snoop import matplotlib.pyplot as plt @@ -41,6 +42,12 @@ optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) mse = torch.nn.MSELoss() +def expand_mask_3d(mask): + mask3d = np.tile(mask[:, :, np.newaxis], (1, 1, 28)) + mask3d = np.transpose(mask3d, [2, 0, 1]) + mask3d = torch.from_numpy(mask3d) + return mask3d + def train(model_name): optimizer.zero_grad() @@ -50,10 +57,13 @@ def train(model_name): cassi_system.pattern = input_mask cassi_system.generate_filtering_cube() - input_acq = cassi_system.image_acquisition(use_psf=False) + input_acq = cassi_system.image_acquisition(use_psf=False) # H x (W + d*(28-1)) + d=2 + input_acq = shift_back(input_acq, step=d) model.train() - output = model(input_acq, input_mask) + input_mask_3d = expand_mask_3d(input_mask) #TODO, like in train_method/utils.py + output = model(input_acq, input_mask_3d) loss = torch.sqrt(mse(output, cassi_system.dataset)) loss.backward() @@ -66,7 +76,8 @@ def train(model_name): # cassi_system.dataset = datamodule.test_dataloader[i][0] # cassi_system.wavelengths = datamodule.test_dataloader[i][1] input_acq = cassi_system.image_acquisition(use_psf=False) + input_acq = shift_back(input_acq, step=d) - output = model(input_acq, input_mask) + output = model(input_acq, input_mask_3d) loss = torch.sqrt(mse(output, cassi_system.dataset)) model.train()