Skip to content

Commit

Permalink
Support subloss logging through loss dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Apr 26, 2024
1 parent 3c2fffa commit 388c9b6
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,41 +205,60 @@ def _base_step(self, batch: Dict, batch_idx: int, mode: str) -> Union[Dict[str,
target = apply_fns(target, self.postprocessing["logging"]["target"])
pred = apply_fns(pred, self.postprocessing["logging"]["pred"])

# If the loss is a dict, sum the sublosses under "combined" key. Any weightings should be applied in the criterion.
if isinstance(loss, dict):
if "combined" in loss:
raise ValueError("The loss dictionary cannot contain a key 'combined'.")
loss["combined"] = sum(loss.values())

# Logging
self._log_stats(loss, metrics, mode, batch_idx)

return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id}

def _log_stats(self, loss: torch.Tensor, metrics: MetricCollection, mode: str, batch_idx: int) -> None:
# Return the loss as required by Lightning as well as other data that can be used in hooks or callbacks.
return {
"loss": loss["combined"] if isinstance(loss, dict) else loss,
"metrics": metrics,
"input": input,
"target": target,
"pred": pred,
"id": id,
}

def _log_stats(
self, loss: Union[torch.Tensor, Dict[str, torch.Tensor]], metrics: MetricCollection, mode: str, batch_idx: int
) -> None:
"""
Logs the loss, metrics, and optimizer statistics.
Args:
loss (torch.Tensor): Calculated loss.
loss (Union[torch.Tensor, Dict[str, torch.Tensor]]): Calculated loss or a dict of sublosses.
metrics (MetricCollection): Calculated metrics.
mode (str): Mode of operation (train/val/test/predict).
batch_idx (int): Index of current batch.
"""
if self.trainer.logger is None:
return

# Arguments for self.log()
log_kwargs = {"logger": True, "batch_size": self.batch_size}
on_step_log_kwargs = {"on_epoch": False, "on_step": True, "sync_dist": False}
on_epoch_log_kwargs = {"on_epoch": True, "on_step": False, "sync_dist": True}
on_step_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=True, on_epoch=False, sync_dist=False)
on_epoch_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=False, on_epoch=True, sync_dist=True)

# Loss
if loss is not None:
self.log(f"{mode}/loss/step", loss, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/loss/epoch", loss, **log_kwargs, **on_epoch_log_kwargs)
if not isinstance(loss, dict):
on_step_log(f"{mode}/loss/step", loss)
on_epoch_log(f"{mode}/loss/epoch", loss)
else:
for name, subloss in loss.items():
on_step_log(f"{mode}/loss/{name}/step", subloss)
on_epoch_log(f"{mode}/loss/{name}/epoch", subloss)
# Metrics
if metrics is not None:
for k, v in metrics.items():
self.log(f"{mode}/metrics/{k}/step", v, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/metrics/{k}/epoch", v, **log_kwargs, **on_epoch_log_kwargs)
for name, metric in metrics.items():
on_step_log(f"{mode}/metrics/{name}/step", metric)
on_epoch_log(f"{mode}/metrics/{name}/epoch", metric)
# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
if mode == "train" and batch_idx == 0:
for k, v in get_optimizer_stats(self.optimizer).items():
self.log(f"{mode}/{k}", v, **log_kwargs, **on_epoch_log_kwargs)
for name, optimizer_stat in get_optimizer_stats(self.optimizer).items():
on_epoch_log(f"{mode}/{name}", optimizer_stat)

def _base_dataloader(self, mode: str) -> DataLoader:
"""Instantiate the dataloader for a mode (train/val/test/predict).
Expand Down

0 comments on commit 388c9b6

Please sign in to comment.