Skip to content

Commit

Permalink
adding training script and Main lightning module with all operations
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 6, 2024
1 parent 9f03905 commit bed97ca
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
106 changes: 106 additions & 0 deletions optimization_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
from simca.CassiSystemOptim import CassiSystemOptim
from simca import load_yaml_config


class EmptyModule(nn.Module):
def __init__(self):
super().__init__()
self.useless_linear = nn.Linear(1, 1)
def forward(self, x):

return x


class JointReconstructionModule_V1(pl.LightningModule):

def __init__(self):
super().__init__()

self.inittialize_cassi_system()
#TODO
# self.reconstruction_model = ReconstructionModel()
self.reconstruction_model = EmptyModule()

self.loss_fn = nn.MSELoss()

def inittialize_cassi_system(self):

config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml")
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()


def forward(self, x):

# generate random patterns (one for scene in the batch)
patterns = self.cassi_system.generate_2D_pattern(self.config_patterns)

# generate first acquisition with simca
filtering_cube = self.cassi_system.generate_filtering_cube()
acquired_image1 = self.cassi_system.image_acquisition(x)
# process first acquisition with reconstruction model
# TODO : replace by the reconstruction model
reconstructed_cube = acquired_image1

return reconstructed_cube


def training_step(self, batch, batch_idx):

loss, y_hat, y = self._common_step(batch, batch_idx)
self.log_dict(
{ "train_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss, "scores":y_hat, "y":y}

def validation_step(self, batch, batch_idx):
loss, y_hat, y = self._common_step(batch, batch_idx)

self.log_dict(
{ "val_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss, "scores":y_hat, "y":y}

def test_step(self, batch, batch_idx):
loss, y_hat, y = self._common_step(batch, batch_idx)
self.log_dict(
{ "test_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)
return {"loss": loss, "scores":y_hat, "y":y}

def predict_step(self, batch, batch_idx):
loss, _, _ = self._common_step(batch, batch_idx)
self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

def _common_step(self, batch, batch_idx):

hyperspectral_cube = batch
y_hat = self.forward(hyperspectral_cube)
loss = self.loss_fn(y_hat, hyperspectral_cube)

return loss, y_hat

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer


18 changes: 18 additions & 0 deletions training_simca_reconstruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytorch_lightning as pl
from data_handler import CubesDataModule
from optimization_modules import JointReconstructionModule_V1

data_dir = "./datasets_reconstruction/"

datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1)

name = "testing_simca_reconstruction"


reconstruction_module = JointReconstructionModule_V1()

trainer = pl.Trainer( accelerator="cpu",
max_epochs=500,
log_every_n_steps=100)

trainer.fit(reconstruction_module, datamodule)

0 comments on commit bed97ca

Please sign in to comment.