From 35fc17359b8e3d2070a0843aeb8de3fb6dd64b59 Mon Sep 17 00:00:00 2001 From: Antoine Rouxel Date: Sat, 9 Mar 2024 19:33:12 +0100 Subject: [PATCH] adding training with resnet and reconstruction --- optimization_modules_with_resnet_v2.py | 6 +++--- training_simca_reconstruction_with_resnet_v2.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/optimization_modules_with_resnet_v2.py b/optimization_modules_with_resnet_v2.py index 25bf528..62081bf 100755 --- a/optimization_modules_with_resnet_v2.py +++ b/optimization_modules_with_resnet_v2.py @@ -239,7 +239,7 @@ def _common_step(self, batch, batch_idx): loss1 = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube)) loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2) - loss = loss1 + loss2 + loss = loss1 ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2)) print(f"loss1 {loss1}") @@ -247,7 +247,7 @@ def _common_step(self, batch, batch_idx): return loss, ssim_loss, reconstructed_cube, ref_cube def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=4e-4) + optimizer = torch.optim.Adam(self.parameters(), lr=1e-4) return { "optimizer":optimizer, "lr_scheduler":{ "scheduler":torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6), @@ -382,4 +382,4 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): # For backward pass, just pass the gradients through unchanged - return grad_output \ No newline at end of file + return grad_output diff --git a/training_simca_reconstruction_with_resnet_v2.py b/training_simca_reconstruction_with_resnet_v2.py index f8e7fff..22148eb 100755 --- a/training_simca_reconstruction_with_resnet_v2.py +++ b/training_simca_reconstruction_with_resnet_v2.py @@ -11,20 +11,20 @@ # data_dir = "./datasets_reconstruction/" data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" +# data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5) datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') name = "testing_simca_reconstruction_full" model_name = "dauhst_9" -reconstruction_checkpoint = "/home/lpaillet/Documents/simca/tb_logs/testing_simca_reconstruction/version_24/checkpoints/epoch=499-step=18000.ckpt" +reconstruction_checkpoint = "./checkpoints/epoch=499-step=18000.ckpt" resnet_checkpoint = None log_dir = 'tb_logs' train = True -retrain_recons = False +retrain_recons = True logger = TensorBoardLogger(log_dir, name=name) @@ -49,12 +49,15 @@ sub_module = JointReconstructionModule_V1(model_name, log_dir) sub_module.load_state_dict(checkpoint["state_dict"]) + +resnet_checkpoint = "./checkpoints/best-checkpoint_resnet_only_24-03-09_18h05.ckpt" + if not retrain_recons or not train: sub_module.eval() reconstruction_module = JointReconstructionModule_V3(sub_module, log_dir=log_dir+'/'+ name, - reconstruction_checkpoint = reconstruction_checkpoint) + resnet_checkpoint=resnet_checkpoint) if torch.cuda.is_available(): @@ -73,4 +76,4 @@ if train: trainer.fit(reconstruction_module, datamodule) else: - trainer.predict(reconstruction_module, datamodule, ckpt_path=resnet_checkpoint) \ No newline at end of file + trainer.predict(reconstruction_module, datamodule, ckpt_path=resnet_checkpoint)