Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure the number of image samples logged at each epoch and batch #49

Merged
merged 2 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ model:
loss_function: null
lr: 0.001
schedule: Constant
log_num_samples: 8
log_batches_per_epoch: 8
log_samples_per_batch: 1
data:
data_path: null
source_channel: null
Expand Down
1 change: 0 additions & 1 deletion examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ predict:
loss_function: null
lr: 0.001
schedule: Constant
log_num_samples: 8
data:
data_path: null
source_channel: null
Expand Down
21 changes: 13 additions & 8 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ class VSUNet(LightningModule):
:param float lr: learning rate in training, defaults to 1e-3
:param Literal['WarmupCosine', 'Constant'] schedule:
learning rate scheduler, defaults to "Constant"
:param int log_num_samples:
number of image samples to log each training/validation epoch,
has to be smaller than batch size, defaults to 8
:param int log_batches_per_epoch:
number of batches to log each training/validation epoch,
has to be smaller than steps per epoch, defaults to 8
:param int log_samples_per_batch:
number of samples to log each training/validation batch,
has to be smaller than batch size, defaults to 1
:param Sequence[int] example_input_yx_shape:
XY shape of the example input for network graph tracing, defaults to (256, 256)
:param str test_cellpose_model_path:
Expand All @@ -174,7 +177,8 @@ def __init__(
loss_function: Union[nn.Module, MixedLoss] = None,
lr: float = 1e-3,
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
log_num_samples: int = 8,
log_batches_per_epoch: int = 8,
log_samples_per_batch: int = 1,
example_input_yx_shape: Sequence[int] = (256, 256),
test_cellpose_model_path: str = None,
test_cellpose_diameter: float = None,
Expand All @@ -194,7 +198,8 @@ def __init__(
self.loss_function = loss_function if loss_function else nn.MSELoss()
self.lr = lr
self.schedule = schedule
self.log_num_samples = log_num_samples
self.log_batches_per_epoch = log_batches_per_epoch
self.log_samples_per_batch = log_samples_per_batch
self.training_step_outputs = []
self.validation_step_outputs = []
# required to log the graph
Expand Down Expand Up @@ -229,7 +234,7 @@ def training_step(self, batch: Sample, batch_idx: int):
logger=True,
sync_dist=True,
)
if batch_idx == 0:
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
Expand All @@ -241,7 +246,7 @@ def validation_step(self, batch: Sample, batch_idx: int):
pred = self.forward(source)
loss = self.loss_function(pred, target)
self.log("loss/validate", loss, sync_dist=True)
if batch_idx == 0:
if batch_idx < self.log_batches_per_epoch:
self.validation_step_outputs.extend(
self._detach_sample((source, target, pred))
)
Expand Down Expand Up @@ -392,7 +397,7 @@ def configure_optimizers(self):
return [optimizer], [scheduler]

def _detach_sample(self, imgs: Sequence[torch.Tensor]):
num_samples = min(imgs[0].shape[0], self.log_num_samples)
num_samples = min(imgs[0].shape[0], self.log_samples_per_batch)
return [
[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs]
for i in range(num_samples)
Expand Down