From 87c2baa21407227a752d69d7f34ef8910efae028 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 1 Feb 2023 15:47:24 -0500 Subject: [PATCH 01/11] Add first prototype of LighterWriter --- .gitignore | 4 + lighter/callbacks/__init__.py | 1 + lighter/callbacks/logger.py | 108 ++--------------- lighter/callbacks/utils.py | 113 ++++++++++++++++++ lighter/callbacks/writer.py | 78 ++++++++++++ .../experiments/monai_bundle_prototype.yaml | 9 +- 6 files changed, 216 insertions(+), 97 deletions(-) create mode 100644 lighter/callbacks/utils.py create mode 100644 lighter/callbacks/writer.py diff --git a/.gitignore b/.gitignore index eda9a9f4..58e42793 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,10 @@ tensorboard/ prototyping.ipynb checkpoints/ +# Our ignores projects/* !projects/README.md !projects/cifar10 +contrib/ +**/predictions/ + diff --git a/lighter/callbacks/__init__.py b/lighter/callbacks/__init__.py index c3a3edc8..48a352b0 100644 --- a/lighter/callbacks/__init__.py +++ b/lighter/callbacks/__init__.py @@ -1 +1,2 @@ from .logger import LighterLogger +from .writer import LighterWriter diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 7d1e4718..f393281f 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -1,13 +1,11 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Union -import re import sys from datetime import datetime from pathlib import Path import torch import torch.distributed as dist -import torchvision from loguru import logger from monai.utils.module import optional_import from pytorch_lightning import Callback, Trainer @@ -15,22 +13,22 @@ from yaml import safe_load from lighter import LighterSystem +from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, parse_data, check_supported_data_type, preprocess_image -LIGHTNING_TO_LIGHTER_STAGE = {"train": "train", "validate": "val", "test": "test"} OPTIONAL_IMPORTS = {} class LighterLogger(Callback): def __init__( self, - project, - log_dir, - tensorboard=False, - wandb=False, - input_type=None, - target_type=None, - pred_type=None, - max_samples=None, + project: str, + log_dir: str, + tensorboard: bool = False, + wandb: bool = False, + input_type: str = None, + target_type: str = None, + pred_type: str = None, + max_samples: int = None, ) -> None: self.project = project # Only used on rank 0, the dir is created in setup(). @@ -146,8 +144,8 @@ def _log(self, outputs: dict, mode: str, global_step: int, is_epoch=False) -> No # Image elif data_type == "image": # Check if the data type is valid. - check_image_data_type(data, data_name) - for identifier, image in parse_image_data(data): + check_supported_data_type(data, data_name) + for identifier, image in parse_data(data): name = name if identifier is None else f"{name}_{identifier}" # Slice to `max_samples` only if it less than the batch size. if self.max_samples is not None and self.max_samples < image.shape[0]: @@ -304,85 +302,3 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> def on_test_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: self._on_epoch_end(trainer, pl_module) - -def preprocess_image(image: torch.Tensor) -> torch.Tensor: - """Preprocess the image before logging it. If it is a batch of multiple images, - it will create a grid image of them. In case of 3D, a single image is displayed - with slices stacked vertically, while a batch as a grid where each column is - a different 3D image. - Args: - image (torch.Tensor): 2D or 3D image tensor. - Returns: - torch.Tensor: image ready for logging. - """ - image = image.detach().cpu() - # 3D image (NCDHW) - has_three_dims = image.ndim == 5 - if has_three_dims: - # Reshape 3D image from NCDHW to NC(D*H)W format - shape = image.shape - image = image.view(shape[0], shape[1], shape[2] * shape[3], shape[4]) - if image.shape[0] == 1: - image = image[0] - else: - # If more than one image, create a grid - nrow = image.shape[0] if has_three_dims else 8 - image = torchvision.utils.make_grid(image, nrow=nrow) - return image - - -def check_image_data_type(data: Any, name: str) -> None: - """Check the input image data for its type. Valid image data types are: - - torch.Tensor - - List[torch.Tensor] - - Dict[str, torch.Tensor] - - Dict[str, List[torch.Tensor]] - - Args: - data (Any): image data to check - name (str): name of the image data, for logging purposes. - """ - if isinstance(data, dict): - is_valid = all(check_image_data_type(elem) for elem in data.values()) - elif isinstance(data, list): - is_valid = all(check_image_data_type(elem) for elem in data) - elif isinstance(data, torch.Tensor): - is_valid = True - else: - is_valid = False - - if not is_valid: - logger.error( - f"`{name}` has to be a Tensor, List[Tensors], Dict[str, Tensor]" - f", or Dict[str, List[Tensor]]. `{type(data)}` is not supported." - ) - sys.exit() - - -def parse_image_data( - data: Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], List[torch.Tensor], torch.Tensor] -) -> List[Tuple[Optional[str], torch.Tensor]]: - """Given input data, this function will parse it and return a list of tuples where - each tuple contains an identifier and a tensor. - - Args: - data (Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], List[torch.Tensor], torch.Tensor]): image tensor(s). - - Returns: - List[Tuple[Optional[str], torch.Tensor]]: a list of tuples where the first element is - a string identifier (or `None` if there is only one image) and the second an image tensor. - """ - result = [] - if isinstance(data, dict): - for key, value in data.items(): - if isinstance(value, list): - for i, tensor in enumerate(value): - result.append((f"{key}_{i}", tensor) if len(value > 1) else (key, tensor)) - else: - result.append((key, value)) - elif isinstance(data, list): - for i, tensor in enumerate(data): - result.append((str(i), tensor)) - else: - result.append((None, data)) - return result diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py new file mode 100644 index 00000000..17fefd85 --- /dev/null +++ b/lighter/callbacks/utils.py @@ -0,0 +1,113 @@ +import sys +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torchvision +from loguru import logger + +LIGHTNING_TO_LIGHTER_STAGE = {"train": "train", "validate": "val", "test": "test"} + + +def parse_data( + data: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], Dict[str, Tuple[torch.Tensor]]] +) -> List[Tuple[Optional[str], torch.Tensor]]: + """Given input data, this function will parse it and return a list of tuples where + each tuple contains an identifier and a tensor. + + Args: + data (Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], Dict[str, Tuple[torch.Tensor]]]): + input data to parse. + + Returns: + List[Tuple[Optional[str], torch.Tensor]]: a list of tuples where the first element is the string + identifier (`None` if there is only one tensor), and the second is the actual tensor. + """ + result = [] + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, (list, tuple)): + for i, tensor in enumerate(value): + result.append((f"{key}_{i}", tensor) if len(value > 1) else (key, tensor)) + else: + result.append((key, value)) + elif isinstance(data, (list, tuple)): + for i, tensor in enumerate(data): + result.append((str(i), tensor)) + else: + result.append((None, data)) + return result + + + +def check_supported_data_type(data: Any, name: str) -> None: + """Check the input data for its type. Valid data types are: + - torch.Tensor + - List[torch.Tensor] + - Tuple[torch.Tensor] + - Dict[str, torch.Tensor] + - Dict[str, List[torch.Tensor]] + - Dict[str, Tuple[torch.Tensor]] + + Args: + data (Any): input data to check + name (str): name of the data, for identification purposes. + """ + if isinstance(data, dict): + is_valid = all(check_supported_data_type(elem) for elem in data.values()) + elif isinstance(data, (list, tuple)): + is_valid = all(check_supported_data_type(elem) for elem in data) + elif isinstance(data, torch.Tensor): + is_valid = True + else: + is_valid = False + + if not is_valid: + logger.error( + f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], " + f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(data)}` is not supported." + ) + sys.exit() + + +def concatenate(outputs: Union[List[Any], Tuple[Any]]) -> Union[torch.Tensor, List[Union[str, int, float]]]: + # List of dicts. + if isinstance(outputs[0], dict): + # Go over dictionaries and concatenate tensors by key. + result = {key: concatenate([output[key] for output in outputs]) for key in outputs[0]} + # List of lists or tuples. + elif isinstance(outputs[0], (list, tuple)): + # Go over lists/tuples and concatenate tensors by their position. + result = [concatenate([output[idx] for output in outputs]) for idx in range(len(outputs[0]))] + # List of tensors. + elif isinstance(outputs[0], torch.Tensor): + result = torch.cat(outputs) + else: + logger.error(f"Type `{type(outputs[0])}` not supported.") + sys.exit() + return result + + +def preprocess_image(image: torch.Tensor) -> torch.Tensor: + """Preprocess the image before logging it. If it is a batch of multiple images, + it will create a grid image of them. In case of 3D, a single image is displayed + with slices stacked vertically, while a batch as a grid where each column is + a different 3D image. + Args: + image (torch.Tensor): 2D or 3D image tensor. + Returns: + torch.Tensor: image ready for logging. + """ + image = image.detach().cpu() + # 3D image (NCDHW) + has_three_dims = image.ndim == 5 + if has_three_dims: + # Reshape 3D image from NCDHW to NC(D*H)W format + shape = image.shape + image = image.view(shape[0], shape[1], shape[2] * shape[3], shape[4]) + if image.shape[0] == 1: + image = image[0] + else: + # If more than one image, create a grid + nrow = image.shape[0] if has_three_dims else 8 + image = torchvision.utils.make_grid(image, nrow=nrow) + return image diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py new file mode 100644 index 00000000..003e486d --- /dev/null +++ b/lighter/callbacks/writer.py @@ -0,0 +1,78 @@ +import sys +from typing import Any, Dict, List, Optional, Tuple, Union +import itertools +from pathlib import Path +from datetime import datetime + +from loguru import logger +import torch +import torchvision +from pytorch_lightning import Callback, Trainer + +from lighter import LighterSystem +from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, parse_data, concatenate, preprocess_image + + +class LighterWriter(Callback): + def __init__(self, write_dir: str, write_as: str, write_on: str = "step", write_to_csv: bool = False) -> None: + self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") + self.write_as = write_as + self.write_on = write_on + self.write_to_csv = write_to_csv + + def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: + if self.write_on not in ["step", "epoch"]: + logger.error("`write_on` must be either 'step' or 'epoch'.") + sys.exit() + + if self.write_to_csv and self.write_as in ["image", "tensor"]: + logger.error(f"`write_as={self.write_as}` cannot be written to a CSV. Change `write_as` or disable `write_to_csv`.") + sys.exit() + + self.write_dir.mkdir(parents=True) + + def _write(self, outputs, indices): + for identifier, data in parse_data(outputs): + for idx, tensor in zip(indices, data): + name = f"step_{idx}" if identifier is None else f"step_{idx}_{identifier}" + if self.write_as == "tensor": + path = self.write_dir / f"{self.write_as}_{name}.pt" + torch.save(tensor, path) + elif self.write_as == "image": + path = self.write_dir / f"{self.write_as}_{name}.png" + torchvision.utils.save_image(preprocess_image(tensor), path) + elif self.write_as == "scalar": + raise NotImplementedError + elif self.write_as == "audio": + raise NotImplementedError + elif self.write_as == "video": + raise NotImplementedError + else: + logger.error(f"`write_as` does not support '{self.write_as}'.") + sys.exit() + + def on_predict_batch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + if self.write_on != "step": + return + indices = trainer.predict_loop.epoch_loop.current_batch_indices + self._write(outputs, indices) + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: + if self.write_on != "epoch": + return + + # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. + indices = trainer.predict_loop.epoch_batch_indices[0] + outputs = outputs[0] + + # Concatenate/flatten into a list of indices. + indices = list(itertools.chain(*indices)) + # Concatenate/flatten the outputs so that each output corresponds to its index in `indices`. + outputs = concatenate(outputs) + + self._write(outputs, indices) + + def on_predict_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: + # Dump the CSV + pass + diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index 431b2bca..bea11a02 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -17,11 +17,16 @@ trainer: input_type: image max_samples: 10 + - _target_: lighter.callbacks.LighterWriter + write_dir: "$@project + '/predictions' " + write_as: "tensor" + write_on: "epoch" # "step" + system: _target_: lighter.LighterSystem batch_size: 512 pin_memory: True - num_workers: 1 + num_workers: 2 model: _target_: "project.models.net.Net" @@ -106,3 +111,5 @@ system: mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] target_transform: null + + predict_dataset: "%system#test_dataset" From db4508b950c721cc3a48e603c4041eb214e70f77 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 1 Feb 2023 15:52:12 -0500 Subject: [PATCH 02/11] style --- .gitignore | 2 -- lighter/callbacks/logger.py | 3 +-- lighter/callbacks/utils.py | 12 +++++++++--- lighter/callbacks/writer.py | 22 +++++++++++++--------- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 58e42793..0d3b8cbc 100644 --- a/.gitignore +++ b/.gitignore @@ -147,6 +147,4 @@ checkpoints/ projects/* !projects/README.md !projects/cifar10 -contrib/ **/predictions/ - diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index f393281f..313db899 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -13,7 +13,7 @@ from yaml import safe_load from lighter import LighterSystem -from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, parse_data, check_supported_data_type, preprocess_image +from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, check_supported_data_type, parse_data, preprocess_image OPTIONAL_IMPORTS = {} @@ -301,4 +301,3 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> def on_test_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: self._on_epoch_end(trainer, pl_module) - diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 17fefd85..a7938278 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -1,6 +1,7 @@ -import sys from typing import Any, Dict, List, Optional, Tuple, Union +import sys + import torch import torchvision from loguru import logger @@ -9,7 +10,13 @@ def parse_data( - data: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], Dict[str, Tuple[torch.Tensor]]] + data: Union[ + torch.Tensor, + List[torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, List[torch.Tensor]], + Dict[str, Tuple[torch.Tensor]], + ] ) -> List[Tuple[Optional[str], torch.Tensor]]: """Given input data, this function will parse it and return a list of tuples where each tuple contains an identifier and a tensor. @@ -38,7 +45,6 @@ def parse_data( return result - def check_supported_data_type(data: Any, name: str) -> None: """Check the input data for its type. Valid data types are: - torch.Tensor diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index 003e486d..77347e88 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -1,16 +1,17 @@ -import sys from typing import Any, Dict, List, Optional, Tuple, Union + import itertools -from pathlib import Path +import sys from datetime import datetime +from pathlib import Path -from loguru import logger import torch import torchvision +from loguru import logger from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, parse_data, concatenate, preprocess_image +from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, concatenate, parse_data, preprocess_image class LighterWriter(Callback): @@ -26,7 +27,9 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: sys.exit() if self.write_to_csv and self.write_as in ["image", "tensor"]: - logger.error(f"`write_as={self.write_as}` cannot be written to a CSV. Change `write_as` or disable `write_to_csv`.") + logger.error( + f"`write_as={self.write_as}` cannot be written to a CSV. Change `write_as` or disable `write_to_csv`." + ) sys.exit() self.write_dir.mkdir(parents=True) @@ -34,7 +37,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: def _write(self, outputs, indices): for identifier, data in parse_data(outputs): for idx, tensor in zip(indices, data): - name = f"step_{idx}" if identifier is None else f"step_{idx}_{identifier}" + name = f"step_{idx}" if identifier is None else f"step_{idx}_{identifier}" if self.write_as == "tensor": path = self.write_dir / f"{self.write_as}_{name}.pt" torch.save(tensor, path) @@ -51,7 +54,9 @@ def _write(self, outputs, indices): logger.error(f"`write_as` does not support '{self.write_as}'.") sys.exit() - def on_predict_batch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_predict_batch_end( + self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: if self.write_on != "step": return indices = trainer.predict_loop.epoch_loop.current_batch_indices @@ -71,8 +76,7 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu outputs = concatenate(outputs) self._write(outputs, indices) - + def on_predict_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Dump the CSV pass - From 68559d6928893337b05eda94ef233963367c20f9 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 1 Feb 2023 23:14:30 -0500 Subject: [PATCH 03/11] add DDP support to LighterWriter, improve DDP in LighterLogger --- lighter/callbacks/logger.py | 30 ++++++++++-------------------- lighter/callbacks/writer.py | 13 ++++++++++++- lighter/utils/misc.py | 6 +++--- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 7fce1ae5..0a709ef3 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -5,7 +5,6 @@ from pathlib import Path import torch -import torch.distributed as dist from loguru import logger from monai.utils.module import optional_import from pytorch_lightning import Callback, Trainer @@ -61,7 +60,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: logger.error("When using LighterLogger, set Trainer(logger=None).") sys.exit() - if dist.is_initialized() and dist.get_rank() != 0: + if not trainer.is_global_zero: return self.log_dir.mkdir(parents=True) @@ -93,7 +92,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # self.wandb.config.update(config) def teardown(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: - if dist.is_initialized() and dist.get_rank() != 0: + if not trainer.is_global_zero: return self.tensorboard.close() @@ -108,9 +107,6 @@ def _log(self, outputs: dict, mode: str, global_step: int, is_epoch=False) -> No is_epoch (bool): whether the log is being done at the end of an epoch or astep. Default is False. """ - if dist.is_initialized() and dist.get_rank() != 0: - return - step_or_epoch = "epoch" if is_epoch else "step" # Loss @@ -201,9 +197,8 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: # Accumulate the loss. if mode in ["train", "val"]: self.loss[mode] += outputs["loss"].item() - # Logging frequency. - if self.global_step_counter[mode] % trainer.log_every_n_steps == 0: - # Log. Done only on rank 0. + # Logging frequency. Log only on rank 0. + if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0: self._log(outputs, mode, global_step=self._get_global_step(trainer)) # Increment the step counters. self.global_step_counter[mode] += 1 @@ -213,7 +208,7 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """Performs logging at the end of an epoch. It calculates the average loss and metrics for the epoch and logs them. In distributed mode, it averages - the losses and metrics from all processes. + the losses and metrics from all ranks. Args: trainer (Trainer): Trainer, passed automatically by PyTorch Lightning. @@ -227,14 +222,8 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: if mode in ["train", "val"]: # Get the accumulated loss. loss = self.loss[mode] - # Reduce the loss to rank 0 and average it. - if dist.is_initialized(): - # Distributed communication works only tensors. - loss = torch.tensor(loss).to(pl_module.device) - # On rank 0, sum the losses from all ranks. Other ranks remain with the same loss as before. - dist.reduce(loss, dst=0) - # On rank 0, average the loss sum by dividing it with the number of processes. - loss = loss.item() / dist.get_world_size() if dist.get_rank() == 0 else loss.item() + # Reduce the loss and average it on each rank. + loss = trainer.strategy.reduce(loss, reduce_op="mean") # Divide the accumulated loss by the number of steps in the epoch. loss /= self.epoch_step_counter[mode] outputs["loss"] = loss @@ -248,8 +237,9 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Reset the metrics for the next epoch. metrics.reset() - # Log. Done only on rank 0. - self._log(outputs, mode, is_epoch=True, global_step=self._get_global_step(trainer)) + # Log. Only on rank 0. + if trainer.is_global_zero: + self._log(outputs, mode, is_epoch=True, global_step=self._get_global_step(trainer)) def _get_global_step(self, trainer: Trainer) -> int: """Return the global step for the current mode. Note that when Trainer diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index 77347e88..ca51b2fb 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -32,7 +32,18 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: ) sys.exit() - self.write_dir.mkdir(parents=True) + # Broadcast the `write_dir` so that all ranks write their predictions there. + self.write_dir = trainer.strategy.broadcast(self.write_dir) + # Let rank 0 create the `write_dir`. + if trainer.is_global_zero: + self.write_dir.mkdir(parents=True) + # If `write_dir` does not exist, the ranks are not on the same storage. + if not self.write_dir.exists(): + logger.error( + f"Rank {trainer.global_rank} is not on the same storage as rank 0." + "Please run the prediction only on nodes that are on the same storage." + ) + sys.exit() def _write(self, outputs, indices): for identifier, data in parse_data(outputs): diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index 664d7336..a1dbfbce 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -58,7 +58,7 @@ def hasarg(_callable: Callable, arg_name: str) -> bool: return arg_name in args -def countargs(callable: Callable) -> bool: +def countargs(_callable: Callable) -> bool: """Count the number of arguments that a function, class, or method accepts. Args: @@ -67,10 +67,10 @@ def countargs(callable: Callable) -> bool: Returns: int: number of arguments that it accepts. """ - return len(inspect.signature(callable).parameters.keys()) + return len(inspect.signature(_callable).parameters.keys()) -def get_name(x: Callable, include_module_name: bool = False) -> str: +def get_name(_callable: Callable, include_module_name: bool = False) -> str: """Get the name of an object, class or function. Args: From a7078c4233c27ab32b16abfa91b58a9a4ed6f16d Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Sat, 4 Feb 2023 12:23:41 -0500 Subject: [PATCH 04/11] add support for multi-type write_as --- lighter/callbacks/logger.py | 2 +- lighter/callbacks/utils.py | 38 ++++++++++-------- lighter/callbacks/writer.py | 78 +++++++++++++++++++++++-------------- 3 files changed, 71 insertions(+), 47 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 0a709ef3..f7fe6e21 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -141,7 +141,7 @@ def _log(self, outputs: dict, mode: str, global_step: int, is_epoch=False) -> No elif data_type == "image": # Check if the data type is valid. check_supported_data_type(data, data_name) - for identifier, image in parse_data(data): + for identifier, image in parse_data(data).items(): name = name if identifier is None else f"{name}_{identifier}" # Slice to `max_samples` only if it less than the batch size. if self.max_samples is not None and self.max_samples < image.shape[0]: diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index a7938278..fd6c89ef 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -11,37 +11,41 @@ def parse_data( data: Union[ - torch.Tensor, - List[torch.Tensor], - Dict[str, torch.Tensor], - Dict[str, List[torch.Tensor]], - Dict[str, Tuple[torch.Tensor]], + Any, + List[Any], + Dict[str, Any], + Dict[str, List[Any]], + Dict[str, Tuple[Any]], ] -) -> List[Tuple[Optional[str], torch.Tensor]]: - """Given input data, this function will parse it and return a list of tuples where - each tuple contains an identifier and a tensor. +) -> Dict[Optional[str], Any]: + """Parse the input data as follows: + - If dict, go over all keys and values, unpacking list and tuples, and assigning them all + a unique identifier based on the original key and their position if they were a list/tuple. + - If list/tuple, enumerate them and use their position as key for each value of the list/tuple. + - If any other type, return it as-is with the key set to 'None'. A 'None' key indicates that no + identifier is needed because no parsing ocurred. Args: - data (Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], Dict[str, Tuple[torch.Tensor]]]): + data (Union[Any, List[Any], Dict[str, Any], Dict[str, List[Any]], Dict[str, Tuple[Any]]]): input data to parse. Returns: - List[Tuple[Optional[str], torch.Tensor]]: a list of tuples where the first element is the string - identifier (`None` if there is only one tensor), and the second is the actual tensor. + Dict[Optional[str], Any]: a dict where key is either a string + identifier or `None`, and value the parsed output. """ - result = [] + result = {} if isinstance(data, dict): for key, value in data.items(): if isinstance(value, (list, tuple)): - for i, tensor in enumerate(value): - result.append((f"{key}_{i}", tensor) if len(value > 1) else (key, tensor)) + for idx, singular in enumerate(value): + result[key] = f"{key}_{idx}", singular if len(value > 1) else key, singular else: result.append((key, value)) elif isinstance(data, (list, tuple)): - for i, tensor in enumerate(data): - result.append((str(i), tensor)) + for idx, singular in enumerate(data): + result[str(idx)] = singular else: - result.append((None, data)) + result[None] = data return result diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index ca51b2fb..a481dfcf 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -11,25 +11,31 @@ from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, concatenate, parse_data, preprocess_image +from lighter.callbacks.utils import concatenate, parse_data, preprocess_image class LighterWriter(Callback): - def __init__(self, write_dir: str, write_as: str, write_on: str = "step", write_to_csv: bool = False) -> None: + def __init__( + self, + write_dir: str, + write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]], + write_on: str = "step", + write_to_csv: bool = False, + ) -> None: self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") self.write_as = write_as self.write_on = write_on self.write_to_csv = write_to_csv + self.parsed_write_as = None + def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: if self.write_on not in ["step", "epoch"]: logger.error("`write_on` must be either 'step' or 'epoch'.") sys.exit() - if self.write_to_csv and self.write_as in ["image", "tensor"]: - logger.error( - f"`write_as={self.write_as}` cannot be written to a CSV. Change `write_as` or disable `write_to_csv`." - ) + if self.write_on != "epoch" and self.write_to_csv: + logger.error("`write_to_csv=True` supports `write_on='epoch'` only.") sys.exit() # Broadcast the `write_dir` so that all ranks write their predictions there. @@ -46,24 +52,44 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: sys.exit() def _write(self, outputs, indices): - for identifier, data in parse_data(outputs): - for idx, tensor in zip(indices, data): + parsed_outputs = parse_data(outputs) + parsed_write_as = self._parse_write_as(self.write_as, parsed_outputs) + for idx in indices: + for identifier in parsed_outputs: + # Unlike a list/tuple/dict of Tensors, a single Tensor has 'None' as identifier since it doesn't need one. name = f"step_{idx}" if identifier is None else f"step_{idx}_{identifier}" - if self.write_as == "tensor": - path = self.write_dir / f"{self.write_as}_{name}.pt" - torch.save(tensor, path) - elif self.write_as == "image": - path = self.write_dir / f"{self.write_as}_{name}.png" - torchvision.utils.save_image(preprocess_image(tensor), path) - elif self.write_as == "scalar": - raise NotImplementedError - elif self.write_as == "audio": - raise NotImplementedError - elif self.write_as == "video": - raise NotImplementedError - else: - logger.error(f"`write_as` does not support '{self.write_as}'.") + self._write_by_type(name, parsed_outputs[identifier], parsed_write_as[identifier]) + + def _write_by_type(self, name, tensor, write_as): + if write_as == "tensor": + path = self.write_dir / f"{name}_{write_as}.pt" + torch.save(tensor, path) + elif write_as == "image": + path = self.write_dir / f"{name}_{write_as}.png" + torchvision.io.write_png(preprocess_image(tensor), path) + elif write_as == "video": + path = self.write_dir / f"{name}_{write_as}.mp4" + torchvision.io.write_video(path, tensor, fps=24) + elif write_as == "scalar": + raise NotImplementedError + elif write_as == "audio": + raise NotImplementedError + else: + logger.error(f"`write_as` does not support '{write_as}'.") + sys.exit() + + def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]): + if self.parsed_write_as is None: + # If `write_as` is a string (single value), all outputs will be saved in that specified format. + if isinstance(write_as, str): + self.parsed_write_as = {key: write_as for key in parsed_outputs} + # Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its specified type. + else: + self.parsed_write_as = parse_data(write_as) + if not set(self.parsed_write_as) == set(parsed_outputs): + logger.error("`write_as` structure does not match the prediction's structure.") sys.exit() + return self.parsed_write_as def on_predict_batch_end( self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int @@ -76,18 +102,12 @@ def on_predict_batch_end( def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: if self.write_on != "epoch": return - # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. indices = trainer.predict_loop.epoch_batch_indices[0] outputs = outputs[0] - - # Concatenate/flatten into a list of indices. + # Concatenate/flatten so that each output corresponds to its index. indices = list(itertools.chain(*indices)) - # Concatenate/flatten the outputs so that each output corresponds to its index in `indices`. outputs = concatenate(outputs) - self._write(outputs, indices) - def on_predict_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Dump the CSV - pass From 3df027abc750be8a0371c3db4ac85c5f4b1d3a67 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Mon, 6 Feb 2023 21:42:02 -0500 Subject: [PATCH 05/11] Refactor to Base, File, and Table Writer --- lighter/callbacks/__init__.py | 2 +- lighter/callbacks/utils.py | 6 +- lighter/callbacks/writer.py | 184 +++++++++++++----- .../experiments/monai_bundle_prototype.yaml | 22 +-- 4 files changed, 149 insertions(+), 65 deletions(-) diff --git a/lighter/callbacks/__init__.py b/lighter/callbacks/__init__.py index 48a352b0..f3ddc5e8 100644 --- a/lighter/callbacks/__init__.py +++ b/lighter/callbacks/__init__.py @@ -1,2 +1,2 @@ from .logger import LighterLogger -from .writer import LighterWriter +from .writer import LighterFileWriter, LighterTableWriter diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index fd6c89ef..e570701a 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -40,7 +40,7 @@ def parse_data( for idx, singular in enumerate(value): result[key] = f"{key}_{idx}", singular if len(value > 1) else key, singular else: - result.append((key, value)) + result[key] = value elif isinstance(data, (list, tuple)): for idx, singular in enumerate(data): result[str(idx)] = singular @@ -63,9 +63,9 @@ def check_supported_data_type(data: Any, name: str) -> None: name (str): name of the data, for identification purposes. """ if isinstance(data, dict): - is_valid = all(check_supported_data_type(elem) for elem in data.values()) + is_valid = all(check_supported_data_type(elem, name) for elem in data.values()) elif isinstance(data, (list, tuple)): - is_valid = all(check_supported_data_type(elem) for elem in data) + is_valid = all(check_supported_data_type(elem, name) for elem in data) elif isinstance(data, torch.Tensor): is_valid = True else: diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index a481dfcf..bbd41c84 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import itertools import sys +from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path +import pandas as pd import torch import torchvision from loguru import logger @@ -14,30 +16,57 @@ from lighter.callbacks.utils import concatenate, parse_data, preprocess_image -class LighterWriter(Callback): +class LighterBaseWriter(ABC, Callback): def __init__( self, write_dir: str, - write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]], + write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], write_on: str = "step", - write_to_csv: bool = False, ) -> None: self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") self.write_as = write_as self.write_on = write_on - self.write_to_csv = write_to_csv self.parsed_write_as = None + @abstractmethod + def write( + self, + idx: int, + identifier: Optional[str], + tensor: torch.Tensor, + write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], + ): + """This method should be overridden to specify how a tensor should be saved. If the Writer + supports multiple types of saving, handle the `write_as` argument with an if-else statement. + + If the Writer only supports one type, `write_as` can be ignored and `write_as=None` can be + set in the overridden `__init__()` method. + + The `idx` and `identifier` arguments can be used to specify the name of the file + or the row and column of a table for the prediction. + + Parameters: + idx (int): The index of the prediction. + identifier (Optional[str]): The identifier of the prediction. It will be `None` if there's + only one prediction, an index if the prediction is a list of predictions, a key if it's + a dict of predictions, and a key_index if it's a dict of list of predictions. + tensor (torch.Tensor): The predicted tensor. + write_as (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): + Specifies how to write the predictions. If it's a single string value, the predictions + will be saved under that type regardless of whether they are single- or multi-output + predictions. To write different outputs in the multi-output predictions using different + methods, use the appropriate format for `write_as`. + """ + def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: + if stage != "predict": + return + if self.write_on not in ["step", "epoch"]: logger.error("`write_on` must be either 'step' or 'epoch'.") sys.exit() - if self.write_on != "epoch" and self.write_to_csv: - logger.error("`write_to_csv=True` supports `write_on='epoch'` only.") - sys.exit() - # Broadcast the `write_dir` so that all ranks write their predictions there. self.write_dir = trainer.strategy.broadcast(self.write_dir) # Let rank 0 create the `write_dir`. @@ -51,39 +80,43 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: ) sys.exit() - def _write(self, outputs, indices): + def on_predict_batch_end( + self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.write_on != "step": + return + indices = trainer.predict_loop.epoch_loop.current_batch_indices + self._on_batch_or_epoch_end(outputs, indices) + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: + if self.write_on != "epoch": + return + # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. + indices = trainer.predict_loop.epoch_batch_indices[0] + outputs = outputs[0] + # Concatenate/flatten so that each output corresponds to its index. + indices = list(itertools.chain(*indices)) + outputs = concatenate(outputs) + self._on_batch_or_epoch_end(outputs, indices) + + def _on_batch_or_epoch_end(self, outputs, indices): + # Parse the outputs into a structure ready for writing. parsed_outputs = parse_data(outputs) + # Parse `write_as`. If multi-value, check if its structure matches `parsed_output`'s structure. parsed_write_as = self._parse_write_as(self.write_as, parsed_outputs) + for idx in indices: - for identifier in parsed_outputs: - # Unlike a list/tuple/dict of Tensors, a single Tensor has 'None' as identifier since it doesn't need one. - name = f"step_{idx}" if identifier is None else f"step_{idx}_{identifier}" - self._write_by_type(name, parsed_outputs[identifier], parsed_write_as[identifier]) - - def _write_by_type(self, name, tensor, write_as): - if write_as == "tensor": - path = self.write_dir / f"{name}_{write_as}.pt" - torch.save(tensor, path) - elif write_as == "image": - path = self.write_dir / f"{name}_{write_as}.png" - torchvision.io.write_png(preprocess_image(tensor), path) - elif write_as == "video": - path = self.write_dir / f"{name}_{write_as}.mp4" - torchvision.io.write_video(path, tensor, fps=24) - elif write_as == "scalar": - raise NotImplementedError - elif write_as == "audio": - raise NotImplementedError - else: - logger.error(f"`write_as` does not support '{write_as}'.") - sys.exit() + for identifier in parsed_outputs: # pylint: disable=consider-using-dict-items + tensor = parsed_outputs[identifier] + write_as = parsed_write_as[identifier] + self.write(idx, identifier, tensor, write_as) def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]): if self.parsed_write_as is None: # If `write_as` is a string (single value), all outputs will be saved in that specified format. if isinstance(write_as, str): self.parsed_write_as = {key: write_as for key in parsed_outputs} - # Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its specified type. + # Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its type. else: self.parsed_write_as = parse_data(write_as) if not set(self.parsed_write_as) == set(parsed_outputs): @@ -91,23 +124,74 @@ def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]): sys.exit() return self.parsed_write_as - def on_predict_batch_end( - self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: - if self.write_on != "step": - return - indices = trainer.predict_loop.epoch_loop.current_batch_indices - self._write(outputs, indices) + +class LighterFileWriter(LighterBaseWriter): + def write(self, idx, identifier, tensor, write_as): + filename = f"{write_as}" if identifier is None else f"{identifier}_{write_as}" + write_dir = self.write_dir / str(idx) + write_dir.mkdir() + + if write_as is None: + pass + elif write_as == "tensor": + path = write_dir / f"{filename}.pt" + torch.save(tensor, path) + elif write_as == "image": + path = write_dir / f"{filename}.png" + torchvision.io.write_png(preprocess_image(tensor), path) + elif write_as == "video": + path = write_dir / f"{filename}.mp4" + torchvision.io.write_video(path, tensor, fps=24) + elif write_as == "scalar": + raise NotImplementedError + elif write_as == "audio": + raise NotImplementedError + else: + logger.error(f"`write_as` '{write_as}' not supported.") + sys.exit() + + +class LighterTableWriter(LighterBaseWriter): + def __init__(self, write_dir: str, write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: + super().__init__(write_dir, write_as, write_on="epoch") + self.csv_records = {} + + def write(self, idx, identifier, tensor, write_as): + # Column name will be set to 'pred' if the identifier is None. + column = "pred" if identifier is None else identifier + + if write_as is None: + record = None + elif write_as == "tensor": + record = tensor.tolist() + elif write_as == "scalar": + raise NotImplementedError + else: + logger.error(f"`write_as` '{write_as}' not supported.") + sys.exit() + + if idx not in self.csv_records: + self.csv_records[idx] = {column: record} + else: + self.csv_records[idx][column] = record def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - if self.write_on != "epoch": - return - # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. - indices = trainer.predict_loop.epoch_batch_indices[0] - outputs = outputs[0] - # Concatenate/flatten so that each output corresponds to its index. - indices = list(itertools.chain(*indices)) - outputs = concatenate(outputs) - self._write(outputs, indices) + super().on_predict_epoch_end(trainer, pl_module, outputs) + + csv_path = self.write_dir / "predictions.csv" + logger.info(f"Saving the predictions to {csv_path}") + + # Sort the dict of dicts by key and turn it into a list of dicts. + self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] + # Gather the records from all ranks when in DDP. + if trainer.world_size > 1: + # Since `all_gather` supports tensors only, mimic the behavior using `broadcast`. + ddp_csv_records = [self.csv_records] * trainer.world_size + for rank in range(trainer.world_size): + # Broadcast the records from the current rank and save it at its designated position. + ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) + # Combine the records from all ranks. List of lists of dicts -> list of dicts. + self.csv_records = list(itertools.chain(*ddp_csv_records)) - # Dump the CSV + # Create a dataframe and save it. + pd.DataFrame(self.csv_records).to_csv(csv_path) diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index bea11a02..f0015c8a 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -3,24 +3,24 @@ project: ./projects/cifar10 trainer: _target_: pytorch_lightning.Trainer max_epochs: 100 - accelerator: gpu - devices: 1 + accelerator: cpu + #devices: 1 # 2 # strategy: ddp log_every_n_steps: 10 logger: null callbacks: - - _target_: lighter.callbacks.LighterLogger - project: CIFAR10 - log_dir: "$@project + '/logs' " - tensorboard: True - wandb: True - input_type: image - max_samples: 10 + # - _target_: lighter.callbacks.LighterLogger + # project: CIFAR10 + # log_dir: "$@project + '/logs' " + # tensorboard: True + # wandb: True + # input_type: image + # max_samples: 10 - - _target_: lighter.callbacks.LighterWriter + - _target_: lighter.callbacks.LighterFileWriter write_dir: "$@project + '/predictions' " write_as: "tensor" - write_on: "epoch" # "step" + write_on: "step" # "epoch" system: _target_: lighter.LighterSystem From d9573e92e4bd5be83912a5bb850321ffd90fed66 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 7 Feb 2023 15:45:03 -0500 Subject: [PATCH 06/11] improve _parse_write_as, add more docstrings --- lighter/callbacks/logger.py | 8 ++-- lighter/callbacks/utils.py | 47 +++++++++++++++---- lighter/callbacks/writer.py | 51 +++++++++++++------- lighter/system.py | 93 +++++++++++++++++++------------------ pyproject.toml | 1 + 5 files changed, 124 insertions(+), 76 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index f7fe6e21..461c94a7 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -10,7 +10,7 @@ from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import LIGHTNING_TO_LIGHTER_STAGE, check_supported_data_type, parse_data, preprocess_image +from lighter.callbacks.utils import check_supported_data_type, get_lighter_mode, parse_data, preprocess_image OPTIONAL_IMPORTS = {} @@ -193,7 +193,7 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: trainer (Trainer): Trainer, passed automatically by PyTorch Lightning. """ if not trainer.sanity_checking: - mode = LIGHTNING_TO_LIGHTER_STAGE[trainer.state.stage] + mode = get_lighter_mode(trainer.state.stage) # Accumulate the loss. if mode in ["train", "val"]: self.loss[mode] += outputs["loss"].item() @@ -215,7 +215,7 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: pl_module (LighterSystem): LighterSystem, passed automatically by PyTorch Lightning. """ if not trainer.sanity_checking: - mode = LIGHTNING_TO_LIGHTER_STAGE[trainer.state.stage] + mode = get_lighter_mode(trainer.state.stage) outputs = {"loss": None, "metrics": None} # Loss @@ -253,7 +253,7 @@ def _get_global_step(self, trainer: Trainer) -> int: Returns: int: global step. """ - mode = LIGHTNING_TO_LIGHTER_STAGE[trainer.state.stage] + mode = get_lighter_mode(trainer.state.stage) # When validating in Trainer.fit(), return the train steps instead of the # val steps to correctly if mode == "val" and trainer.state.fn == "fit": diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index e570701a..bb654e08 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -6,7 +6,18 @@ import torchvision from loguru import logger -LIGHTNING_TO_LIGHTER_STAGE = {"train": "train", "validate": "val", "test": "test"} + +def get_lighter_mode(lightning_stage: str) -> str: + """Converts the name of a PyTorch Lightnig stage to the name of its corresponding Lighter mode. + + Args: + lightning_stage (str): stage in which PyTorch Lightning Trainer is. Can be accessed using `trainer.state.stage`. + + Returns: + str: name of the Lighter mode. + """ + lightning_to_lighter = {"train": "train", "validate": "val", "test": "test"} + return lightning_to_lighter[lightning_stage] def parse_data( @@ -79,20 +90,38 @@ def check_supported_data_type(data: Any, name: str) -> None: sys.exit() -def concatenate(outputs: Union[List[Any], Tuple[Any]]) -> Union[torch.Tensor, List[Union[str, int, float]]]: +def structure_preserving_concatenate( + inputs: Union[List[Any], Tuple[Any]] +) -> Union[torch.Tensor, List[Union[str, int, float]]]: + """Recursively concatenate tensors that are either on their own or inside of other data structures (list/tuple/dict). + An input list of tensors is reduced to a single concatenated tensor, while an input list of data structures with tensors + will be reduced to a single data structure with its tensors concatenated along the key or position. + + Assumes that all elements of the input list have the same type and structure. + + Args: + inputs (Union[List[Any], Tuple[Any]]): A list or tuple of either: + - Dictionaries, each containing tensors to be concatenated by key. + - Lists/tuples, each containing tensors to be concatenated by their position. + - Tensors, which are concatenated along the first dimension. + + Returns: + Union[torch.Tensor, List[Union[str, int, float]]]: The concatenated result in the same format as the input's elements. + """ # List of dicts. - if isinstance(outputs[0], dict): + if isinstance(inputs[0], dict): # Go over dictionaries and concatenate tensors by key. - result = {key: concatenate([output[key] for output in outputs]) for key in outputs[0]} + keys = inputs[0].keys() + result = {key: structure_preserving_concatenate([input[key] for input in inputs]) for key in keys} # List of lists or tuples. - elif isinstance(outputs[0], (list, tuple)): + elif isinstance(inputs[0], (list, tuple)): # Go over lists/tuples and concatenate tensors by their position. - result = [concatenate([output[idx] for output in outputs]) for idx in range(len(outputs[0]))] + result = [structure_preserving_concatenate([input[idx] for input in inputs]) for idx in range(len(inputs[0]))] # List of tensors. - elif isinstance(outputs[0], torch.Tensor): - result = torch.cat(outputs) + elif isinstance(inputs[0], torch.Tensor): + result = torch.cat(inputs) else: - logger.error(f"Type `{type(outputs[0])}` not supported.") + logger.error(f"Type `{type(inputs[0])}` not supported.") sys.exit() return result diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index bbd41c84..005fe2bd 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -17,6 +17,22 @@ class LighterBaseWriter(ABC, Callback): + """Base class for a Writer. Override `self.write()` to define how a prediction should be saved. + `LighterBaseWriter` sets up the write directory, and defines `on_predict_batch_end` and + `on_predict_epoch_end`. `write_on` specifies which of the two should the writer call. + + Args: + write_dir (str): the Writer will create a directory inside of `write_dir` with date + and time as its name and store the predictions there. + write_as (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): + type in which the predictions will be stored. Passed automatically to the `write()` + abstract method and can be used to support writing different types. Should the Writer + support only one type, this argument can be removed from the overriden `__init__()`'s + arguments and set `self.write_as = None`. + write_on (str, optional): whether to write on each step or at the end of the prediction epoch. + Defaults to "step". + """ + def __init__( self, write_dir: str, @@ -37,11 +53,11 @@ def write( tensor: torch.Tensor, write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], ): - """This method should be overridden to specify how a tensor should be saved. If the Writer + """This method must be overridden to specify how a tensor should be saved. If the Writer supports multiple types of saving, handle the `write_as` argument with an if-else statement. - If the Writer only supports one type, `write_as` can be ignored and `write_as=None` can be - set in the overridden `__init__()` method. + If the Writer only supports one type, remove `write_as` from the overridden + `__init__()` method and set `self.write_as=None`. The `idx` and `identifier` arguments can be used to specify the name of the file or the row and column of a table for the prediction. @@ -102,27 +118,28 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu def _on_batch_or_epoch_end(self, outputs, indices): # Parse the outputs into a structure ready for writing. parsed_outputs = parse_data(outputs) - # Parse `write_as`. If multi-value, check if its structure matches `parsed_output`'s structure. - parsed_write_as = self._parse_write_as(self.write_as, parsed_outputs) + # Runs only on the first step. + if self.parsed_write_as is None: + # Parse `self.write_as`. If multi-value, check if its structure matches `parsed_output`'s structure. + self.parsed_write_as = self._parse_write_as(self.write_as, parsed_outputs) for idx in indices: for identifier in parsed_outputs: # pylint: disable=consider-using-dict-items tensor = parsed_outputs[identifier] - write_as = parsed_write_as[identifier] + write_as = self.parsed_write_as[identifier] self.write(idx, identifier, tensor, write_as) def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]): - if self.parsed_write_as is None: - # If `write_as` is a string (single value), all outputs will be saved in that specified format. - if isinstance(write_as, str): - self.parsed_write_as = {key: write_as for key in parsed_outputs} - # Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its type. - else: - self.parsed_write_as = parse_data(write_as) - if not set(self.parsed_write_as) == set(parsed_outputs): - logger.error("`write_as` structure does not match the prediction's structure.") - sys.exit() - return self.parsed_write_as + # If `write_as` is a string (single value), all outputs will be saved in that specified format. + if isinstance(write_as, str): + parsed_write_as = {key: write_as for key in parsed_outputs} + # Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its type. + else: + parsed_write_as = parse_data(write_as) + if not set(parsed_write_as) == set(parsed_outputs): + logger.error("`write_as` structure does not match the prediction's structure.") + sys.exit() + return parsed_write_as class LighterFileWriter(LighterBaseWriter): diff --git a/lighter/system.py b/lighter/system.py index 96198da3..d037fc6e 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -17,6 +17,53 @@ class LighterSystem(pl.LightningModule): + """_summary_ + + Args: + model (Module): the model. + batch_size (int): batch size. + drop_last_batch (bool, optional): whether the last batch in the dataloader + should be dropped. Defaults to False. + num_workers (int, optional): number of dataloader workers. Defaults to 0. + pin_memory (bool, optional): whether to pin the dataloaders memory. Defaults to True. + optimizers (Optional[Union[Optimizer, List[Optimizer]]], optional): + a single or a list of optimizers. Defaults to None. + schedulers (Optional[Union[Callable, List[Callable]]], optional): + a single or a list of schedulers. Defaults to None. + criterion (Optional[Callable], optional): + criterion/loss function. Defaults to None. + cast_target_dtype_to (Optional[str], optional): whether to cast the target to the + specified type before calculating the loss. May be necessary for some criterions. + Defaults to None. + post_criterion_activation (Optional[str], optional): some criterions + (e.g. BCEWithLogitsLoss) require non-activated prediction for their calculaiton. + However, to calculate the metrics and log the data, it may be necessary to activate + the predictions. Defaults to None. + patch_based_inferer (Optional[Callable], optional): the patch based inferer needs to be + either a class with a `__call__` method or function that accepts two arguments - + first one is the input tensor, and the other one the model itself. It should + perform the inference over the patches and return the aggregated/averaged output. + Defaults to None. + train_metrics (Optional[Union[Metric, List[Metric]]], optional): training metric(s). + They have to be implemented using `torchmetrics`. Defaults to None. + val_metrics (Optional[Union[Metric, List[Metric]]], optional): validation metric(s). + They have to be implemented using `torchmetrics`. Defaults to None. + test_metrics (Optional[Union[Metric, List[Metric]]], optional): test metric(s). + They have to be implemented using `torchmetrics`. Defaults to None. + train_dataset (Optional[Union[Dataset, List[Dataset]]], optional): training dataset(s). + Defaults to None. + val_dataset (Optional[Union[Dataset, List[Dataset]]], optional): validation dataset(s). + Defaults to None. + test_dataset (Optional[Union[Dataset, List[Dataset]]], optional): test dataset(s). + Defaults to None. + predict_dataset (Optional[Union[Dataset, List[Dataset]]], optional): predict dataset(s). + Defaults to None. + train_sampler (Optional[Sampler], optional): training sampler(s). Defaults to None. + val_sampler (Optional[Sampler], optional): validation sampler(s). Defaults to None. + test_sampler (Optional[Sampler], optional): test sampler(s). Defaults to None. + predict_sampler (Optional[Sampler], optional): predict sampler(s). Defaults to None. + """ + def __init__( self, model: Module, @@ -42,52 +89,6 @@ def __init__( test_sampler: Optional[Sampler] = None, predict_sampler: Optional[Sampler] = None, ) -> None: - """_summary_ - - Args: - model (Module): the model. - batch_size (int): batch size. - drop_last_batch (bool, optional): whether the last batch in the dataloader - should be dropped. Defaults to False. - num_workers (int, optional): number of dataloader workers. Defaults to 0. - pin_memory (bool, optional): whether to pin the dataloaders memory. Defaults to True. - optimizers (Optional[Union[Optimizer, List[Optimizer]]], optional): - a single or a list of optimizers. Defaults to None. - schedulers (Optional[Union[Callable, List[Callable]]], optional): - a single or a list of schedulers. Defaults to None. - criterion (Optional[Callable], optional): - criterion/loss function. Defaults to None. - cast_target_dtype_to (Optional[str], optional): whether to cast the target to the - specified type before calculating the loss. May be necessary for some criterions. - Defaults to None. - post_criterion_activation (Optional[str], optional): some criterions - (e.g. BCEWithLogitsLoss) require non-activated prediction for their calculaiton. - However, to calculate the metrics and log the data, it may be necessary to activate - the predictions. Defaults to None. - patch_based_inferer (Optional[Callable], optional): the patch based inferer needs to be - either a class with a `__call__` method or function that accepts two arguments - - first one is the input tensor, and the other one the model itself. It should - perform the inference over the patches and return the aggregated/averaged output. - Defaults to None. - train_metrics (Optional[Union[Metric, List[Metric]]], optional): training metric(s). - They have to be implemented using `torchmetrics`. Defaults to None. - val_metrics (Optional[Union[Metric, List[Metric]]], optional): validation metric(s). - They have to be implemented using `torchmetrics`. Defaults to None. - test_metrics (Optional[Union[Metric, List[Metric]]], optional): test metric(s). - They have to be implemented using `torchmetrics`. Defaults to None. - train_dataset (Optional[Union[Dataset, List[Dataset]]], optional): training dataset(s). - Defaults to None. - val_dataset (Optional[Union[Dataset, List[Dataset]]], optional): validation dataset(s). - Defaults to None. - test_dataset (Optional[Union[Dataset, List[Dataset]]], optional): test dataset(s). - Defaults to None. - predict_dataset (Optional[Union[Dataset, List[Dataset]]], optional): predict dataset(s). - Defaults to None. - train_sampler (Optional[Sampler], optional): training sampler(s). Defaults to None. - val_sampler (Optional[Sampler], optional): validation sampler(s). Defaults to None. - test_sampler (Optional[Sampler], optional): test sampler(s). Defaults to None. - predict_sampler (Optional[Sampler], optional): predict sampler(s). Defaults to None. - """ super().__init__() # Bypass LightningModule's check for default methods. We define them in self.setup(). self._init_placeholders_for_dataloader_and_step_methods() diff --git a/pyproject.toml b/pyproject.toml index a97692b9..841fd362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,6 +177,7 @@ disable = """ too-many-arguments, not-callable """ +generated-members = "torch.*" [tool.pylint.master] fail-under=8 From 24319ef22063bb2bca02d8f60dfdd1366240c152 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 7 Feb 2023 20:35:57 -0500 Subject: [PATCH 07/11] small fix --- lighter/callbacks/writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index 005fe2bd..45324479 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -13,7 +13,7 @@ from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import concatenate, parse_data, preprocess_image +from lighter.callbacks.utils import parse_data, preprocess_image, structure_preserving_concatenate class LighterBaseWriter(ABC, Callback): @@ -112,7 +112,7 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu outputs = outputs[0] # Concatenate/flatten so that each output corresponds to its index. indices = list(itertools.chain(*indices)) - outputs = concatenate(outputs) + outputs = structure_preserving_concatenate(outputs) self._on_batch_or_epoch_end(outputs, indices) def _on_batch_or_epoch_end(self, outputs, indices): From adbfe54278010b6ad1b0b4572940142508106a84 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 28 Feb 2023 12:33:22 -0500 Subject: [PATCH 08/11] write_to to write_interval --- lighter/callbacks/writer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer.py index 45324479..008e6622 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer.py @@ -19,7 +19,7 @@ class LighterBaseWriter(ABC, Callback): """Base class for a Writer. Override `self.write()` to define how a prediction should be saved. `LighterBaseWriter` sets up the write directory, and defines `on_predict_batch_end` and - `on_predict_epoch_end`. `write_on` specifies which of the two should the writer call. + `on_predict_epoch_end`. `write_interval` specifies which of the two should the writer call. Args: write_dir (str): the Writer will create a directory inside of `write_dir` with date @@ -29,7 +29,7 @@ class LighterBaseWriter(ABC, Callback): abstract method and can be used to support writing different types. Should the Writer support only one type, this argument can be removed from the overriden `__init__()`'s arguments and set `self.write_as = None`. - write_on (str, optional): whether to write on each step or at the end of the prediction epoch. + write_interval (str, optional): whether to write on each step or at the end of the prediction epoch. Defaults to "step". """ @@ -37,11 +37,11 @@ def __init__( self, write_dir: str, write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - write_on: str = "step", + write_interval: str = "step", ) -> None: self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") self.write_as = write_as - self.write_on = write_on + self.write_interval = write_interval self.parsed_write_as = None @@ -79,8 +79,8 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: if stage != "predict": return - if self.write_on not in ["step", "epoch"]: - logger.error("`write_on` must be either 'step' or 'epoch'.") + if self.write_interval not in ["step", "epoch"]: + logger.error("`write_interval` must be either 'step' or 'epoch'.") sys.exit() # Broadcast the `write_dir` so that all ranks write their predictions there. @@ -99,13 +99,13 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: def on_predict_batch_end( self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.write_on != "step": + if self.write_interval != "step": return indices = trainer.predict_loop.epoch_loop.current_batch_indices self._on_batch_or_epoch_end(outputs, indices) def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - if self.write_on != "epoch": + if self.write_interval != "epoch": return # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. indices = trainer.predict_loop.epoch_batch_indices[0] @@ -170,7 +170,7 @@ def write(self, idx, identifier, tensor, write_as): class LighterTableWriter(LighterBaseWriter): def __init__(self, write_dir: str, write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: - super().__init__(write_dir, write_as, write_on="epoch") + super().__init__(write_dir, write_as, write_interval="epoch") self.csv_records = {} def write(self, idx, identifier, tensor, write_as): From cc224fe33974d888679d47e93d9c2f5c90cf58db Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 28 Feb 2023 13:08:51 -0500 Subject: [PATCH 09/11] Reorganize writers --- lighter/callbacks/__init__.py | 3 +- .../callbacks/{writer.py => writer/base.py} | 76 +------------------ lighter/callbacks/writer/file.py | 34 +++++++++ lighter/callbacks/writer/table.py | 57 ++++++++++++++ .../experiments/monai_bundle_prototype.yaml | 2 +- 5 files changed, 95 insertions(+), 77 deletions(-) rename lighter/callbacks/{writer.py => writer/base.py} (67%) create mode 100644 lighter/callbacks/writer/file.py create mode 100644 lighter/callbacks/writer/table.py diff --git a/lighter/callbacks/__init__.py b/lighter/callbacks/__init__.py index f3ddc5e8..78d3fcf5 100644 --- a/lighter/callbacks/__init__.py +++ b/lighter/callbacks/__init__.py @@ -1,2 +1,3 @@ from .logger import LighterLogger -from .writer import LighterFileWriter, LighterTableWriter +from .writer.file import LighterFileWriter +from .writer.table import LighterTableWriter diff --git a/lighter/callbacks/writer.py b/lighter/callbacks/writer/base.py similarity index 67% rename from lighter/callbacks/writer.py rename to lighter/callbacks/writer/base.py index 008e6622..76041e6b 100644 --- a/lighter/callbacks/writer.py +++ b/lighter/callbacks/writer/base.py @@ -6,14 +6,12 @@ from datetime import datetime from pathlib import Path -import pandas as pd import torch -import torchvision from loguru import logger from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import parse_data, preprocess_image, structure_preserving_concatenate +from lighter.callbacks.utils import parse_data, structure_preserving_concatenate class LighterBaseWriter(ABC, Callback): @@ -140,75 +138,3 @@ def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]): logger.error("`write_as` structure does not match the prediction's structure.") sys.exit() return parsed_write_as - - -class LighterFileWriter(LighterBaseWriter): - def write(self, idx, identifier, tensor, write_as): - filename = f"{write_as}" if identifier is None else f"{identifier}_{write_as}" - write_dir = self.write_dir / str(idx) - write_dir.mkdir() - - if write_as is None: - pass - elif write_as == "tensor": - path = write_dir / f"{filename}.pt" - torch.save(tensor, path) - elif write_as == "image": - path = write_dir / f"{filename}.png" - torchvision.io.write_png(preprocess_image(tensor), path) - elif write_as == "video": - path = write_dir / f"{filename}.mp4" - torchvision.io.write_video(path, tensor, fps=24) - elif write_as == "scalar": - raise NotImplementedError - elif write_as == "audio": - raise NotImplementedError - else: - logger.error(f"`write_as` '{write_as}' not supported.") - sys.exit() - - -class LighterTableWriter(LighterBaseWriter): - def __init__(self, write_dir: str, write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: - super().__init__(write_dir, write_as, write_interval="epoch") - self.csv_records = {} - - def write(self, idx, identifier, tensor, write_as): - # Column name will be set to 'pred' if the identifier is None. - column = "pred" if identifier is None else identifier - - if write_as is None: - record = None - elif write_as == "tensor": - record = tensor.tolist() - elif write_as == "scalar": - raise NotImplementedError - else: - logger.error(f"`write_as` '{write_as}' not supported.") - sys.exit() - - if idx not in self.csv_records: - self.csv_records[idx] = {column: record} - else: - self.csv_records[idx][column] = record - - def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - super().on_predict_epoch_end(trainer, pl_module, outputs) - - csv_path = self.write_dir / "predictions.csv" - logger.info(f"Saving the predictions to {csv_path}") - - # Sort the dict of dicts by key and turn it into a list of dicts. - self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] - # Gather the records from all ranks when in DDP. - if trainer.world_size > 1: - # Since `all_gather` supports tensors only, mimic the behavior using `broadcast`. - ddp_csv_records = [self.csv_records] * trainer.world_size - for rank in range(trainer.world_size): - # Broadcast the records from the current rank and save it at its designated position. - ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) - # Combine the records from all ranks. List of lists of dicts -> list of dicts. - self.csv_records = list(itertools.chain(*ddp_csv_records)) - - # Create a dataframe and save it. - pd.DataFrame(self.csv_records).to_csv(csv_path) diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py new file mode 100644 index 00000000..5a6ddf5a --- /dev/null +++ b/lighter/callbacks/writer/file.py @@ -0,0 +1,34 @@ +import sys + +import torch +import torchvision +from loguru import logger + +from lighter.callbacks.utils import preprocess_image +from lighter.callbacks.writer.base import LighterBaseWriter + + +class LighterFileWriter(LighterBaseWriter): + def write(self, idx, identifier, tensor, write_as): + filename = f"{write_as}" if identifier is None else f"{identifier}_{write_as}" + write_dir = self.write_dir / str(idx) + write_dir.mkdir() + + if write_as is None: + pass + elif write_as == "tensor": + path = write_dir / f"{filename}.pt" + torch.save(tensor, path) + elif write_as == "image": + path = write_dir / f"{filename}.png" + torchvision.io.write_png(preprocess_image(tensor), path) + elif write_as == "video": + path = write_dir / f"{filename}.mp4" + torchvision.io.write_video(path, tensor, fps=24) + elif write_as == "scalar": + raise NotImplementedError + elif write_as == "audio": + raise NotImplementedError + else: + logger.error(f"`write_as` '{write_as}' not supported.") + sys.exit() diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py new file mode 100644 index 00000000..ba0033a9 --- /dev/null +++ b/lighter/callbacks/writer/table.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, List, Union + +import itertools +import sys + +import pandas as pd +from loguru import logger +from pytorch_lightning import Trainer + +from lighter import LighterSystem +from lighter.callbacks.writer.base import LighterBaseWriter + + +class LighterTableWriter(LighterBaseWriter): + def __init__(self, write_dir: str, write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: + super().__init__(write_dir, write_as, write_interval="epoch") + self.csv_records = {} + + def write(self, idx, identifier, tensor, write_as): + # Column name will be set to 'pred' if the identifier is None. + column = "pred" if identifier is None else identifier + + if write_as is None: + record = None + elif write_as == "tensor": + record = tensor.tolist() + elif write_as == "scalar": + raise NotImplementedError + else: + logger.error(f"`write_as` '{write_as}' not supported.") + sys.exit() + + if idx not in self.csv_records: + self.csv_records[idx] = {column: record} + else: + self.csv_records[idx][column] = record + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: + super().on_predict_epoch_end(trainer, pl_module, outputs) + + csv_path = self.write_dir / "predictions.csv" + logger.info(f"Saving the predictions to {csv_path}") + + # Sort the dict of dicts by key and turn it into a list of dicts. + self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] + # Gather the records from all ranks when in DDP. + if trainer.world_size > 1: + # Since `all_gather` supports tensors only, mimic the behavior using `broadcast`. + ddp_csv_records = [self.csv_records] * trainer.world_size + for rank in range(trainer.world_size): + # Broadcast the records from the current rank and save it at its designated position. + ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) + # Combine the records from all ranks. List of lists of dicts -> list of dicts. + self.csv_records = list(itertools.chain(*ddp_csv_records)) + + # Create a dataframe and save it. + pd.DataFrame(self.csv_records).to_csv(csv_path) diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index f0015c8a..df11108a 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -20,7 +20,7 @@ trainer: - _target_: lighter.callbacks.LighterFileWriter write_dir: "$@project + '/predictions' " write_as: "tensor" - write_on: "step" # "epoch" + write_interval: "step" # "epoch" system: _target_: lighter.LighterSystem From fe38afa701e9a17bc8c7281e841f8e6a3d0c31d1 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Wed, 1 Mar 2023 20:09:43 -0500 Subject: [PATCH 10/11] Add init to writer package --- lighter/callbacks/writer/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lighter/callbacks/writer/__init__.py diff --git a/lighter/callbacks/writer/__init__.py b/lighter/callbacks/writer/__init__.py new file mode 100644 index 00000000..e69de29b From 3ffbff105baffbae6cb41642ad9a826996362415 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Thu, 2 Mar 2023 10:06:54 -0500 Subject: [PATCH 11/11] format --- lighter/callbacks/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index cedec168..d9a362cd 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -91,6 +91,7 @@ def check_supported_data_type(data: Any, name: str) -> None: return is_valid + def structure_preserving_concatenate( inputs: Union[List[Any], Tuple[Any]] ) -> Union[torch.Tensor, List[Union[str, int, float]]]: