diff --git a/optimization_modules_with_resnet.py b/optimization_modules_with_resnet.py index 5f63332..8e90e8f 100755 --- a/optimization_modules_with_resnet.py +++ b/optimization_modules_with_resnet.py @@ -41,6 +41,9 @@ def __init__(self, model_name,log_dir="tb_logs",reconstruction_checkpoint=None): self.writer = SummaryWriter(log_dir) + # for param in self.reconstruction_model.parameters(): + # param.requires_grad = False + def on_validation_start(self,stage=None): print("---VALIDATION START---") self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml"