Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support subloss logging through loss dicts #111

Merged
merged 6 commits into from
May 24, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,41 +205,60 @@ def _base_step(self, batch: Dict, batch_idx: int, mode: str) -> Union[Dict[str,
target = apply_fns(target, self.postprocessing["logging"]["target"])
pred = apply_fns(pred, self.postprocessing["logging"]["pred"])

# If the loss is a dict, the sublosses must be combined under "total" key.
if isinstance(loss, dict) and "total" not in loss:
raise ValueError("The loss dictionary must include a 'total' key that combines all sublosses. Example: {'total': combined_loss, 'subloss1': loss1, ...}")

# Logging
self._log_stats(loss, metrics, mode, batch_idx)

return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id}

def _log_stats(self, loss: torch.Tensor, metrics: MetricCollection, mode: str, batch_idx: int) -> None:
# Return the loss as required by Lightning as well as other data that can be used in hooks or callbacks.
return {
"loss": loss["total"] if isinstance(loss, dict) else loss,
"metrics": metrics,
"input": input,
"target": target,
"pred": pred,
"id": id,
}

def _log_stats(
self, loss: Union[torch.Tensor, Dict[str, torch.Tensor]], metrics: MetricCollection, mode: str, batch_idx: int
) -> None:
"""
Logs the loss, metrics, and optimizer statistics.
Args:
loss (torch.Tensor): Calculated loss.
loss (Union[torch.Tensor, Dict[str, torch.Tensor]]): Calculated loss or a dict of sublosses.
metrics (MetricCollection): Calculated metrics.
mode (str): Mode of operation (train/val/test/predict).
batch_idx (int): Index of current batch.
"""
if self.trainer.logger is None:
return

# Arguments for self.log()
log_kwargs = {"logger": True, "batch_size": self.batch_size}
on_step_log_kwargs = {"on_epoch": False, "on_step": True, "sync_dist": False}
on_epoch_log_kwargs = {"on_epoch": True, "on_step": False, "sync_dist": True}
on_step_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=True, on_epoch=False, sync_dist=False)
on_epoch_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=False, on_epoch=True, sync_dist=True)

# Loss
if loss is not None:
self.log(f"{mode}/loss/step", loss, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/loss/epoch", loss, **log_kwargs, **on_epoch_log_kwargs)
if not isinstance(loss, dict):
on_step_log(f"{mode}/loss/step", loss)
on_epoch_log(f"{mode}/loss/epoch", loss)
else:
for name, subloss in loss.items():
on_step_log(f"{mode}/loss/{name}/step", subloss)
on_epoch_log(f"{mode}/loss/{name}/epoch", subloss)
# Metrics
if metrics is not None:
for k, v in metrics.items():
self.log(f"{mode}/metrics/{k}/step", v, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/metrics/{k}/epoch", v, **log_kwargs, **on_epoch_log_kwargs)
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise TypeError(f"Expected type for metric is 'Metric', got '{type(metric).__name__}' instead.")
on_step_log(f"{mode}/metrics/{name}/step", metric)
on_epoch_log(f"{mode}/metrics/{name}/epoch", metric)
# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
if mode == "train" and batch_idx == 0:
for k, v in get_optimizer_stats(self.optimizer).items():
self.log(f"{mode}/{k}", v, **log_kwargs, **on_epoch_log_kwargs)
for name, optimizer_stat in get_optimizer_stats(self.optimizer).items():
on_epoch_log(f"{mode}/{name}", optimizer_stat)
Comment on lines +263 to +264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure robust configuration of optimizers.

Consider adding error handling for cases where the scheduler is not specified but is expected to be used. This can help prevent runtime errors and improve the robustness of the system.


Refactor dynamic method setup.

Consider refactoring the dynamic setup of methods to improve readability and maintainability. This can be achieved by separating the logic for each stage into separate methods and calling these methods from the setup method.


def _base_dataloader(self, mode: str) -> DataLoader:
"""Instantiate the dataloader for a mode (train/val/test/predict).
Expand Down
Loading