Skip to content

Commit

Permalink
early draft of the training without resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 6, 2024
1 parent 3fef420 commit c246675
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions mask_optim_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from simca.CassiSystem import CassiSystem
from data_handler import CubesDataModule
from MST.simulation.train_code.architecture import *
from MST.simulation.train_code.utils import *
import numpy as np
import snoop
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -41,6 +42,12 @@
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
mse = torch.nn.MSELoss()

def expand_mask_3d(mask):
mask3d = np.tile(mask[:, :, np.newaxis], (1, 1, 28))
mask3d = np.transpose(mask3d, [2, 0, 1])
mask3d = torch.from_numpy(mask3d)
return mask3d

def train(model_name):

optimizer.zero_grad()
Expand All @@ -50,10 +57,13 @@ def train(model_name):
cassi_system.pattern = input_mask
cassi_system.generate_filtering_cube()

input_acq = cassi_system.image_acquisition(use_psf=False)
input_acq = cassi_system.image_acquisition(use_psf=False) # H x (W + d*(28-1))
d=2
input_acq = shift_back(input_acq, step=d)

model.train()
output = model(input_acq, input_mask)
input_mask_3d = expand_mask_3d(input_mask) #TODO, like in train_method/utils.py
output = model(input_acq, input_mask_3d)
loss = torch.sqrt(mse(output, cassi_system.dataset))

loss.backward()
Expand All @@ -66,7 +76,8 @@ def train(model_name):
# cassi_system.dataset = datamodule.test_dataloader[i][0]
# cassi_system.wavelengths = datamodule.test_dataloader[i][1]
input_acq = cassi_system.image_acquisition(use_psf=False)
input_acq = shift_back(input_acq, step=d)

output = model(input_acq, input_mask)
output = model(input_acq, input_mask_3d)
loss = torch.sqrt(mse(output, cassi_system.dataset))
model.train()

0 comments on commit c246675

Please sign in to comment.