-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding training script and Main lightning module with all operations
- Loading branch information
1 parent
9f03905
commit bed97ca
Showing
2 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
import torch.nn as nn | ||
from simca.CassiSystemOptim import CassiSystemOptim | ||
from simca import load_yaml_config | ||
|
||
|
||
class EmptyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.useless_linear = nn.Linear(1, 1) | ||
def forward(self, x): | ||
|
||
return x | ||
|
||
|
||
class JointReconstructionModule_V1(pl.LightningModule): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
self.inittialize_cassi_system() | ||
#TODO | ||
# self.reconstruction_model = ReconstructionModel() | ||
self.reconstruction_model = EmptyModule() | ||
|
||
self.loss_fn = nn.MSELoss() | ||
|
||
def inittialize_cassi_system(self): | ||
|
||
config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml") | ||
self.config_patterns = load_yaml_config("simca/configs/pattern.yml") | ||
self.cassi_system = CassiSystemOptim(system_config=config_system) | ||
self.cassi_system.propagate_coded_aperture_grid() | ||
|
||
|
||
def forward(self, x): | ||
|
||
# generate random patterns (one for scene in the batch) | ||
patterns = self.cassi_system.generate_2D_pattern(self.config_patterns) | ||
|
||
# generate first acquisition with simca | ||
filtering_cube = self.cassi_system.generate_filtering_cube() | ||
acquired_image1 = self.cassi_system.image_acquisition(x) | ||
# process first acquisition with reconstruction model | ||
# TODO : replace by the reconstruction model | ||
reconstructed_cube = acquired_image1 | ||
|
||
return reconstructed_cube | ||
|
||
|
||
def training_step(self, batch, batch_idx): | ||
|
||
loss, y_hat, y = self._common_step(batch, batch_idx) | ||
self.log_dict( | ||
{ "train_loss": loss, | ||
}, | ||
on_step=True, | ||
on_epoch=True, | ||
prog_bar=True, | ||
) | ||
|
||
return {"loss": loss, "scores":y_hat, "y":y} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
loss, y_hat, y = self._common_step(batch, batch_idx) | ||
|
||
self.log_dict( | ||
{ "val_loss": loss, | ||
}, | ||
on_step=True, | ||
on_epoch=True, | ||
prog_bar=True, | ||
) | ||
|
||
return {"loss": loss, "scores":y_hat, "y":y} | ||
|
||
def test_step(self, batch, batch_idx): | ||
loss, y_hat, y = self._common_step(batch, batch_idx) | ||
self.log_dict( | ||
{ "test_loss": loss, | ||
}, | ||
on_step=True, | ||
on_epoch=True, | ||
prog_bar=True, | ||
) | ||
return {"loss": loss, "scores":y_hat, "y":y} | ||
|
||
def predict_step(self, batch, batch_idx): | ||
loss, _, _ = self._common_step(batch, batch_idx) | ||
self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True) | ||
return loss | ||
|
||
def _common_step(self, batch, batch_idx): | ||
|
||
hyperspectral_cube = batch | ||
y_hat = self.forward(hyperspectral_cube) | ||
loss = self.loss_fn(y_hat, hyperspectral_cube) | ||
|
||
return loss, y_hat | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytorch_lightning as pl | ||
from data_handler import CubesDataModule | ||
from optimization_modules import JointReconstructionModule_V1 | ||
|
||
data_dir = "./datasets_reconstruction/" | ||
|
||
datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1) | ||
|
||
name = "testing_simca_reconstruction" | ||
|
||
|
||
reconstruction_module = JointReconstructionModule_V1() | ||
|
||
trainer = pl.Trainer( accelerator="cpu", | ||
max_epochs=500, | ||
log_every_n_steps=100) | ||
|
||
trainer.fit(reconstruction_module, datamodule) |