diff --git a/lighter/system.py b/lighter/system.py index fdf4aa82..a2479436 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -205,16 +205,32 @@ def _base_step(self, batch: Dict, batch_idx: int, mode: str) -> Union[Dict[str, target = apply_fns(target, self.postprocessing["logging"]["target"]) pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) + # If the loss is a dict, sum the sublosses under "combined" key. Any weightings should be applied in the criterion. + if isinstance(loss, dict): + if "combined" in loss: + raise ValueError("The loss dictionary cannot contain a key 'combined'.") + loss["combined"] = sum(loss.values()) + # Logging self._log_stats(loss, metrics, mode, batch_idx) - return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id} - - def _log_stats(self, loss: torch.Tensor, metrics: MetricCollection, mode: str, batch_idx: int) -> None: + # Return the loss as required by Lightning as well as other data that can be used in hooks or callbacks. + return { + "loss": loss["combined"] if isinstance(loss, dict) else loss, + "metrics": metrics, + "input": input, + "target": target, + "pred": pred, + "id": id, + } + + def _log_stats( + self, loss: Union[torch.Tensor, Dict[str, torch.Tensor]], metrics: MetricCollection, mode: str, batch_idx: int + ) -> None: """ Logs the loss, metrics, and optimizer statistics. Args: - loss (torch.Tensor): Calculated loss. + loss (Union[torch.Tensor, Dict[str, torch.Tensor]]): Calculated loss or a dict of sublosses. metrics (MetricCollection): Calculated metrics. mode (str): Mode of operation (train/val/test/predict). batch_idx (int): Index of current batch. @@ -222,24 +238,27 @@ def _log_stats(self, loss: torch.Tensor, metrics: MetricCollection, mode: str, b if self.trainer.logger is None: return - # Arguments for self.log() - log_kwargs = {"logger": True, "batch_size": self.batch_size} - on_step_log_kwargs = {"on_epoch": False, "on_step": True, "sync_dist": False} - on_epoch_log_kwargs = {"on_epoch": True, "on_step": False, "sync_dist": True} + on_step_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=True, on_epoch=False, sync_dist=False) + on_epoch_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=False, on_epoch=True, sync_dist=True) # Loss if loss is not None: - self.log(f"{mode}/loss/step", loss, **log_kwargs, **on_step_log_kwargs) - self.log(f"{mode}/loss/epoch", loss, **log_kwargs, **on_epoch_log_kwargs) + if not isinstance(loss, dict): + on_step_log(f"{mode}/loss/step", loss) + on_epoch_log(f"{mode}/loss/epoch", loss) + else: + for name, subloss in loss.items(): + on_step_log(f"{mode}/loss/{name}/step", subloss) + on_epoch_log(f"{mode}/loss/{name}/epoch", subloss) # Metrics if metrics is not None: - for k, v in metrics.items(): - self.log(f"{mode}/metrics/{k}/step", v, **log_kwargs, **on_step_log_kwargs) - self.log(f"{mode}/metrics/{k}/epoch", v, **log_kwargs, **on_epoch_log_kwargs) + for name, metric in metrics.items(): + on_step_log(f"{mode}/metrics/{name}/step", metric) + on_epoch_log(f"{mode}/metrics/{name}/epoch", metric) # Optimizer's lr, momentum, beta. Logged in train mode and once per epoch. if mode == "train" and batch_idx == 0: - for k, v in get_optimizer_stats(self.optimizer).items(): - self.log(f"{mode}/{k}", v, **log_kwargs, **on_epoch_log_kwargs) + for name, optimizer_stat in get_optimizer_stats(self.optimizer).items(): + on_epoch_log(f"{mode}/{name}", optimizer_stat) def _base_dataloader(self, mode: str) -> DataLoader: """Instantiate the dataloader for a mode (train/val/test/predict).