From 40ec4187cd71f9984d2a6193d0df97fe322db390 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 6 Sep 2023 14:14:34 -0700 Subject: [PATCH 1/2] log sample size at epoch and batch levels --- viscy/light/engine.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 3f6c8ce1..3e0ab99b 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -152,9 +152,12 @@ class VSUNet(LightningModule): :param float lr: learning rate in training, defaults to 1e-3 :param Literal['WarmupCosine', 'Constant'] schedule: learning rate scheduler, defaults to "Constant" - :param int log_num_samples: - number of image samples to log each training/validation epoch, - has to be smaller than batch size, defaults to 8 + :param int log_batches_per_epoch: + number of batches to log each training/validation epoch, + has to be smaller than steps per epoch, defaults to 8 + :param int log_samples_per_batch: + number of samples to log each training/validation batch, + has to be smaller than batch size, defaults to 1 :param Sequence[int] example_input_yx_shape: XY shape of the example input for network graph tracing, defaults to (256, 256) :param str test_cellpose_model_path: @@ -174,7 +177,8 @@ def __init__( loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", - log_num_samples: int = 8, + log_batches_per_epoch: int = 8, + log_samples_per_batch: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), test_cellpose_model_path: str = None, test_cellpose_diameter: float = None, @@ -194,7 +198,8 @@ def __init__( self.loss_function = loss_function if loss_function else nn.MSELoss() self.lr = lr self.schedule = schedule - self.log_num_samples = log_num_samples + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] self.validation_step_outputs = [] # required to log the graph @@ -229,7 +234,7 @@ def training_step(self, batch: Sample, batch_idx: int): logger=True, sync_dist=True, ) - if batch_idx == 0: + if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( self._detach_sample((source, target, pred)) ) @@ -241,7 +246,7 @@ def validation_step(self, batch: Sample, batch_idx: int): pred = self.forward(source) loss = self.loss_function(pred, target) self.log("loss/validate", loss, sync_dist=True) - if batch_idx == 0: + if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) ) @@ -392,7 +397,7 @@ def configure_optimizers(self): return [optimizer], [scheduler] def _detach_sample(self, imgs: Sequence[torch.Tensor]): - num_samples = min(imgs[0].shape[0], self.log_num_samples) + num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) return [ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] for i in range(num_samples) From 40d8b847b52de3973c7430f2c6aef171aca2188d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 6 Sep 2023 14:14:42 -0700 Subject: [PATCH 2/2] update example configs --- examples/configs/fit_example.yml | 3 ++- examples/configs/predict_example.yml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index a8274bcb..dd46ac27 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -44,7 +44,8 @@ model: loss_function: null lr: 0.001 schedule: Constant - log_num_samples: 8 + log_batches_per_epoch: 8 + log_samples_per_batch: 1 data: data_path: null source_channel: null diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index 3a74da51..88bf48e9 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -50,7 +50,6 @@ predict: loss_function: null lr: 0.001 schedule: Constant - log_num_samples: 8 data: data_path: null source_channel: null