From b24ab22bab148d5f9eba8d96173f0e9f2aa0c5cf Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Fri, 11 Aug 2023 14:56:03 -0400 Subject: [PATCH 01/20] Remove loss logging when predicting --- lighter/system.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/lighter/system.py b/lighter/system.py index e5f66761..f1435ac4 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -190,19 +190,11 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un # Calculate the loss. loss = self._calculate_loss(pred, target) if mode in ["train", "val"] else None - # Log the loss for monitoring purposes. - self.log( - "loss" if mode == "train" else f"{mode}_loss", - loss, - on_step=True, - on_epoch=True, - sync_dist=True, - logger=False, - batch_size=self.batch_size, - ) # Log and return the results. if mode == "predict": + # Pred postprocessing for logging or writing. + pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) return pred else: # Data postprocessing for metrics @@ -212,8 +204,19 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un # Calculate the metrics for the step. metrics = self.metrics[mode](pred, target) + # Log the metrics for monitoring purposes. self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True, logger=False, batch_size=self.batch_size) + # Log the loss for monitoring purposes. + self.log( + "loss" if mode == "train" else f"{mode}_loss", + loss, + on_step=True, + on_epoch=True, + sync_dist=True, + logger=False, + batch_size=self.batch_size, + ) # Data postprocessing for logging. input = apply_fns(input, self.postprocessing["logging"]["input"]) From 4a137217a393872d3c38beb9175cbaf897d25a28 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Sun, 13 Aug 2023 18:16:46 -0400 Subject: [PATCH 02/20] Add "id" for each batch sample ID-ing purposes. Refactor Writers, add easy extensibility for new formats --- lighter/callbacks/utils.py | 16 +- lighter/callbacks/writer/base.py | 269 +++++++++++++++++++----------- lighter/callbacks/writer/file.py | 137 +++++++++------ lighter/callbacks/writer/table.py | 96 ++++++++--- lighter/system.py | 41 +++-- 5 files changed, 359 insertions(+), 200 deletions(-) diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 757bce68..82e5034d 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -94,33 +94,33 @@ def parse_data( return result -def gather_tensors( +def group_tensors( inputs: Union[List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]] ) -> Union[List, Dict]: - """Recursively gather tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict). + """Recursively group tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict). An input list of tensors is returned as-is. Given an input list of data structures with tensors, this function - will gather all tensors into a list and save it under a single data structure. Assumes that all elements of + will group all tensors into a list and save it under a single data structure. Assumes that all elements of the input list have the same type and structure. Args: inputs (List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]): They can be: - - List/Tuples of Dictionaries, each containing tensors to be gathered by their key. - - List/Tuples of Lists/tuples, each containing tensors to be gathered by their position. + - List/Tuples of Dictionaries, each containing tensors to be grouped by their key. + - List/Tuples of Lists/tuples, each containing tensors to be grouped by their position. - List/Tuples of Tensors, returned as-is. - Nested versions of the above. The input data structure must be the same for all elements of the list. They can be arbitrarily nested. Returns: - Union[List, Dict]: The gathered tensors. + Union[List, Dict]: The grouped tensors. """ # List of dicts. if isinstance(inputs[0], dict): keys = inputs[0].keys() - return {key: gather_tensors([input[key] for input in inputs]) for key in keys} + return {key: group_tensors([input[key] for input in inputs]) for key in keys} # List of lists or tuples. elif isinstance(inputs[0], (list, tuple)): - return [gather_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))] + return [group_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))] # List of tensors. elif isinstance(inputs[0], torch.Tensor): return inputs diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index ab0ee79a..da1ff3c9 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import itertools import sys @@ -11,136 +11,201 @@ from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import gather_tensors, parse_data +from lighter.callbacks.utils import group_tensors, parse_data class LighterBaseWriter(ABC, Callback): - """Base class for a Writer. Override `self.write()` to define how a prediction should be saved. - `LighterBaseWriter` sets up the write directory, and defines `on_predict_batch_end` and - `on_predict_epoch_end`. `write_interval` specifies which of the two should the writer call. + """ + Base class for defining custom Writer. It provides the structure to save predictions in various formats. + + Subclasses should implement: + 1) `self._writers` property to specify the supported formats and their corresponding writer functions. + 2) `self.write()` method to specify the saving strategy for a prediction. Args: - write_dir (str): the Writer will create a directory inside of `write_dir` with date - and time as its name and store the predictions there. - write_format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): - type in which the predictions will be stored. Passed automatically to the `write()` - abstract method and can be used to support writing different types. Should the Writer - support only one type, this argument can be removed from the overriden `__init__()`'s - arguments and set `self.write_format = None`. - write_interval (str, optional): whether to write on each step or at the end of the prediction epoch. - Defaults to "step". + directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. + format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Desired format(s) for saving predictions. + The format will be passed to the `write` method. + interval (str, optional): Specifies when to save predictions - at every step or at the end of epoch. Defaults to "step". + additional_writers (Optional[Dict[str, Callable]]): Additional writer functions to be registered with the base writer. """ def __init__( self, - write_dir: str, - write_format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - write_interval: str = "step", + directory: str, + format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], + interval: str, + additional_writers: Optional[Dict[str, Callable]] = None, ) -> None: - self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") - self.write_format = write_format - self.write_interval = write_interval + # Create a unique directory using the current date and time + self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") + self.format = format + self.interval = interval + + # Placeholder for processed format for quicker access during writes + self.parsed_format = None + + # Keeps track of last written prediction index for cases when ids aren't provided + self.last_index = 0 + + # Ensure that default writers are defined + if not hasattr(self, "_writers"): + raise NotImplementedError("Subclasses of LighterBaseWriter must implement the `_writers` property.") - self.parsed_write_format = None + # Register any additional writers passed during initialization + if additional_writers: + for format, writer_function in additional_writers.items(): + self.add_writer(format, writer_function) @abstractmethod def write( self, - idx: int, - identifier: Optional[str], tensor: torch.Tensor, - write_format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - ): - """This method must be overridden to specify how a tensor should be saved. If the Writer - supports multiple types of saving, handle the `write_format` argument with an if-else statement. - - If the Writer only supports one type, remove `write_format` from the overridden - `__init__()` method and set `self.write_format=None`. - - The `idx` and `identifier` arguments can be used to specify the name of the file - or the row and column of a table for the prediction. - - Parameters: - idx (int): The index of the prediction. - identifier (Optional[str]): The identifier of the prediction. It will be `None` if there's - only one prediction, an index if the prediction is a list of predictions, a key if it's - a dict of predictions, and a key_index if it's a dict of list of predictions. - tensor (torch.Tensor): The predicted tensor. - write_format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): - Specifies how to write the predictions. If it's a single string value, the predictions - will be saved under that type regardless of whether they are single- or multi-output - predictions. To write different outputs in the multi-output predictions using different - methods, use the appropriate format for `write_format`. + id: int, + multi_pred_id: Optional[str], + format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], + ) -> None: + """ + Method to define how a tensor should be saved. + + Depending on the specified format, this method should contain logic to handle the saving mechanism. + + Args: + tensor (torch.Tensor): Tensor to be saved. + id (int): Identifier for the tensor, can be used for naming or indexing. + multi_pred_id (Optional[str]): Used when there are multiple predictions for a single input. + It can represent the index of a prediction, the key of a prediction in case of a dict, + or combined key and index for a dict of lists. + format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Format for saving the tensor. """ + pass def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: + """Callback for setup stage in Pytorch Lightning Trainer.""" if stage != "predict": return - if self.write_interval not in ["step", "epoch"]: - logger.error("`write_interval` must be either 'step' or 'epoch'.") - sys.exit() + # Validate the interval parameter + if self.interval not in ["step", "epoch"]: + raise ValueError("`interval` must be either 'step' or 'epoch'.") - # Broadcast the `write_dir` so that all ranks write their predictions there. - self.write_dir = trainer.strategy.broadcast(self.write_dir) - # Let rank 0 create the `write_dir`. + # Ensure all distributed nodes write to the same directory + self.directory = trainer.strategy.broadcast(self.directory) if trainer.is_global_zero: - self.write_dir.mkdir(parents=True) - # If `write_dir` does not exist, the ranks are not on the same storage. - if not self.write_dir.exists(): - logger.error( - f"Rank {trainer.global_rank} is not on the same storage as rank 0." - "Please run the prediction only on nodes that are on the same storage." + self.directory.mkdir(parents=True) + if not self.directory.exists(): + raise RuntimeError( + f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access." ) - sys.exit() def on_predict_batch_end( - self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int = 0 ) -> None: - if self.write_interval != "step": + """Callback method triggered at the end of each prediction batch/step.""" + if self.interval != "step": return - indices = trainer.predict_loop.epoch_loop.current_batch_indices - self._on_batch_or_epoch_end(outputs, indices) + + preds, ids = outputs["pred"], outputs["id"] + + # Generate IDs if not provided + if ids is None: + ids = list(range(self.last_index, self.last_index + len(preds))) + self.last_index += len(preds) + + self._on_batch_or_epoch_end(preds, ids) def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - if self.write_interval != "epoch": + """Callback method triggered at the end of the prediction epoch.""" + if self.interval != "epoch": return - # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. - indices = trainer.predict_loop.epoch_batch_indices[0] - outputs = outputs[0] - # Flatten so that each output sample corresponds to its index. - indices = list(itertools.chain(*indices)) - # Remove the batch dimension since every output is a single sample. - outputs = [output.squeeze(0) for output in outputs] - # Gather the output tensors from all samples into a single structure rather than having one structures for each sample. - outputs = gather_tensors(outputs) - self._on_batch_or_epoch_end(outputs, indices) - - def _on_batch_or_epoch_end(self, outputs, indices): - """Iterate through each output and save it in the specified format. The outputs and indices are automatically - split individually by PyTorch Lightning.""" - # Sanity check. Should never happen. If it does, something is wrong with the Trainer. - assert len(indices) == len(outputs) - # `idx` is the index of the input sample, `output` is the output of the model for that sample. - for idx, output in zip(indices, outputs): - # Parse the outputs into a structure ready for writing. - parsed_output = parse_data(output) - # Parse `self.write_format`. If multi-value, check if its structure matches `parsed_output`'s structure. - if self.parsed_write_format is None: - self.parsed_write_format = self._parse_write_format(self.write_format, parsed_output) - # Iterate through each prediction for the `idx`-th input sample. - for identifier, tensor in parsed_output.items(): - # Save the prediction in the specified format. - self.write(idx, identifier, tensor.detach().cpu(), self.parsed_write_format[identifier]) - - def _parse_write_format(self, write_format, parsed_outputs: Dict[str, Any]): - # If `write_format` is a string (single value), all outputs will be saved in that specified format. - if isinstance(write_format, str): - parsed_write_format = {key: write_format for key in parsed_outputs} - # Otherwise, `write_format` needs to match the structure of the outputs in order to assign each tensor its type. - else: - parsed_write_format = parse_data(write_format) - if not set(parsed_write_format) == set(parsed_outputs): - logger.error("`write_format` structure does not match the prediction's structure.") - sys.exit() - return parsed_write_format + # Only one epoch when predicting, select its outputs. + preds, ids = outputs["pred"][0], outputs["id"][0] + # Remove the batch dimension since every pred is a single sample. + preds = [pred.squeeze(0) for pred in preds] + # Group predictions from all samples into a unified structure. + preds = group_tensors(preds) + # If no ids provided, assign default sequential ids based on the prediction order. + if ids[0] is None: + ids = list(range(len(preds))) + self._on_batch_or_epoch_end(preds, ids) + + def _on_batch_or_epoch_end(self, preds, ids): + """ + Process each prediction at the end of either a batch or epoch and save in the defined format. + + Args: + preds: Predicted tensors. + ids: Corresponding identifiers for the predictions. + """ + # Sanity check to ensure matched lengths for predictions and ids. + assert len(ids) == len(preds) + for id, pred in zip(ids, preds): + # Convert predictions into a structured format suitable for writing. + parsed_pred = parse_data(pred) + # If the format hasn't been parsed yet, do it now. + if self.parsed_format is None: + self.parsed_format = parse_format(self.format, parsed_pred) + # If multiple outputs, parsed_pred will contain multiple keys. For a single output, key will be None. + for multi_pred_id, tensor in parsed_pred.items(): + # Save the prediction as per the designated format. + self.write(tensor.detach().cpu(), id, multi_pred_id, self.parsed_format[multi_pred_id]) + + def add_writer(self, format: str, writer_function: Callable) -> None: + """ + Register a new writer function for a specified format. + + Args: + format (str): Format type for which the writer is being registered. + writer_function (Callable): Function to write data in the given format. + + Raises: + ValueError: If a writer for the given format is already registered. + """ + if format in self._writers: + raise ValueError(f"Writer for format {format} already registered.") + self._writers[format] = writer_function + + def get_writer(self, format: str) -> Callable: + """ + Retrieve the registered writer function for a specified format. + + Args: + format (str): Format for which the writer function is needed. + + Returns: + Callable: Registered writer function for the given format. + + Raises: + ValueError: If no writer is registered for the specified format. + """ + if format not in self._writers: + raise ValueError(f"Writer for format {format} not registered.") + return self._writers[format] + + +def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]: + """ + Parse the given format and align it with the structure of the predictions. + + If the format is a single string, all predictions will be saved in this format. If the format has a structure + (like a dictionary), it needs to match the structure of the predictions. + + Args: + format (str): The storage format for the predictions, either as a string or a structured format. + parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data. + + Returns: + Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure. + + Raises: + ValueError: If the structure of the format does not align with the prediction structure. + """ + if isinstance(format, str): + # Assign the single format to all prediction keys. + parsed_format = {key: format for key in parsed_preds} + else: + # Ensure the structured format corresponds with the predictions' structure. + parsed_format = parse_data(format) + if not set(parsed_format) == set(parsed_preds): + raise ValueError("`format` structure does not match the prediction's structure.") + return parsed_format diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index 7536a0e8..bb82ed72 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -1,4 +1,7 @@ +from typing import Callable, Dict, Optional, Union + import sys +from pathlib import Path import torch import torchvision @@ -12,61 +15,99 @@ class LighterFileWriter(LighterBaseWriter): - def write(self, idx, identifier, tensor, write_format): - filename = f"{write_format}" if identifier is None else f"{identifier}_{write_format}" - write_dir = self.write_dir / str(idx) - write_dir.mkdir() - - if write_format is None: - pass - - # Tensor - elif write_format == "tensor": - path = write_dir / f"{filename}.pt" - torch.save(tensor, path) - - # Image - elif write_format == "image": - path = write_dir / f"{filename}.png" - torchvision.io.write_png(preprocess_image(tensor), path) - - # Video - elif write_format == "video": - path = write_dir / f"{filename}.mp4" - # Video tensor must be divisible by 2. Pad the height and width. - tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor) - # Video tensor must be THWC. Permute CTHW -> THWC. - tensor = tensor.permute(1, 2, 3, 0) - # Video tensor must have 3 channels (RGB). Repeat the channel dim to convert grayscale to RGB. - if tensor.shape[-1] == 1: - tensor = tensor.repeat(1, 1, 1, 3) - # Video tensor must be in the range [0, 1]. Scale to [0, 255]. - tensor = (tensor * 255).to(torch.uint8) - torchvision.io.write_video(str(path), tensor, fps=24) - - # Scalar - elif write_format == "scalar": - raise NotImplementedError - - # Audio - elif write_format == "audio": - raise NotImplementedError + """ + Writer for writing predictions to files. Supports multiple formats, and + additional custom formats can be added either through `additional_writers` + argument at initialization, or by calling `add_writer` method after initialization. + + Args: + directory (Union[str, Path]): The directory where the files should be written. + format (str): The format in which the files should be saved. + interval (str): Interval for writing, e.g., "epoch", "batch". + additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions. + """ + + def __init__( + self, directory: Union[str, Path], format: str, interval: str, additional_writers: Optional[Dict[str, Callable]] = None + ) -> None: + # Predefined writers for different formats. + self._writers = { + "tensor": write_tensor, + "image": write_image, + "video": write_video, + "sitk_nrrd": write_sitk_nrrd, + "sitk_seg_nrrd": write_seg_nrrd, + "sitk_nifti": write_sitk_nifti, + } + # Initialize the base class. + super().__init__(directory, format, interval, additional_writers) + def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: + """ + Write the tensor to the specified path in the given format. + + If there are multiple predictions, a directory named `id` is created, and each file is named + after `multi_pred_id`. If there's a single prediction, the file is named after `id`. + + Args: + tensor (Tensor): The tensor to be written. + id (Union[int, str]): The primary identifier for naming. + multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions. + format (str): Format in which tensor should be written. + """ + # Determine the path for the file based on prediction count. + if multi_pred_id is not None: + path = self.directory / str(id) / str(multi_pred_id) else: - logger.error(f"`write_format` '{write_format}' not supported.") - sys.exit() + path = self.directory / str(id) + # Ensure the directory exists. + path.parent.mkdir(exist_ok=True, parents=True) + # Fetch the appropriate writer function for the format. + writer = self.get_writer(format) + # Write the tensor to the file. + writer(path, tensor) -def write_sitk_image(path: str, tensor: torch.Tensor) -> None: - """Write a SimpleITK image to disk. +def write_tensor(path, tensor): + torch.save(tensor, path.with_suffix(".pt")) - Args: - path (str): path to write the image. - tensor (torch.Tensor): tensor to write. - """ + +def write_image(path, tensor): + torchvision.io.write_png(preprocess_image(tensor), path.with_suffix(".png")) + + +def write_video(path, tensor): + # Video tensor must be divisible by 2. Pad the height and width. + tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor) + # Video tensor must be THWC. Permute CTHW -> THWC. + tensor = tensor.permute(1, 2, 3, 0) + # Video tensor must have 3 channels (RGB). Repeat the channel dim to convert grayscale to RGB. + if tensor.shape[-1] == 1: + tensor = tensor.repeat(1, 1, 1, 3) + # Video tensor must be in the range [0, 1]. Scale to [0, 255]. + tensor = (tensor * 255).to(torch.uint8) + torchvision.io.write_video(str(path.with_suffix(".mp4")), tensor, fps=24) + + +def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None: if "sitk" not in OPTIONAL_IMPORTS: OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK") if not sitk_available: raise ModuleNotFoundError("SimpleITK not installed. To install it, run `pip install SimpleITK`. Exiting.") + + # Remove the channel dimension if it's equal to 1. + tensor = tensor.squeeze(0) if (tensor.dim() == 4 and tensor.shape[0] == 1) else tensor sitk_image = OPTIONAL_IMPORTS["sitk"].GetImageFromArray(tensor.cpu().numpy()) - OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path), True) + OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path.with_suffix(".nrrd")), True) + + +def write_sitk_nrrd(path, tensor): + _write_sitk_image(path, tensor, suffix=".nrrd") + + +def write_seg_nrrd(path, tensor): + _write_sitk_image(path, tensor, suffix=".seg.nrrd") + + +def write_sitk_nifti(path, tensor): + _write_sitk_image(path, tensor, suffix=".nii.gz") diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 5e490a69..017580b5 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Union import itertools import sys +from pathlib import Path import pandas as pd from loguru import logger @@ -12,46 +13,91 @@ class LighterTableWriter(LighterBaseWriter): - def __init__(self, write_dir: str, write_format: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: - super().__init__(write_dir, write_format, write_interval="epoch") + """ + Writer for saving predictions in a table format. Supports multiple formats, and + additional custom formats can be added either through `additional_writers` + argument at initialization, or by calling `add_writer` method after initialization. + + Args: + directory (Path): The directory where the CSV will be saved. + format (str): The format in which the data should be saved in the CSV. + additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions. + """ + + def __init__( + self, directory: Union[str, Path], format: str, additional_writers: Optional[Dict[str, Callable]] = None + ) -> None: + # Predefined writers for different formats. + self._writers = { + "tensor": write_tensor, + } + + # Initialize the base class. + super().__init__(directory, format, "epoch", additional_writers) + + # Create a dictionary to hold CSV records for each ID. self.csv_records = {} - def write(self, idx, identifier, tensor, write_format): - # Column name will be set to 'pred' if the identifier is None. - column = "pred" if identifier is None else identifier + def write(self, tensor: Any, format: str, id: Union[int, str], multi_pred_id: Optional[Union[int, str]]) -> None: + """ + Write the tensor as a table record in the given format. - if write_format is None: - record = None - elif write_format == "tensor": - record = tensor.tolist() - elif write_format == "scalar": - raise NotImplementedError - else: - logger.error(f"`write_format` '{write_format}' not supported.") - sys.exit() + If there are multiple predictions, there will be a separate column for each prediction, + named after the corresponding `multi_pred_id`. + If single prediction, there will be a single column named "pred". + + Args: + tensor (Any): The tensor to be written. + id (Union[int, str]): The primary identifier for naming. + multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions. + format (str): Format in which tensor should be written. + """ + # Determine the column name based on the presence of multi_pred_id + column = "pred" if multi_pred_id is None else multi_pred_id + + # Get the appropriate writer function for the given format + writer = self.get_writer(format) - if idx not in self.csv_records: - self.csv_records[idx] = {column: record} + # Convert the tensor to the desired format (e.g., list) + record = writer(tensor) + + # Store the record in the csv_records dictionary under the specified ID and column + if id not in self.csv_records: + self.csv_records[id] = {column: record} else: - self.csv_records[idx][column] = record + self.csv_records[id][column] = record def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: + """ + Callback method triggered at the end of the prediction epoch to dump the CSV table. + + Args: + trainer (Trainer): Pytorch Lightning Trainer instance. + pl_module (LighterSystem): Lighter system instance. + outputs (List[Any]): List of predictions. + """ + # Call the parent class's method to handle additional end-of-epoch logic super().on_predict_epoch_end(trainer, pl_module, outputs) - csv_path = self.write_dir / "predictions.csv" + # Set the path where the CSV will be saved + csv_path = self.directory / "predictions.csv" + + # Log the save path for user's reference logger.info(f"Saving the predictions to {csv_path}") - # Sort the dict of dicts by key and turn it into a list of dicts. + # Sort the records by ID and convert the dictionary to a list self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] - # Gather the records from all ranks when in DDP. + + # If in distributed data parallel mode, gather records from all processes if trainer.world_size > 1: - # Since `all_gather` supports tensors only, mimic the behavior using `broadcast`. ddp_csv_records = [self.csv_records] * trainer.world_size for rank in range(trainer.world_size): - # Broadcast the records from the current rank and save it at its designated position. ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) - # Combine the records from all ranks. List of lists of dicts -> list of dicts. self.csv_records = list(itertools.chain(*ddp_csv_records)) - # Create a dataframe and save it. + # Convert the list of records to a dataframe and save it as a CSV file pd.DataFrame(self.csv_records).to_csv(csv_path) + + +def write_tensor(tensor: Any) -> List: + return tensor.tolist() diff --git a/lighter/system.py b/lighter/system.py index f1435ac4..4843e5dd 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -155,27 +155,34 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un For predict step, it returns pred only. """ - # Ensure that the batch is a list, a tuple, or a dict. - if not isinstance(batch, (list, tuple, dict)): + # Batch type check: + # - Dict: must contain "input" and "target" keys, and optionally "id" key. + if isinstance(batch, dict): + if set(batch.keys()) not in [{"input", "target"}, {"input", "target", "id"}]: + raise ValueError( + "A batch dict must have 'input', 'target', and, " + f"optionally 'id', as keys, but found {list(batch.keys())}" + ) + batch["id"] = None if "id" not in batch else batch["id"] + # - List/tuple: must contain two elements - input and target. After the check, convert it to dict. + elif isinstance(batch, (list, tuple)): + if len(batch) != 2: + raise ValueError( + f"A batch must consist of 2 elements - input and target. However, {len(batch)} " + "elements wer found. Note: if target does not exist, return `None` as target." + ) + batch = {"input": batch[0], "target": batch[1], "id": None} + # - Other types are not allowed. + else: raise TypeError( "A batch must be a list, a tuple, or a dict." - "A batch dict must have 'input' and 'target' as keys." + "A batch dict must have 'input' and 'target' keys, and optionally 'id'." "A batch list or a tuple must have 2 elements - input and target." "If target does not exist, return `None` as target." ) - # Ensure that a dict batch has input and target keys exclusively. - if isinstance(batch, dict) and set(batch.keys()) != {"input", "target"}: - raise ValueError("A batch must be a dict with 'input' and 'target' as keys.") - # Ensure that a list/tuple batch has 2 elements (input and target). - if len(batch) == 1: - raise ValueError( - "A batch must consist of 2 elements - input and target. If target does not exist, return `None` as target." - ) - if len(batch) > 2: - raise ValueError(f"A batch must consist of 2 elements - input and target, but found {len(batch)} elements.") - # Split the batch into input and target. - input, target = batch if not isinstance(batch, dict) else (batch["input"], batch["target"]) + # Split the batch into input, target, and id. + input, target, id = batch["input"], batch["target"], batch["id"] # Forward if self.inferer and mode in ["val", "test", "predict"]: @@ -195,7 +202,7 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un if mode == "predict": # Pred postprocessing for logging or writing. pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) - return pred + return {"pred": pred, "id": id} else: # Data postprocessing for metrics input = apply_fns(input, self.postprocessing["metrics"]["input"]) @@ -224,7 +231,7 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) # Return the loss, metrics, input, target, and pred. - return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred} + return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id} def _calculate_loss( self, pred: Union[torch.Tensor, List, Tuple, Dict], target: Union[torch.Tensor, List, Tuple, Dict, None] From 2ea66bdbc99db2b69e97509330ba4b28ba914106 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Mon, 14 Aug 2023 13:20:51 -0400 Subject: [PATCH 03/20] Remove interval arg, group_tensors fn, and on pred epoch end writing. Add decollate_batch when writing. --- lighter/callbacks/utils.py | 53 ++++++----- lighter/callbacks/writer/base.py | 90 ++++--------------- lighter/callbacks/writer/file.py | 24 ++--- lighter/callbacks/writer/table.py | 13 ++- .../experiments/monai_bundle_prototype.yaml | 5 +- 5 files changed, 61 insertions(+), 124 deletions(-) diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 82e5034d..a5af8781 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -94,51 +94,48 @@ def parse_data( return result -def group_tensors( - inputs: Union[List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]] -) -> Union[List, Dict]: - """Recursively group tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict). - An input list of tensors is returned as-is. Given an input list of data structures with tensors, this function - will group all tensors into a list and save it under a single data structure. Assumes that all elements of - the input list have the same type and structure. +def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]: + """ + Parse the given format and align it with the structure of the predictions. + + If the format is a single string, all predictions will be saved in this format. If the format has a structure + (like a dictionary), it needs to match the structure of the predictions. Args: - inputs (List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]): - They can be: - - List/Tuples of Dictionaries, each containing tensors to be grouped by their key. - - List/Tuples of Lists/tuples, each containing tensors to be grouped by their position. - - List/Tuples of Tensors, returned as-is. - - Nested versions of the above. - The input data structure must be the same for all elements of the list. They can be arbitrarily nested. + format (str): The storage format for the predictions, either as a string or a structured format. + parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data. Returns: - Union[List, Dict]: The grouped tensors. + Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure. + + Raises: + ValueError: If the structure of the format does not align with the prediction structure. """ - # List of dicts. - if isinstance(inputs[0], dict): - keys = inputs[0].keys() - return {key: group_tensors([input[key] for input in inputs]) for key in keys} - # List of lists or tuples. - elif isinstance(inputs[0], (list, tuple)): - return [group_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))] - # List of tensors. - elif isinstance(inputs[0], torch.Tensor): - return inputs + if isinstance(format, str): + # Assign the single format to all prediction keys. + parsed_format = {key: format for key in parsed_preds} else: - raise TypeError(f"Type `{type(inputs[0])}` not supported.") + # Ensure the structured format corresponds with the predictions' structure. + parsed_format = parse_data(format) + if not set(parsed_format) == set(parsed_preds): + raise ValueError("`format` structure does not match the prediction's structure.") + return parsed_format -def preprocess_image(image: torch.Tensor) -> torch.Tensor: +def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor: """Preprocess the image before logging it. If it is a batch of multiple images, it will create a grid image of them. In case of 3D, a single image is displayed with slices stacked vertically, while a batch of 3D images as a grid where each column is a different 3D image. Args: image (torch.Tensor): 2D or 3D image tensor. + add_batch_dim (bool, optional): Whether to add a batch dimension to the input image. + Use only when the input image does not have a batch dimension. Defaults to False. Returns: torch.Tensor: image ready for logging. """ - image = image.detach().cpu() + if add_batch_dim: + image = image.unsqueeze(0) # If 3D (BCDHW), concat the images vertically and horizontally. if image.ndim == 5: shape = image.shape diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index da1ff3c9..5b7a20c0 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -8,10 +8,11 @@ import torch from loguru import logger +from monai.data.utils import decollate_batch from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import group_tensors, parse_data +from lighter.callbacks.utils import parse_data, parse_format class LighterBaseWriter(ABC, Callback): @@ -19,14 +20,13 @@ class LighterBaseWriter(ABC, Callback): Base class for defining custom Writer. It provides the structure to save predictions in various formats. Subclasses should implement: - 1) `self._writers` property to specify the supported formats and their corresponding writer functions. + 1) `self._writers` attribute to specify the supported formats and their corresponding writer functions. 2) `self.write()` method to specify the saving strategy for a prediction. Args: directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Desired format(s) for saving predictions. The format will be passed to the `write` method. - interval (str, optional): Specifies when to save predictions - at every step or at the end of epoch. Defaults to "step". additional_writers (Optional[Dict[str, Callable]]): Additional writer functions to be registered with the base writer. """ @@ -34,13 +34,11 @@ def __init__( self, directory: str, format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - interval: str, additional_writers: Optional[Dict[str, Callable]] = None, ) -> None: # Create a unique directory using the current date and time self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") self.format = format - self.interval = interval # Placeholder for processed format for quicker access during writes self.parsed_format = None @@ -66,12 +64,14 @@ def write( format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], ) -> None: """ - Method to define how a tensor should be saved. + Method to define how a tensor should be saved. The input tensor will be a single tensor without + the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving, + either in this method or in the particular writer function. Depending on the specified format, this method should contain logic to handle the saving mechanism. Args: - tensor (torch.Tensor): Tensor to be saved. + tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension. id (int): Identifier for the tensor, can be used for naming or indexing. multi_pred_id (Optional[str]): Used when there are multiple predictions for a single input. It can represent the index of a prediction, the key of a prediction in case of a dict, @@ -85,10 +85,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: if stage != "predict": return - # Validate the interval parameter - if self.interval not in ["step", "epoch"]: - raise ValueError("`interval` must be either 'step' or 'epoch'.") - # Ensure all distributed nodes write to the same directory self.directory = trainer.strategy.broadcast(self.directory) if trainer.is_global_zero: @@ -102,43 +98,17 @@ def on_predict_batch_end( self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int = 0 ) -> None: """Callback method triggered at the end of each prediction batch/step.""" - if self.interval != "step": - return - - preds, ids = outputs["pred"], outputs["id"] - - # Generate IDs if not provided - if ids is None: + # Fetch and decollate preds. + preds = decollate_batch(outputs["pred"], detach=True, pad=False) + # Fetch and decollate IDs if provided. + if outputs["id"] is not None: + ids = decollate_batch(outputs["id"], detach=True, pad=False) + # Generate IDs if not provided. An ID will be the index of the prediction. + else: ids = list(range(self.last_index, self.last_index + len(preds))) self.last_index += len(preds) - self._on_batch_or_epoch_end(preds, ids) - - def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - """Callback method triggered at the end of the prediction epoch.""" - if self.interval != "epoch": - return - # Only one epoch when predicting, select its outputs. - preds, ids = outputs["pred"][0], outputs["id"][0] - # Remove the batch dimension since every pred is a single sample. - preds = [pred.squeeze(0) for pred in preds] - # Group predictions from all samples into a unified structure. - preds = group_tensors(preds) - # If no ids provided, assign default sequential ids based on the prediction order. - if ids[0] is None: - ids = list(range(len(preds))) - self._on_batch_or_epoch_end(preds, ids) - - def _on_batch_or_epoch_end(self, preds, ids): - """ - Process each prediction at the end of either a batch or epoch and save in the defined format. - - Args: - preds: Predicted tensors. - ids: Corresponding identifiers for the predictions. - """ - # Sanity check to ensure matched lengths for predictions and ids. - assert len(ids) == len(preds) + # Iterate over the predictions and save them. for id, pred in zip(ids, preds): # Convert predictions into a structured format suitable for writing. parsed_pred = parse_data(pred) @@ -148,7 +118,7 @@ def _on_batch_or_epoch_end(self, preds, ids): # If multiple outputs, parsed_pred will contain multiple keys. For a single output, key will be None. for multi_pred_id, tensor in parsed_pred.items(): # Save the prediction as per the designated format. - self.write(tensor.detach().cpu(), id, multi_pred_id, self.parsed_format[multi_pred_id]) + self.write(tensor, id, multi_pred_id, format=self.parsed_format[multi_pred_id]) def add_writer(self, format: str, writer_function: Callable) -> None: """ @@ -181,31 +151,3 @@ def get_writer(self, format: str) -> Callable: if format not in self._writers: raise ValueError(f"Writer for format {format} not registered.") return self._writers[format] - - -def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]: - """ - Parse the given format and align it with the structure of the predictions. - - If the format is a single string, all predictions will be saved in this format. If the format has a structure - (like a dictionary), it needs to match the structure of the predictions. - - Args: - format (str): The storage format for the predictions, either as a string or a structured format. - parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data. - - Returns: - Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure. - - Raises: - ValueError: If the structure of the format does not align with the prediction structure. - """ - if isinstance(format, str): - # Assign the single format to all prediction keys. - parsed_format = {key: format for key in parsed_preds} - else: - # Ensure the structured format corresponds with the predictions' structure. - parsed_format = parse_data(format) - if not set(parsed_format) == set(parsed_preds): - raise ValueError("`format` structure does not match the prediction's structure.") - return parsed_format diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index bb82ed72..689d516b 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -23,12 +23,11 @@ class LighterFileWriter(LighterBaseWriter): Args: directory (Union[str, Path]): The directory where the files should be written. format (str): The format in which the files should be saved. - interval (str): Interval for writing, e.g., "epoch", "batch". additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions. """ def __init__( - self, directory: Union[str, Path], format: str, interval: str, additional_writers: Optional[Dict[str, Callable]] = None + self, directory: Union[str, Path], format: str, additional_writers: Optional[Dict[str, Callable]] = None ) -> None: # Predefined writers for different formats. self._writers = { @@ -40,7 +39,7 @@ def __init__( "sitk_nifti": write_sitk_nifti, } # Initialize the base class. - super().__init__(directory, format, interval, additional_writers) + super().__init__(directory, format, additional_writers) def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: """ @@ -55,12 +54,8 @@ def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Option multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions. format (str): Format in which tensor should be written. """ - # Determine the path for the file based on prediction count. - if multi_pred_id is not None: - path = self.directory / str(id) / str(multi_pred_id) - else: - path = self.directory / str(id) - # Ensure the directory exists. + # Determine the path for the file based on prediction count. The suffix must be added by the writer function. + path = self.directory / str(id) if multi_pred_id is None else self.directory / str(id) / str(multi_pred_id) path.parent.mkdir(exist_ok=True, parents=True) # Fetch the appropriate writer function for the format. writer = self.get_writer(format) @@ -73,10 +68,13 @@ def write_tensor(path, tensor): def write_image(path, tensor): - torchvision.io.write_png(preprocess_image(tensor), path.with_suffix(".png")) + path = path.with_suffix(".png") + tensor = preprocess_image(tensor, add_batch_dim=True) + torchvision.io.write_png(tensor, path) def write_video(path, tensor): + path = path.with_suffix(".mp4") # Video tensor must be divisible by 2. Pad the height and width. tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor) # Video tensor must be THWC. Permute CTHW -> THWC. @@ -86,10 +84,12 @@ def write_video(path, tensor): tensor = tensor.repeat(1, 1, 1, 3) # Video tensor must be in the range [0, 1]. Scale to [0, 255]. tensor = (tensor * 255).to(torch.uint8) - torchvision.io.write_video(str(path.with_suffix(".mp4")), tensor, fps=24) + torchvision.io.write_video(str(path), tensor, fps=24) def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None: + path = path.with_suffix(suffix) + if "sitk" not in OPTIONAL_IMPORTS: OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK") if not sitk_available: @@ -98,7 +98,7 @@ def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None: # Remove the channel dimension if it's equal to 1. tensor = tensor.squeeze(0) if (tensor.dim() == 4 and tensor.shape[0] == 1) else tensor sitk_image = OPTIONAL_IMPORTS["sitk"].GetImageFromArray(tensor.cpu().numpy()) - OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path.with_suffix(".nrrd")), True) + OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path), useCompression=True) def write_sitk_nrrd(path, tensor): diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 017580b5..4fee0f08 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -33,12 +33,14 @@ def __init__( } # Initialize the base class. - super().__init__(directory, format, "epoch", additional_writers) + super().__init__(directory, format, additional_writers) - # Create a dictionary to hold CSV records for each ID. + # Create a dictionary to hold CSV records for each ID. These are populated at each batch end + # by `self.on_predict_batch_end` defined in the base class using the `write` method below. + # Finally, the records are dumped to a CSV file at the end of the epoch by `self.on_predict_epoch_end`. self.csv_records = {} - def write(self, tensor: Any, format: str, id: Union[int, str], multi_pred_id: Optional[Union[int, str]]) -> None: + def write(self, tensor: Any, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: """ Write the tensor as a table record in the given format. @@ -67,7 +69,7 @@ def write(self, tensor: Any, format: str, id: Union[int, str], multi_pred_id: Op else: self.csv_records[id][column] = record - def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """ Callback method triggered at the end of the prediction epoch to dump the CSV table. @@ -76,9 +78,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu pl_module (LighterSystem): Lighter system instance. outputs (List[Any]): List of predictions. """ - # Call the parent class's method to handle additional end-of-epoch logic - super().on_predict_epoch_end(trainer, pl_module, outputs) - # Set the path where the CSV will be saved csv_path = self.directory / "predictions.csv" diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index 9c976132..63f25a7c 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -20,9 +20,8 @@ trainer: max_samples: 10 - _target_: lighter.callbacks.LighterFileWriter - write_dir: "$@project + '/predictions' " - write_format: "tensor" - write_interval: "step" # "epoch" + directory: "$@project + '/predictions' " + format: "tensor" system: _target_: lighter.LighterSystem From ffc26ae7a0176add6f5fe3e7999ac1f7e966ef54 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Mon, 14 Aug 2023 20:38:30 -0400 Subject: [PATCH 04/20] Small fixes --- lighter/callbacks/writer/base.py | 23 ++++++++++++----------- lighter/callbacks/writer/file.py | 2 -- lighter/callbacks/writer/table.py | 5 ----- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 5b7a20c0..aa6c7cf8 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -25,8 +25,8 @@ class LighterBaseWriter(ABC, Callback): Args: directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. - format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Desired format(s) for saving predictions. - The format will be passed to the `write` method. + format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): + Desired format(s) for saving predictions. The format will be passed to the `write` method. additional_writers (Optional[Dict[str, Callable]]): Additional writer functions to be registered with the base writer. """ @@ -41,10 +41,10 @@ def __init__( self.format = format # Placeholder for processed format for quicker access during writes - self.parsed_format = None + self._parsed_format = None # Keeps track of last written prediction index for cases when ids aren't provided - self.last_index = 0 + self._current_pred_index = 0 # Ensure that default writers are defined if not hasattr(self, "_writers"): @@ -68,7 +68,8 @@ def write( the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving, either in this method or in the particular writer function. - Depending on the specified format, this method should contain logic to handle the saving mechanism. + For each supported format, there should be a corresponding writer function registered in `self._writers`, + and can be retrieved using `self.get_writer(format)`. Args: tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension. @@ -95,7 +96,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: ) def on_predict_batch_end( - self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int = 0 + self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Callback method triggered at the end of each prediction batch/step.""" # Fetch and decollate preds. @@ -105,20 +106,20 @@ def on_predict_batch_end( ids = decollate_batch(outputs["id"], detach=True, pad=False) # Generate IDs if not provided. An ID will be the index of the prediction. else: - ids = list(range(self.last_index, self.last_index + len(preds))) - self.last_index += len(preds) + ids = list(range(self._current_pred_index, self._current_pred_index + len(preds))) + self._current_pred_index += len(preds) # Iterate over the predictions and save them. for id, pred in zip(ids, preds): # Convert predictions into a structured format suitable for writing. parsed_pred = parse_data(pred) # If the format hasn't been parsed yet, do it now. - if self.parsed_format is None: - self.parsed_format = parse_format(self.format, parsed_pred) + if self._parsed_format is None: + self._parsed_format = parse_format(self.format, parsed_pred) # If multiple outputs, parsed_pred will contain multiple keys. For a single output, key will be None. for multi_pred_id, tensor in parsed_pred.items(): # Save the prediction as per the designated format. - self.write(tensor, id, multi_pred_id, format=self.parsed_format[multi_pred_id]) + self.write(tensor, id, multi_pred_id, format=self._parsed_format[multi_pred_id]) def add_writer(self, format: str, writer_function: Callable) -> None: """ diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index 689d516b..f25bb995 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -1,11 +1,9 @@ from typing import Callable, Dict, Optional, Union -import sys from pathlib import Path import torch import torchvision -from loguru import logger from monai.transforms import DivisiblePad from monai.utils.module import optional_import diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 4fee0f08..0feace56 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -1,11 +1,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import itertools -import sys from pathlib import Path import pandas as pd -from loguru import logger from pytorch_lightning import Trainer from lighter import LighterSystem @@ -81,9 +79,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No # Set the path where the CSV will be saved csv_path = self.directory / "predictions.csv" - # Log the save path for user's reference - logger.info(f"Saving the predictions to {csv_path}") - # Sort the records by ID and convert the dictionary to a list self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] From 6c4563d29dfefca3013cd5fffeaba04e9fddbe04 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Mon, 14 Aug 2023 21:29:02 -0400 Subject: [PATCH 05/20] Remove multi opt and scheduler support. Replace remaininig sys.exit's. --- lighter/callbacks/logger.py | 6 ++-- lighter/callbacks/writer/file.py | 2 +- lighter/system.py | 51 +++++++++++++------------------- lighter/utils/dynamic_imports.py | 6 ++-- lighter/utils/misc.py | 3 +- 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 6c57ed2d..90bf7be8 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -62,8 +62,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: stage (str): stage of the training process. Passed automatically by PyTorch Lightning. """ if trainer.logger is not None: - logger.error("When using LighterLogger, set Trainer(logger=None).") - sys.exit() + raise ValueError("When using LighterLogger, set Trainer(logger=None).") if not trainer.is_global_zero: return @@ -88,8 +87,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: if self.wandb: OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb") if not wandb_available: - logger.error("Weights & Biases not installed. To install it, run `pip install wandb`. Exiting.") - sys.exit() + raise ImportError("Weights & Biases not installed. To install it, run `pip install wandb`.") wandb_dir = self.log_dir / "wandb" wandb_dir.mkdir() self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config) diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index f25bb995..e4e7c03b 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -91,7 +91,7 @@ def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None: if "sitk" not in OPTIONAL_IMPORTS: OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK") if not sitk_available: - raise ModuleNotFoundError("SimpleITK not installed. To install it, run `pip install SimpleITK`. Exiting.") + raise ImportError("SimpleITK is not available. Install it with `pip install SimpleITK`.") # Remove the channel dimension if it's equal to 1. tensor = tensor.squeeze(0) if (tensor.dim() == 4 and tensor.shape[0] == 1) else tensor diff --git a/lighter/system.py b/lighter/system.py index 4843e5dd..8049ca7b 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -25,28 +25,25 @@ class LighterSystem(pl.LightningModule): should be dropped. Defaults to False. num_workers (int, optional): number of dataloader workers. Defaults to 0. pin_memory (bool, optional): whether to pin the dataloaders memory. Defaults to True. - optimizer (Optional[Union[Optimizer, List[Optimizer]]], optional): - a single or a list of optimizers. Defaults to None. - scheduler (Optional[Union[Callable, List[Callable]]], optional): - a single or a list of schedulers. Defaults to None. - criterion (Optional[Callable], optional): - criterion/loss function. Defaults to None. - datasets (Optional[Dict[str, Optional[Dataset]]], optional): + optimizer (Optimizer, optional): optimizers. Defaults to None. + scheduler (LRScheduler, optional): learning rate scheduler. Defaults to None. + criterion (Callable, optional): criterion/loss function. Defaults to None. + datasets (Dict[str, Optional[Dataset]], optional): datasets for train, val, test, and predict. Supports Defaults to None. - samplers (Optional[Dict[str, Optional[Sampler]]], optional): + samplers (Dict[str, Optional[Sampler]], optional): samplers for train, val, test, and predict. Defaults to None. - collate_fns (Optional[Dict[str, Optional[Callable]]], optional): + collate_fns (Dict[str, Optional[Callable]], optional): collate functions for train, val, test, and predict. Defaults to None. - metrics (Optional[Dict[str, Optional[Union[Metric, List[Metric]]]]], optional): + metrics (Dict[str, Optional[Union[Metric, List[Metric]]]], optional): metrics for train, val, and test. Supports a single metric or a list of metrics, implemented using `torchmetrics`. Defaults to None. - postprocessing (Optional[Dict[str, Optional[Callable]]], optional): + postprocessing (Dict[str, Optional[Callable]], optional): Postprocessing functions for input, target, and pred, for three stages - criterion, metrics, and logging. The postprocessing is done before each stage - for example, criterion postprocessing will be done prior to loss calculation. Note that the postprocessing of a latter stage stacks on top of the previous one(s) - for example, the logging postprocessing will be done on the data that has been postprocessed for the criterion and metrics earlier. Defaults to None. - inferer (Optional[Callable], optional): the inferer must be a class with a `__call__` + inferer (Callable, optional): the inferer must be a class with a `__call__` method that accepts two arguments - the input to infer over, and the model itself. Used in 'val', 'test', and 'predict' mode, but not in 'train'. Typically, an inferer is a sliding window or a patch-based inferer that will infer over the smaller parts of @@ -61,8 +58,8 @@ def __init__( drop_last_batch: bool = False, num_workers: int = 0, pin_memory: bool = True, - optimizer: Optional[Union[Optimizer, List[Optimizer]]] = None, - scheduler: Optional[Union[Callable, List[Callable]]] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional["LRScheduler"] = None, criterion: Optional[Callable] = None, datasets: Optional[Dict[str, Optional[Dataset]]] = None, samplers: Optional[Dict[str, Optional[Sampler]]] = None, @@ -82,8 +79,8 @@ def __init__( # Criterion, optimizer, and scheduler self.criterion = criterion - self.optimizer = ensure_list(optimizer) - self.scheduler = ensure_list(scheduler) + self.optimizer = optimizer + self.scheduler = scheduler # DataLoader specifics self.num_workers = num_workers @@ -307,24 +304,18 @@ def _base_dataloader(self, mode: str) -> DataLoader: collate_fn=collate_fn, ) - def configure_optimizers(self) -> Union[Optimizer, List[Dict[str, Union[Optimizer, "Scheduler"]]]]: + def configure_optimizers(self) -> Union[Optimizer, Tuple[Optimizer, "LRScheduler"]]: """LightningModule method. Returns optimizers and, if defined, schedulers. Returns: - Optimizer or a List of Dict of paired Optimizers and Schedulers: instantiated - optimizers and/or schedulers. + Optimizer or a tuple of Optimizer and LRScheduler: the optimizer and, if defined, the scheduler. """ - if not self.optimizer: - logger.error("Please specify 'system.optimizer' in the config. Exiting.") - sys.exit() - if not self.scheduler: - return self.optimizer - - if len(self.optimizer) != len(self.scheduler): - logger.error("Each optimizer must have its own scheduler.") - sys.exit() - - return [{"optimizer": opt, "lr_scheduler": sched} for opt, sched in zip(self.optimizer, self.scheduler)] + if self.optimizer is None: + raise ValueError("Please specify 'system.optimizer' in the config.") + if self.scheduler is None: + return {"optimizer": self.optimizer} + else: + return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} def setup(self, stage: str) -> None: """Automatically called by the LightningModule after the initialization. diff --git a/lighter/utils/dynamic_imports.py b/lighter/utils/dynamic_imports.py index 155decf5..be690ea6 100644 --- a/lighter/utils/dynamic_imports.py +++ b/lighter/utils/dynamic_imports.py @@ -19,13 +19,11 @@ def import_module_from_path(module_name: str, module_path: str) -> None: # Based on https://stackoverflow.com/a/41595552. if module_name in sys.modules: - logger.error(f"{module_path} has already been imported as module: {module_name}") - sys.exit() + raise ValueError(f"{module_name} has already been imported as module.") module_path = Path(module_path).resolve() / "__init__.py" if not module_path.is_file(): - logger.error(f"No `__init__.py` in `{module_path}`. Exiting.") - sys.exit() + raise FileNotFoundError(f"No `__init__.py` in `{module_path}`.") spec = importlib.util.spec_from_file_location(module_name, str(module_path)) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index 1eae7cbe..74672022 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -63,8 +63,7 @@ def setattr_dot_notation(obj: Callable, attr: str, value: Any): """ if "." not in attr: if not hasattr(obj, attr): - logger.info(f"`{get_name(obj, True)}` has no attribute `{attr}`. Exiting.") - sys.exit() + raise AttributeError(f"`{get_name(obj, True)}` has no attribute `{attr}`.") setattr(obj, attr, value) # Solve recursively if the attribute is defined in dot-notation else: From f8d689b0331cf11d52b240b0ccead257317a2fd8 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Mon, 14 Aug 2023 21:34:20 -0400 Subject: [PATCH 06/20] Update configure_optimizers docstring --- lighter/system.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lighter/system.py b/lighter/system.py index 8049ca7b..a2dc322e 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -304,11 +304,11 @@ def _base_dataloader(self, mode: str) -> DataLoader: collate_fn=collate_fn, ) - def configure_optimizers(self) -> Union[Optimizer, Tuple[Optimizer, "LRScheduler"]]: + def configure_optimizers(self) -> Dict: """LightningModule method. Returns optimizers and, if defined, schedulers. Returns: - Optimizer or a tuple of Optimizer and LRScheduler: the optimizer and, if defined, the scheduler. + Dict: optimizer and, if defined, scheduler. """ if self.optimizer is None: raise ValueError("Please specify 'system.optimizer' in the config.") From 57b44479ab18eaedc07c1a623d9a2cf34c1134ed Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 15 Aug 2023 10:10:21 -0400 Subject: [PATCH 07/20] Fix index ID issue in DDP writing. Replace broadcast with gather in the TableWriter. --- lighter/callbacks/writer/base.py | 24 ++++++++++++++++++------ lighter/callbacks/writer/table.py | 23 ++++++++++++++--------- lighter/system.py | 2 +- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index aa6c7cf8..3bcb44c2 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -43,8 +43,8 @@ def __init__( # Placeholder for processed format for quicker access during writes self._parsed_format = None - # Keeps track of last written prediction index for cases when ids aren't provided - self._current_pred_index = 0 + # When IDs are not provided, keep track of the global prediction count. Supports DDP. + self._pred_count = None # Ensure that default writers are defined if not hasattr(self, "_writers"): @@ -86,10 +86,17 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: if stage != "predict": return + # Initialize the prediction count with the rank of the current process + self._pred_count = torch.distributed.get_rank() + # Ensure all distributed nodes write to the same directory - self.directory = trainer.strategy.broadcast(self.directory) + self.directory = trainer.strategy.broadcast(self.directory, src=0) if trainer.is_global_zero: self.directory.mkdir(parents=True) + # Wait for rank 0 to create the directory + trainer.strategy.barrier() + + # Ensure all distributed nodes have access to the directory if not self.directory.exists(): raise RuntimeError( f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access." @@ -104,10 +111,15 @@ def on_predict_batch_end( # Fetch and decollate IDs if provided. if outputs["id"] is not None: ids = decollate_batch(outputs["id"], detach=True, pad=False) - # Generate IDs if not provided. An ID will be the index of the prediction. + # Generate IDs if not provided. An ID will be the global index of the prediction. else: - ids = list(range(self._current_pred_index, self._current_pred_index + len(preds))) - self._current_pred_index += len(preds) + ids = [] + for _ in range(len(preds)): + # Append the current prediction count to the IDs list. + ids.append(self._pred_count) + # Increment the prediction count by the total number of DDP processes. + # This ensures each process will generate unique IDs in the next batch. + self._pred_count += trainer.world_size # Iterate over the predictions and save them. for id, pred in zip(ids, preds): diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 0feace56..0ae6b33c 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -4,6 +4,7 @@ from pathlib import Path import pandas as pd +import torch from pytorch_lightning import Trainer from lighter import LighterSystem @@ -80,17 +81,21 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No csv_path = self.directory / "predictions.csv" # Sort the records by ID and convert the dictionary to a list - self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] + self.csv_records = [self.csv_records[id] for id in sorted(self.csv_records)] - # If in distributed data parallel mode, gather records from all processes + # If in distributed data parallel mode, gather records from all processes to rank 0. if trainer.world_size > 1: - ddp_csv_records = [self.csv_records] * trainer.world_size - for rank in range(trainer.world_size): - ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) - self.csv_records = list(itertools.chain(*ddp_csv_records)) - - # Convert the list of records to a dataframe and save it as a CSV file - pd.DataFrame(self.csv_records).to_csv(csv_path) + # Create a list to hold the records from each process. Used on rank 0 only. + gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None + # Each process sends its records to rank 0, which stores them in the `gather_csv_records`. + torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0) + # Concatenate the gathered records + if trainer.is_global_zero: + self.csv_records = list(itertools.chain(*gather_csv_records)) + + # Save the records to a CSV file + if trainer.is_global_zero: + pd.DataFrame(self.csv_records).to_csv(csv_path) def write_tensor(tensor: Any) -> List: diff --git a/lighter/system.py b/lighter/system.py index a2dc322e..ab27645e 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -166,7 +166,7 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un if len(batch) != 2: raise ValueError( f"A batch must consist of 2 elements - input and target. However, {len(batch)} " - "elements wer found. Note: if target does not exist, return `None` as target." + "elements were found. Note: if target does not exist, return `None` as target." ) batch = {"input": batch[0], "target": batch[1], "id": None} # - Other types are not allowed. From 605764aa0f4db69e74d403003a220afcd1ea93a8 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 15 Aug 2023 10:18:41 -0400 Subject: [PATCH 08/20] Add missing if DDP check --- lighter/callbacks/writer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 3bcb44c2..9628e1e9 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -87,7 +87,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: return # Initialize the prediction count with the rank of the current process - self._pred_count = torch.distributed.get_rank() + self._pred_count = torch.distributed.get_rank() if trainer.world_size > 1 else 0 # Ensure all distributed nodes write to the same directory self.directory = trainer.strategy.broadcast(self.directory, src=0) From ae1a452d35b5ed8fc2a5b0e0e71e0509e3e96b3d Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 15 Aug 2023 20:55:25 -0400 Subject: [PATCH 09/20] Update docstrings, rename and refactor parse_data --- lighter/callbacks/logger.py | 4 +- lighter/callbacks/utils.py | 74 +++++++++++-------------------- lighter/callbacks/writer/base.py | 44 +++++++++++------- lighter/callbacks/writer/table.py | 15 +++---- 4 files changed, 64 insertions(+), 73 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 90bf7be8..3778dcda 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -12,7 +12,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from lighter import LighterSystem -from lighter.callbacks.utils import get_lighter_mode, is_data_type_supported, parse_data, preprocess_image +from lighter.callbacks.utils import flatten_structure, get_lighter_mode, is_data_type_supported, preprocess_image from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS @@ -220,7 +220,7 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], " f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported." ) - for identifier, item in parse_data(outputs[name]).items(): + for identifier, item in flatten_structure(outputs[name]).items(): item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}" self._log_by_type(item_name, item, self.log_types[name], global_step) diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index a5af8781..cb9e9e3d 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -45,48 +45,56 @@ def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List return is_valid -def parse_data( +def flatten_structure( data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None ) -> Dict[Optional[str], Any]: """ - Parse the input data recursively, handling nested dictionaries, lists, and tuples. + Recursively parse nested data structures into a flat dictionary. - This function will recursively parse the input data, unpacking nested dictionaries, lists, and tuples. The result - will be a dictionary where each key is a unique identifier reflecting the data's original structure (dict keys - or list/tuple positions) and each value is a non-container data type from the input data. + This function flattens dictionaries, lists, and tuples, returning a dictionary where each key is constructed + from the original structure's keys or list/tuple indices. The values in the output dictionary are non-container + data types extracted from the input. Args: - data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to parse. - prefix (Optional[str]): Current prefix for keys in the result dictionary. Defaults to None. + data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): + The input data to parse. Can be of any data type but the function is optimized + to handle dictionaries, lists, and tuples. Nested structures are also supported. + + prefix (Optional[str]): + A prefix used when constructing keys for the output dictionary. Useful for recursive + calls to maintain context. Defaults to None. Returns: - Dict[Optional[str], Any]: A dictionary where key is either a string identifier or `None`, and value is the parsed output. + Dict[Optional[str], Any]: + A flattened dictionary where keys are unique identifiers built from the original data structure, + and values are non-container data extracted from the input. Example: input_data = { "a": [1, 2], "b": {"c": (3, 4), "d": 5} } - output_data = parse_data(input_data) - # Output: - # { - # 'a_0': 1, - # 'a_1': 2, - # 'b_c_0': 3, - # 'b_c_1': 4, - # 'b_d': 5 - # } + output_data = flatten_structure(input_data) + + Expected output: + { + 'a_0': 1, + 'a_1': 2, + 'b_c_0': 3, + 'b_c_1': 4, + 'b_d': 5 + } """ result = {} if isinstance(data, dict): for key, value in data.items(): # Recursively parse the value with an updated prefix - sub_result = parse_data(value, prefix=f"{prefix}_{key}" if prefix else key) + sub_result = flatten_structure(value, prefix=f"{prefix}_{key}" if prefix else key) result.update(sub_result) elif isinstance(data, (list, tuple)): for idx, element in enumerate(data): # Recursively parse the element with an updated prefix - sub_result = parse_data(element, prefix=f"{prefix}_{idx}" if prefix else str(idx)) + sub_result = flatten_structure(element, prefix=f"{prefix}_{idx}" if prefix else str(idx)) result.update(sub_result) else: # Assign the value to the result dictionary using the current prefix as its key @@ -94,34 +102,6 @@ def parse_data( return result -def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]: - """ - Parse the given format and align it with the structure of the predictions. - - If the format is a single string, all predictions will be saved in this format. If the format has a structure - (like a dictionary), it needs to match the structure of the predictions. - - Args: - format (str): The storage format for the predictions, either as a string or a structured format. - parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data. - - Returns: - Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure. - - Raises: - ValueError: If the structure of the format does not align with the prediction structure. - """ - if isinstance(format, str): - # Assign the single format to all prediction keys. - parsed_format = {key: format for key in parsed_preds} - else: - # Ensure the structured format corresponds with the predictions' structure. - parsed_format = parse_data(format) - if not set(parsed_format) == set(parsed_preds): - raise ValueError("`format` structure does not match the prediction's structure.") - return parsed_format - - def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor: """Preprocess the image before logging it. If it is a batch of multiple images, it will create a grid image of them. In case of 3D, a single image is displayed diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 9628e1e9..4537e4e4 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -12,7 +12,7 @@ from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import parse_data, parse_format +from lighter.callbacks.utils import flatten_structure class LighterBaseWriter(ABC, Callback): @@ -40,9 +40,6 @@ def __init__( self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") self.format = format - # Placeholder for processed format for quicker access during writes - self._parsed_format = None - # When IDs are not provided, keep track of the global prediction count. Supports DDP. self._pred_count = None @@ -68,8 +65,8 @@ def write( the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving, either in this method or in the particular writer function. - For each supported format, there should be a corresponding writer function registered in `self._writers`, - and can be retrieved using `self.get_writer(format)`. + For each supported format, there should be a corresponding writer function registered in `self._writers` + A specific writer function can be retrieved using `self.get_writer(format)`. Args: tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension. @@ -82,7 +79,13 @@ def write( pass def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: - """Callback for setup stage in Pytorch Lightning Trainer.""" + """ + Callback function to set up necessary prerequisites: prediction count and prediction directory. + When executing in a distributed environment, it ensures that: + 1. Each distributed node initializes a prediction count based on its rank. + 2. All distributed nodes write predictions to the same directory. + 3. The directiory is accessible to all nodes, i.e. that all nodes share the same storage. + """ if stage != "predict": return @@ -105,7 +108,13 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: def on_predict_batch_end( self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: - """Callback method triggered at the end of each prediction batch/step.""" + """ + Callback method executed at the end of each prediction batch/step. + + It decollates the predicted outputs and, if provided, the associated IDs. + If the IDs are not provided, it generates global unique IDs based on the prediction count. + Finally, it writes the predictions according to the specified format. + """ # Fetch and decollate preds. preds = decollate_batch(outputs["pred"], detach=True, pad=False) # Fetch and decollate IDs if provided. @@ -124,14 +133,19 @@ def on_predict_batch_end( # Iterate over the predictions and save them. for id, pred in zip(ids, preds): # Convert predictions into a structured format suitable for writing. - parsed_pred = parse_data(pred) - # If the format hasn't been parsed yet, do it now. - if self._parsed_format is None: - self._parsed_format = parse_format(self.format, parsed_pred) - # If multiple outputs, parsed_pred will contain multiple keys. For a single output, key will be None. - for multi_pred_id, tensor in parsed_pred.items(): + pred = flatten_structure(pred) + + # If a single format is provided, assign it to all pred keys. Otherwise, the format must match the pred structure. + format = {key: self.format for key in pred} if isinstance(self.format, str) else flatten_structure(self.format) + # Ensure that the format structure matches the prediction structure. + if not set(format) == set(pred): + raise ValueError("`format` structure does not match the prediction's structure.") + + # If pred is multi-output, there will be a `multi_pred_id` for each output. + # If single-output, `multi_pred_id` will be None. + for multi_pred_id, tensor in pred.items(): # Save the prediction as per the designated format. - self.write(tensor, id, multi_pred_id, format=self._parsed_format[multi_pred_id]) + self.write(tensor, id, multi_pred_id, format=format[multi_pred_id]) def add_writer(self, format: str, writer_function: Callable) -> None: """ diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 0ae6b33c..f65e399b 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -42,10 +42,8 @@ def __init__( def write(self, tensor: Any, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: """ Write the tensor as a table record in the given format. - - If there are multiple predictions, there will be a separate column for each prediction, - named after the corresponding `multi_pred_id`. - If single prediction, there will be a single column named "pred". + If there are multiple predictions, there will be a separate column for each prediction, named after + the corresponding `multi_pred_id`. If single prediction, there will be a single column named "pred". Args: tensor (Any): The tensor to be written. @@ -70,12 +68,11 @@ def write(self, tensor: Any, id: Union[int, str], multi_pred_id: Optional[Union[ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """ - Callback method triggered at the end of the prediction epoch to dump the CSV table. + Callback invoked at the end of the prediction epoch to save predictions to a CSV file. - Args: - trainer (Trainer): Pytorch Lightning Trainer instance. - pl_module (LighterSystem): Lighter system instance. - outputs (List[Any]): List of predictions. + This method is responsible for organizing prediction records and saving them as a CSV file. + If training was done in a distributed setting, it gathers predictions from all processes + and then saves them from the rank 0 process. """ # Set the path where the CSV will be saved csv_path = self.directory / "predictions.csv" From cb63a9eda92b2535367f88c726cc806553c0f055 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 15 Aug 2023 21:01:57 -0400 Subject: [PATCH 10/20] Add freezer to init file --- lighter/callbacks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lighter/callbacks/__init__.py b/lighter/callbacks/__init__.py index 78d3fcf5..360ba889 100644 --- a/lighter/callbacks/__init__.py +++ b/lighter/callbacks/__init__.py @@ -1,3 +1,4 @@ +from .freezer import LighterFreezer from .logger import LighterLogger from .writer.file import LighterFileWriter from .writer.table import LighterTableWriter From 736d9f60865acef736281957d50e63bfed972157 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Tue, 15 Aug 2023 21:16:22 -0400 Subject: [PATCH 11/20] Change property to attribute --- lighter/callbacks/writer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 4537e4e4..fb070462 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -45,7 +45,7 @@ def __init__( # Ensure that default writers are defined if not hasattr(self, "_writers"): - raise NotImplementedError("Subclasses of LighterBaseWriter must implement the `_writers` property.") + raise NotImplementedError("Subclasses of LighterBaseWriter must implement the `_writers` attribute.") # Register any additional writers passed during initialization if additional_writers: From fe7693aa1a6d7611ad9d6caa04dae6d033c6f964 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 16 Aug 2023 13:21:01 -0400 Subject: [PATCH 12/20] Add support for dict metrics. Refactor system. --- lighter/callbacks/logger.py | 1 - lighter/callbacks/writer/base.py | 1 - lighter/system.py | 151 ++++++++++++++++++------------- lighter/utils/cli.py | 1 - lighter/utils/misc.py | 1 - 5 files changed, 86 insertions(+), 69 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 3778dcda..2bd99ec4 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Union import itertools -import sys from datetime import datetime from pathlib import Path diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index fb070462..065c705b 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -1,7 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import itertools -import sys from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path diff --git a/lighter/system.py b/lighter/system.py index ab27645e..15be957e 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -1,54 +1,51 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import sys from functools import partial import pytorch_lightning as pl import torch from loguru import logger -from torch.nn import Module +from torch.nn import Module, ModuleDict from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader, Dataset, Sampler from torchmetrics import Metric, MetricCollection from lighter.utils.collate import collate_replace_corrupted -from lighter.utils.misc import apply_fns, ensure_dict_schema, ensure_list, get_name, hasarg +from lighter.utils.misc import apply_fns, ensure_dict_schema, get_name, hasarg class LighterSystem(pl.LightningModule): """_summary_ Args: - model (Module): the model. - batch_size (int): batch size. - drop_last_batch (bool, optional): whether the last batch in the dataloader - should be dropped. Defaults to False. - num_workers (int, optional): number of dataloader workers. Defaults to 0. - pin_memory (bool, optional): whether to pin the dataloaders memory. Defaults to True. - optimizer (Optimizer, optional): optimizers. Defaults to None. - scheduler (LRScheduler, optional): learning rate scheduler. Defaults to None. - criterion (Callable, optional): criterion/loss function. Defaults to None. - datasets (Dict[str, Optional[Dataset]], optional): - datasets for train, val, test, and predict. Supports Defaults to None. - samplers (Dict[str, Optional[Sampler]], optional): - samplers for train, val, test, and predict. Defaults to None. - collate_fns (Dict[str, Optional[Callable]], optional): - collate functions for train, val, test, and predict. Defaults to None. - metrics (Dict[str, Optional[Union[Metric, List[Metric]]]], optional): - metrics for train, val, and test. Supports a single metric or a list of metrics, - implemented using `torchmetrics`. Defaults to None. - postprocessing (Dict[str, Optional[Callable]], optional): + model (Module): The model. + batch_size (int): Batch size. + drop_last_batch (bool, optional): Whether the last batch in the dataloader should be dropped. Defaults to False. + num_workers (int, optional): Number of dataloader workers. Defaults to 0. + pin_memory (bool, optional): Whether to pin the dataloaders memory. Defaults to True. + optimizer (Optimizer, optional): Optimizers. Defaults to None. + scheduler (LRScheduler, optional): Learning rate scheduler. Defaults to None. + criterion (Callable, optional): Criterion/loss function. Defaults to None. + datasets (Dict[str, Dataset], optional): Datasets for train, val, test, and predict. Defaults to None. + samplers (Dict[str, Sampler], optional): Samplers for train, val, test, and predict. Defaults to None. + collate_fns (Dict[str, Union[Callable, List[Callable]]], optional): + Collate functions for train, val, test, and predict. Defaults to None. + metrics (Dict[str, Union[Metric, List[Metric], Dict[str, Metric]]], optional): + Metrics for train, val, and test. Supports a single metric or a list/dict of `torchmetrics` metrics. + Defaults to None. + postprocessing (Dict[str, Union[Callable, List[Callable]]], optional): Postprocessing functions for input, target, and pred, for three stages - criterion, metrics, and logging. The postprocessing is done before each stage - for example, criterion postprocessing will be done prior to loss calculation. Note that the postprocessing of a latter stage stacks on top of the previous one(s) - for example, the logging postprocessing will be done on the data that has been postprocessed for the criterion and metrics earlier. Defaults to None. - inferer (Callable, optional): the inferer must be a class with a `__call__` - method that accepts two arguments - the input to infer over, and the model itself. - Used in 'val', 'test', and 'predict' mode, but not in 'train'. Typically, an inferer - is a sliding window or a patch-based inferer that will infer over the smaller parts of - the input, combine them, and return a single output. The inferers provided by MONAI - cover most of such cases (https://docs.monai.io/en/stable/inferers.html). Defaults to None. + inferer (Callable, optional): The inferer must be a class with a `__call__` method that accepts two + arguments - the input to infer over, and the model itself. Used in 'val', 'test', and 'predict' + mode, but not in 'train'. Typically, an inferer is a sliding window or a patch-based inferer + that will infer over the smaller parts of the input, combine them, and return a single output. + The inferers provided by MONAI cover most of such cases (https://docs.monai.io/en/stable/inferers.html). + Defaults to None. """ def __init__( @@ -59,13 +56,13 @@ def __init__( num_workers: int = 0, pin_memory: bool = True, optimizer: Optional[Optimizer] = None, - scheduler: Optional["LRScheduler"] = None, + scheduler: Optional[LRScheduler] = None, criterion: Optional[Callable] = None, - datasets: Optional[Dict[str, Optional[Dataset]]] = None, - samplers: Optional[Dict[str, Optional[Sampler]]] = None, - collate_fns: Optional[Dict[str, Optional[Callable]]] = None, - metrics: Optional[Dict[str, Optional[Union[Metric, List[Metric]]]]] = None, - postprocessing: Optional[Dict[str, Optional[Callable]]] = None, + datasets: Dict[str, Dataset] = None, + samplers: Dict[str, Sampler] = None, + collate_fns: Dict[str, Union[Callable, List[Callable]]] = None, + metrics: Dict[str, Union[Metric, List[Metric], Dict[str, Metric]]] = None, + postprocessing: Dict[str, Union[Callable, List[Callable]]] = None, inferer: Optional[Callable] = None, ) -> None: super().__init__() @@ -87,25 +84,15 @@ def __init__( self.pin_memory = pin_memory # Datasets, samplers, and collate functions - schema = {"train": None, "val": None, "test": None, "predict": None} - self.datasets = ensure_dict_schema(datasets, schema) - self.samplers = ensure_dict_schema(samplers, schema) - self.collate_fns = ensure_dict_schema(collate_fns, schema) + self.datasets = self._init_datasets(datasets) + self.samplers = self._init_samplers(samplers) + self.collate_fns = self._init_collate_fns(collate_fns) # Metrics - self.metrics = ensure_dict_schema(metrics, schema={"train": None, "val": None, "test": None}) - self.metrics = {mode: MetricCollection(ensure_list(metric)) for mode, metric in self.metrics.items()} - # Register the metrics to allow the LightningModule to automatically move them to the correct device. - # Currently, a workaround is needed because of https://github.com/pytorch/pytorch/issues/71203. - # Once it's fixed, we can set `self.metrics = ModuleDict(self.metrics)` directly. - for mode, mode_metrics in self.metrics.items(): - setattr(self, f"{mode}_metric", mode_metrics) - self.metrics[mode] = getattr(self, f"{mode}_metric") + self.metrics = self._init_metrics(metrics) # Postprocessing - schema = {"input": None, "target": None, "pred": None} - schema = {"criterion": schema, "metrics": schema, "logging": schema} - self.postprocessing = ensure_dict_schema(postprocessing, schema) + self.postprocessing = self._init_postprocessing(postprocessing) # Inferer for val, test, and predict self.inferer = inferer @@ -194,6 +181,17 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un # Calculate the loss. loss = self._calculate_loss(pred, target) if mode in ["train", "val"] else None + # Log the loss for monitoring purposes. + if loss is not None: + self.log( + "loss" if mode == "train" else f"{mode}_loss", + loss, + on_step=True, + on_epoch=True, + sync_dist=True, + logger=False, + batch_size=self.batch_size, + ) # Log and return the results. if mode == "predict": @@ -206,21 +204,12 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un target = apply_fns(target, self.postprocessing["metrics"]["target"]) pred = apply_fns(pred, self.postprocessing["metrics"]["pred"]) - # Calculate the metrics for the step. - metrics = self.metrics[mode](pred, target) - - # Log the metrics for monitoring purposes. - self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True, logger=False, batch_size=self.batch_size) - # Log the loss for monitoring purposes. - self.log( - "loss" if mode == "train" else f"{mode}_loss", - loss, - on_step=True, - on_epoch=True, - sync_dist=True, - logger=False, - batch_size=self.batch_size, - ) + # Calculate the step metrics. + # TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203 + metrics = self.metrics["_" + mode](pred, target) if self.metrics["_" + mode] is not None else None + # Log the metrics. + if metrics is not None: + self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True, logger=False, batch_size=self.batch_size) # Data postprocessing for logging. input = apply_fns(input, self.postprocessing["logging"]["input"]) @@ -363,7 +352,10 @@ def setup(self, stage: str) -> None: self.predict_step = partial(self._base_step, mode="predict") def _init_placeholders_for_dataloader_and_step_methods(self) -> None: - """`LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods + """ + Initializes placeholders for dataloader and step methods. + + `LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods in the `self.setup()` method. However, when `LightningModule` excepts them to be defined at init. To prevent it from throwing an error, the `..._dataloader()` and `..._step()` are initially defined as `lambda: None`, before `self.setup()` is called. @@ -372,3 +364,32 @@ def _init_placeholders_for_dataloader_and_step_methods(self) -> None: self.val_dataloader = self.validation_step = lambda: None self.test_dataloader = self.test_step = lambda: None self.predict_dataloader = self.predict_step = lambda: None + + def _init_datasets(self, datasets: Dict[str, Optional[Dataset]]): + """Ensures that the datasets have the predefined schema.""" + return ensure_dict_schema(datasets, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_samplers(self, samplers: Dict[str, Optional[Sampler]]): + """Ensures that the samplers have the predefined schema""" + return ensure_dict_schema(samplers, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_collate_fns(self, collate_fns: Dict[str, Optional[Callable]]): + """Ensures that the collate functions have the predefined schema.""" + return ensure_dict_schema(collate_fns, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_metrics(self, metrics: Dict[str, Optional[Union[Metric, List[Metric], Dict[str, Metric]]]]): + """Ensures that the metrics have the desired schema. Wraps each mode's metrics in + a MetricCollection, and finally registers them with PyTorch using a ModuleDict. + """ + metrics = ensure_dict_schema(metrics, {"train": None, "val": None, "test": None}) + for mode, metric in metrics.items(): + metrics[mode] = MetricCollection(metric) if metric is not None else None + # TODO: Remove the prefix addition line below when fixed https://github.com/pytorch/pytorch/issues/71203 + metrics = {f"_{k}": v for k, v in metrics.items()} + return ModuleDict(metrics) + + def _init_postprocessing(self, postprocessing: Dict[str, Optional[Union[Callable, List[Callable]]]]): + """Ensures that the postprocessing functions have the predefined schema.""" + subschema = {"input": None, "target": None, "pred": None} + schema = {"criterion": subschema, "metrics": subschema, "logging": subschema} + return ensure_dict_schema(postprocessing, schema) diff --git a/lighter/utils/cli.py b/lighter/utils/cli.py index 0d942765..c74b31ec 100644 --- a/lighter/utils/cli.py +++ b/lighter/utils/cli.py @@ -1,4 +1,3 @@ -import sys from functools import partial import fire diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index 74672022..06328bb6 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -1,7 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import inspect -import sys from loguru import logger From 57bcd7a7765800b5ce04d953fa1cb3ecc3a91639 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 16 Aug 2023 13:34:59 -0400 Subject: [PATCH 13/20] Fix typos --- lighter/callbacks/freezer.py | 2 +- lighter/system.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lighter/callbacks/freezer.py b/lighter/callbacks/freezer.py index e4b54d5d..a5d91089 100644 --- a/lighter/callbacks/freezer.py +++ b/lighter/callbacks/freezer.py @@ -69,7 +69,7 @@ def on_test_batch_start( self._on_batch_start(trainer, pl_module) def on_predict_batch_start( - self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: self._on_batch_start(trainer, pl_module) diff --git a/lighter/system.py b/lighter/system.py index 15be957e..43733027 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -378,7 +378,7 @@ def _init_collate_fns(self, collate_fns: Dict[str, Optional[Callable]]): return ensure_dict_schema(collate_fns, {"train": None, "val": None, "test": None, "predict": None}) def _init_metrics(self, metrics: Dict[str, Optional[Union[Metric, List[Metric], Dict[str, Metric]]]]): - """Ensures that the metrics have the desired schema. Wraps each mode's metrics in + """Ensures that the metrics have the predefined schema. Wraps each mode's metrics in a MetricCollection, and finally registers them with PyTorch using a ModuleDict. """ metrics = ensure_dict_schema(metrics, {"train": None, "val": None, "test": None}) From 799719ef6f20e9c1d9247e181b0dd09208759dd7 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Wed, 16 Aug 2023 17:33:58 -0400 Subject: [PATCH 14/20] Remove unused imports --- lighter/callbacks/writer/base.py | 2 -- lighter/utils/collate.py | 3 +-- lighter/utils/misc.py | 4 +--- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 065c705b..31ae2bdd 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -1,12 +1,10 @@ from typing import Any, Callable, Dict, List, Optional, Union -import itertools from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path import torch -from loguru import logger from monai.data.utils import decollate_batch from pytorch_lightning import Callback, Trainer diff --git a/lighter/utils/collate.py b/lighter/utils/collate.py index b0bd8238..5dff93af 100644 --- a/lighter/utils/collate.py +++ b/lighter/utils/collate.py @@ -1,8 +1,7 @@ -from typing import Any, Callable, List +from typing import Any, Callable import random -import torch from torch.utils.data import DataLoader from torch.utils.data._utils.collate import collate_str_fn, default_collate_fn_map from torch.utils.data.dataloader import default_collate diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index 06328bb6..ae27e2bf 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -1,9 +1,7 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Union import inspect -from loguru import logger - def ensure_list(vals: Any) -> List: """Wrap the input into a list if it is not a list. If it is a None, return an empty list. From 21d84f3805a4ed774c19c8ffdee97c5e3d887eb3 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Fri, 18 Aug 2023 11:15:18 -0400 Subject: [PATCH 15/20] Update logger.py to support the temp ModuleDict fix --- lighter/callbacks/logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 2bd99ec4..b75d5c31 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -257,7 +257,8 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Metrics # Get the torchmetrics. - metric_collection = pl_module.metrics[mode] + # TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203 + metric_collection = pl_module.metrics["_" + mode] if metric_collection is not None: # Compute the epoch metrics. metrics = metric_collection.compute() From c8eedea1c5abe75345a3d8154ad7b8b50499cb38 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Sun, 20 Aug 2023 02:21:44 +0200 Subject: [PATCH 16/20] Add continue to freezer and detach cpu to image logging --- lighter/callbacks/freezer.py | 13 ++++++++----- lighter/callbacks/logger.py | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lighter/callbacks/freezer.py b/lighter/callbacks/freezer.py index a5d91089..b2aba887 100644 --- a/lighter/callbacks/freezer.py +++ b/lighter/callbacks/freezer.py @@ -122,20 +122,23 @@ def _set_model_requires_grad(self, model: Union[Module, LighterSystem], requires # Leave the excluded-from-freezing parameters trainable. if self.except_names and name in self.except_names: param.requires_grad = True - elif self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with): + continue + if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with): param.requires_grad = True + continue # Freeze/unfreeze the specified parameters, based on the `requires_grad` argument. - elif self.names and name in self.names: + if self.names and name in self.names: param.requires_grad = requires_grad frozen_layers.append(name) - elif self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with): + continue + if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with): param.requires_grad = requires_grad frozen_layers.append(name) + continue # Otherwise, leave the parameter trainable. - else: - param.requires_grad = True + param.requires_grad = True self._frozen_state = not requires_grad # Log only when freezing the parameters. diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index b75d5c31..d9b48570 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -162,6 +162,7 @@ def _log_image(self, name: str, image: torch.Tensor, global_step: int) -> None: image (torch.Tensor): image to be logged. global_step (int): current global step. """ + image = image.detach().cpu() if self.tensorboard: self.tensorboard.add_image(name, image, global_step=global_step) if self.wandb: From 8874399ef7dc59b438c75560bbbc2f6a9a18b818 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Thu, 14 Sep 2023 16:38:44 -0400 Subject: [PATCH 17/20] Remove multi_pred, refactor Writer, Logger, and optional imports --- lighter/callbacks/logger.py | 142 ++++++++++++++--------------- lighter/callbacks/utils.py | 87 ------------------ lighter/callbacks/writer/base.py | 144 +++++++++--------------------- lighter/callbacks/writer/file.py | 70 ++++++--------- lighter/callbacks/writer/table.py | 63 ++++--------- lighter/utils/dynamic_imports.py | 72 ++++++++++----- 6 files changed, 203 insertions(+), 375 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index d9b48570..32fb1935 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -6,12 +6,11 @@ import torch from loguru import logger -from monai.utils.module import optional_import from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from lighter import LighterSystem -from lighter.callbacks.utils import flatten_structure, get_lighter_mode, is_data_type_supported, preprocess_image +from lighter.callbacks.utils import get_lighter_mode, preprocess_image from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS @@ -74,8 +73,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # Tensorboard initialization. if self.tensorboard: - # Tensorboard is a part of PyTorch, no need to check if it is not available. - OPTIONAL_IMPORTS["tensorboard"], _ = optional_import("torch.utils.tensorboard") tensorboard_dir = self.log_dir / "tensorboard" tensorboard_dir.mkdir() self.tensorboard = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter(log_dir=tensorboard_dir) @@ -84,9 +81,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # Wandb initialization. if self.wandb: - OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb") - if not wandb_available: - raise ImportError("Weights & Biases not installed. To install it, run `pip install wandb`.") wandb_dir = self.log_dir / "wandb" wandb_dir.mkdir() self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config) @@ -191,49 +185,44 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: outputs (Dict): output dict from the model. trainer (Trainer): Trainer, passed automatically by PyTorch Lightning. """ - if not trainer.sanity_checking: - mode = get_lighter_mode(trainer.state.stage) - # Accumulate the loss. - if mode in ["train", "val"]: - self.loss[mode] += outputs["loss"].item() - # Logging frequency. Log only on rank 0. - if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0: - # Get global step. - global_step = self._get_global_step(trainer) - - # Log loss. - if outputs["loss"] is not None: - self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step) - - # Log metrics. - if outputs["metrics"] is not None: - for name, metric in outputs["metrics"].items(): - self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step) - - # Log input, target, and pred. - for name in ["input", "target", "pred"]: - if self.log_types[name] is None: - continue - # Ensure data is of a valid type. - if not is_data_type_supported(outputs[name]): - raise ValueError( - f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], " - f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported." - ) - for identifier, item in flatten_structure(outputs[name]).items(): - item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}" - self._log_by_type(item_name, item, self.log_types[name], global_step) - - # Log learning rate stats. Logs at step if a scheduler's interval is step-based. - if mode == "train": - lr_stats = self.lr_monitor.get_stats(trainer, "step") - for name, value in lr_stats.items(): - self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step) - - # Increment the step counters. - self.global_step_counter[mode] += 1 - if mode in ["train", "val"]: - self.epoch_step_counter[mode] += 1 + if trainer.sanity_checking: + return + + mode = get_lighter_mode(trainer.state.stage) + + # Accumulate the loss. + if mode in ["train", "val"]: + self.loss[mode] += outputs["loss"].item() + + # Log only on rank 0 and according to the `log_every_n_steps` parameter. Otherwise, only increment the step counters. + if not trainer.is_global_zero or self.global_step_counter[mode] % trainer.log_every_n_steps != 0: + self._increment_step_counters(mode) + return + + global_step = self._get_global_step(trainer) + + # Loss. + if outputs["loss"] is not None: + self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step) + + # Metrics. + if outputs["metrics"] is not None: + for name, metric in outputs["metrics"].items(): + self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step) + + # Input, target, and pred. + for name in ["input", "target", "pred"]: + if self.log_types[name] is not None: + self._log_by_type(f"{mode}/data/{name}", outputs[name], self.log_types[name], global_step) + + # LR info. Logs at step if a scheduler's interval is step-based. + if mode == "train": + lr_stats = self.lr_monitor.get_stats(trainer, "step") + for name, value in lr_stats.items(): + self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step) + + # Increment the step counters. + self._increment_step_counters(mode) def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """Performs logging at the end of an epoch. Logs the epoch number, the loss, and the metrics. @@ -247,16 +236,12 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: mode = get_lighter_mode(trainer.state.stage) loss, metrics = None, None - # Loss + # Get the accumulated loss over the epoch and processes. if mode in ["train", "val"]: - # Get the accumulated loss. loss = self.loss[mode] - # Reduce the loss and average it on each rank. loss = trainer.strategy.reduce(loss, reduce_op="mean") - # Divide the accumulated loss by the number of steps in the epoch. loss /= self.epoch_step_counter[mode] - # Metrics # Get the torchmetrics. # TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203 metric_collection = pl_module.metrics["_" + mode] @@ -266,28 +251,29 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Reset the metrics for the next epoch. metric_collection.reset() - # Log. Only on rank 0. - if trainer.is_global_zero: - # Get global step. - global_step = self._get_global_step(trainer) + # Log only on rank 0. + if not trainer.is_global_zero: + return - # Log epoch number. - self._log_scalar("epoch", trainer.current_epoch, global_step) + global_step = self._get_global_step(trainer) - # Log loss. - if loss is not None: - self._log_scalar(f"{mode}/loss/epoch", loss, global_step) + # Epoch number. + self._log_scalar("epoch", trainer.current_epoch, global_step) - # Log metrics. - if metrics is not None: - for name, metric in metrics.items(): - self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step) + # Loss. + if loss is not None: + self._log_scalar(f"{mode}/loss/epoch", loss, global_step) - # Log learning rate stats. Logs at epoch if a scheduler's interval is epoch-based, or if no scheduler is used. - if mode == "train": - lr_stats = self.lr_monitor.get_stats(trainer, "epoch") - for name, value in lr_stats.items(): - self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step) + # Metrics. + if metrics is not None: + for name, metric in metrics.items(): + self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step) + + # LR info. Logged at epoch if the scheduler's interval is epoch-based, or if no scheduler is used. + if mode == "train": + lr_stats = self.lr_monitor.get_stats(trainer, "epoch") + for name, value in lr_stats.items(): + self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step) def _get_global_step(self, trainer: Trainer) -> int: """Return the global step for the current mode. Note that when Trainer @@ -308,6 +294,16 @@ def _get_global_step(self, trainer: Trainer) -> int: return self.global_step_counter["train"] return self.global_step_counter[mode] + def _increment_step_counters(self, mode: str) -> None: + """Increment the global step and epoch step counters for the specified mode. + + Args: + mode (str): mode to increment the global step counter for. + """ + self.global_step_counter[mode] += 1 + if mode in ["train", "val"]: + self.epoch_step_counter[mode] += 1 + def on_train_epoch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Reset the loss and the epoch step counter for the next epoch. self.loss["train"] = 0 diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index cb9e9e3d..8b8b39f0 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - import torch import torchvision @@ -17,91 +15,6 @@ def get_lighter_mode(lightning_stage: str) -> str: return lightning_to_lighter[lightning_stage] -def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]) -> bool: - """ - Check the input data recursively for its type. Valid data types are: - - torch.Tensor - - List[torch.Tensor] - - Tuple[torch.Tensor] - - Dict[str, torch.Tensor] - - Dict[str, List[torch.Tensor]] - - Dict[str, Tuple[torch.Tensor]] - - Nested combinations of the above - - Args: - data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to check. - - Returns: - bool: True if the data type is supported, False otherwise. - """ - if isinstance(data, dict): - is_valid = all(is_data_type_supported(elem) for elem in data.values()) - elif isinstance(data, (list, tuple)): - is_valid = all(is_data_type_supported(elem) for elem in data) - elif isinstance(data, torch.Tensor): - is_valid = True - else: - is_valid = False - return is_valid - - -def flatten_structure( - data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None -) -> Dict[Optional[str], Any]: - """ - Recursively parse nested data structures into a flat dictionary. - - This function flattens dictionaries, lists, and tuples, returning a dictionary where each key is constructed - from the original structure's keys or list/tuple indices. The values in the output dictionary are non-container - data types extracted from the input. - - Args: - data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): - The input data to parse. Can be of any data type but the function is optimized - to handle dictionaries, lists, and tuples. Nested structures are also supported. - - prefix (Optional[str]): - A prefix used when constructing keys for the output dictionary. Useful for recursive - calls to maintain context. Defaults to None. - - Returns: - Dict[Optional[str], Any]: - A flattened dictionary where keys are unique identifiers built from the original data structure, - and values are non-container data extracted from the input. - - Example: - input_data = { - "a": [1, 2], - "b": {"c": (3, 4), "d": 5} - } - output_data = flatten_structure(input_data) - - Expected output: - { - 'a_0': 1, - 'a_1': 2, - 'b_c_0': 3, - 'b_c_1': 4, - 'b_d': 5 - } - """ - result = {} - if isinstance(data, dict): - for key, value in data.items(): - # Recursively parse the value with an updated prefix - sub_result = flatten_structure(value, prefix=f"{prefix}_{key}" if prefix else key) - result.update(sub_result) - elif isinstance(data, (list, tuple)): - for idx, element in enumerate(data): - # Recursively parse the element with an updated prefix - sub_result = flatten_structure(element, prefix=f"{prefix}_{idx}" if prefix else str(idx)) - result.update(sub_result) - else: - # Assign the value to the result dictionary using the current prefix as its key - result[prefix] = data - return result - - def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor: """Preprocess the image before logging it. If it is a batch of multiple images, it will create a grid image of them. In case of 3D, a single image is displayed diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 31ae2bdd..410fb873 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -1,15 +1,13 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Union from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path import torch -from monai.data.utils import decollate_batch from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import flatten_structure class LighterBaseWriter(ABC, Callback): @@ -17,63 +15,58 @@ class LighterBaseWriter(ABC, Callback): Base class for defining custom Writer. It provides the structure to save predictions in various formats. Subclasses should implement: - 1) `self._writers` attribute to specify the supported formats and their corresponding writer functions. + 1) `self.writers` attribute to specify the supported formats and their corresponding writer functions. 2) `self.write()` method to specify the saving strategy for a prediction. Args: directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. - format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): - Desired format(s) for saving predictions. The format will be passed to the `write` method. - additional_writers (Optional[Dict[str, Callable]]): Additional writer functions to be registered with the base writer. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. """ - def __init__( - self, - directory: str, - format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - additional_writers: Optional[Dict[str, Callable]] = None, - ) -> None: + def __init__(self, directory: str, writer: Union[str, Callable]) -> None: + """ + Initialize the LighterBaseWriter. + + Args: + directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + """ # Create a unique directory using the current date and time self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") - self.format = format - # When IDs are not provided, keep track of the global prediction count. Supports DDP. - self._pred_count = None + # Check if the writer is a string and if it exists in the writers dictionary + if isinstance(writer, str): + if writer not in self.writers: + raise ValueError(f"Writer for format {writer} does not exist. Available writers: {self.writers.keys()}.") + self.writer = self.writers[writer] + else: + # If the writer is not a string, it is assumed to be a callable function + self.writer = writer - # Ensure that default writers are defined - if not hasattr(self, "_writers"): - raise NotImplementedError("Subclasses of LighterBaseWriter must implement the `_writers` attribute.") + # Prediction counter. Used when IDs are not provided. Initialized in `self.setup()` based on the DDP rank. + self._pred_counter = None - # Register any additional writers passed during initialization - if additional_writers: - for format, writer_function in additional_writers.items(): - self.add_writer(format, writer_function) + @property + @abstractmethod + def writers(self) -> Dict[str, Callable]: + """ + Property to define the default writer functions. + """ @abstractmethod - def write( - self, - tensor: torch.Tensor, - id: int, - multi_pred_id: Optional[str], - format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - ) -> None: + def write(self, tensor: torch.Tensor, id: int) -> None: """ Method to define how a tensor should be saved. The input tensor will be a single tensor without the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving, either in this method or in the particular writer function. - For each supported format, there should be a corresponding writer function registered in `self._writers` - A specific writer function can be retrieved using `self.get_writer(format)`. + For each supported format, there should be a corresponding writer function registered in `self.writers` + A specific writer function can be retrieved using `self.get_writer(self.format)`. Args: tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension. id (int): Identifier for the tensor, can be used for naming or indexing. - multi_pred_id (Optional[str]): Used when there are multiple predictions for a single input. - It can represent the index of a prediction, the key of a prediction in case of a dict, - or combined key and index for a dict of lists. - format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Format for saving the tensor. """ - pass def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: """ @@ -81,13 +74,13 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: When executing in a distributed environment, it ensures that: 1. Each distributed node initializes a prediction count based on its rank. 2. All distributed nodes write predictions to the same directory. - 3. The directiory is accessible to all nodes, i.e. that all nodes share the same storage. + 3. The directory is accessible to all nodes, i.e., all nodes share the same storage. """ if stage != "predict": return # Initialize the prediction count with the rank of the current process - self._pred_count = torch.distributed.get_rank() if trainer.world_size > 1 else 0 + self._pred_counter = torch.distributed.get_rank() if trainer.world_size > 1 else 0 # Ensure all distributed nodes write to the same directory self.directory = trainer.strategy.broadcast(self.directory, src=0) @@ -107,71 +100,16 @@ def on_predict_batch_end( ) -> None: """ Callback method executed at the end of each prediction batch/step. - - It decollates the predicted outputs and, if provided, the associated IDs. If the IDs are not provided, it generates global unique IDs based on the prediction count. - Finally, it writes the predictions according to the specified format. + Finally, it writes the predictions using the specified writer. """ - # Fetch and decollate preds. - preds = decollate_batch(outputs["pred"], detach=True, pad=False) - # Fetch and decollate IDs if provided. - if outputs["id"] is not None: - ids = decollate_batch(outputs["id"], detach=True, pad=False) - # Generate IDs if not provided. An ID will be the global index of the prediction. - else: - ids = [] - for _ in range(len(preds)): - # Append the current prediction count to the IDs list. - ids.append(self._pred_count) - # Increment the prediction count by the total number of DDP processes. - # This ensures each process will generate unique IDs in the next batch. - self._pred_count += trainer.world_size - - # Iterate over the predictions and save them. - for id, pred in zip(ids, preds): - # Convert predictions into a structured format suitable for writing. - pred = flatten_structure(pred) - - # If a single format is provided, assign it to all pred keys. Otherwise, the format must match the pred structure. - format = {key: self.format for key in pred} if isinstance(self.format, str) else flatten_structure(self.format) - # Ensure that the format structure matches the prediction structure. - if not set(format) == set(pred): - raise ValueError("`format` structure does not match the prediction's structure.") - - # If pred is multi-output, there will be a `multi_pred_id` for each output. - # If single-output, `multi_pred_id` will be None. - for multi_pred_id, tensor in pred.items(): - # Save the prediction as per the designated format. - self.write(tensor, id, multi_pred_id, format=format[multi_pred_id]) - - def add_writer(self, format: str, writer_function: Callable) -> None: - """ - Register a new writer function for a specified format. - - Args: - format (str): Format type for which the writer is being registered. - writer_function (Callable): Function to write data in the given format. - Raises: - ValueError: If a writer for the given format is already registered. - """ - if format in self._writers: - raise ValueError(f"Writer for format {format} already registered.") - self._writers[format] = writer_function + # If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported. + if outputs["id"] is None: + batch_size = len(outputs["pred"]) + world_size = trainer.world_size + outputs["id"] = list(range(self._pred_counter, self._pred_counter + batch_size * world_size, world_size)) + self._pred_counter += batch_size * world_size - def get_writer(self, format: str) -> Callable: - """ - Retrieve the registered writer function for a specified format. - - Args: - format (str): Format for which the writer function is needed. - - Returns: - Callable: Registered writer function for the given format. - - Raises: - ValueError: If no writer is registered for the specified format. - """ - if format not in self._writers: - raise ValueError(f"Writer for format {format} not registered.") - return self._writers[format] + for id, pred in zip(outputs["id"], outputs["pred"]): + self.write(tensor=pred, id=id) diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index e4e7c03b..05534d3a 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -1,11 +1,13 @@ -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Union +from functools import partial from pathlib import Path +import monai import torch import torchvision +from monai.data import metatensor_to_itk_image from monai.transforms import DivisiblePad -from monai.utils.module import optional_import from lighter.callbacks.utils import preprocess_image from lighter.callbacks.writer.base import LighterBaseWriter @@ -20,45 +22,38 @@ class LighterFileWriter(LighterBaseWriter): Args: directory (Union[str, Path]): The directory where the files should be written. - format (str): The format in which the files should be saved. - additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + Available writers: "tensor", "image", "video", "itk_nrrd", "itk_seg_nrrd", "itk_nifti". """ - def __init__( - self, directory: Union[str, Path], format: str, additional_writers: Optional[Dict[str, Callable]] = None - ) -> None: - # Predefined writers for different formats. - self._writers = { + def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None: + super().__init__(directory, writer) + + @property + def writers(self) -> Dict[str, Callable]: + return { "tensor": write_tensor, "image": write_image, "video": write_video, - "sitk_nrrd": write_sitk_nrrd, - "sitk_seg_nrrd": write_seg_nrrd, - "sitk_nifti": write_sitk_nifti, + "itk_nrrd": partial(write_itk_image, suffix=".nrrd"), + "itk_seg_nrrd": partial(write_itk_image, suffix=".seg.nrrd"), + "itk_nifti": partial(write_itk_image, suffix=".nii.gz"), } - # Initialize the base class. - super().__init__(directory, format, additional_writers) - def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: + def write(self, tensor: torch.Tensor, id: Union[int, str]) -> None: """ Write the tensor to the specified path in the given format. - If there are multiple predictions, a directory named `id` is created, and each file is named - after `multi_pred_id`. If there's a single prediction, the file is named after `id`. - Args: tensor (Tensor): The tensor to be written. - id (Union[int, str]): The primary identifier for naming. - multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions. + id (Union[int, str]): The identifier for naming. format (str): Format in which tensor should be written. """ # Determine the path for the file based on prediction count. The suffix must be added by the writer function. - path = self.directory / str(id) if multi_pred_id is None else self.directory / str(id) / str(multi_pred_id) + path = self.directory / str(id) path.parent.mkdir(exist_ok=True, parents=True) - # Fetch the appropriate writer function for the format. - writer = self.get_writer(format) # Write the tensor to the file. - writer(path, tensor) + self.writer(path, tensor) def write_tensor(path, tensor): @@ -85,27 +80,12 @@ def write_video(path, tensor): torchvision.io.write_video(str(path), tensor, fps=24) -def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None: +def write_itk_image(path: str, tensor: torch.Tensor, suffix) -> None: path = path.with_suffix(suffix) - if "sitk" not in OPTIONAL_IMPORTS: - OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK") - if not sitk_available: - raise ImportError("SimpleITK is not available. Install it with `pip install SimpleITK`.") - - # Remove the channel dimension if it's equal to 1. - tensor = tensor.squeeze(0) if (tensor.dim() == 4 and tensor.shape[0] == 1) else tensor - sitk_image = OPTIONAL_IMPORTS["sitk"].GetImageFromArray(tensor.cpu().numpy()) - OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path), useCompression=True) - - -def write_sitk_nrrd(path, tensor): - _write_sitk_image(path, tensor, suffix=".nrrd") - - -def write_seg_nrrd(path, tensor): - _write_sitk_image(path, tensor, suffix=".seg.nrrd") - + # TODO: Remove this code when fixed https://github.com/Project-MONAI/MONAI/issues/6985 + if tensor.meta["space"] == "RAS": + tensor.affine = monai.data.utils.orientation_ras_lps(tensor.affine) -def write_sitk_nifti(path, tensor): - _write_sitk_image(path, tensor, suffix=".nii.gz") + itk_image = metatensor_to_itk_image(tensor, channel_dim=0, dtype=tensor.dtype) + OPTIONAL_IMPORTS["itk"].imwrite(itk_image, str(path), True) diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index f65e399b..1b7efccd 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Union import itertools from pathlib import Path @@ -13,58 +13,36 @@ class LighterTableWriter(LighterBaseWriter): """ - Writer for saving predictions in a table format. Supports multiple formats, and - additional custom formats can be added either through `additional_writers` - argument at initialization, or by calling `add_writer` method after initialization. + Writer for saving predictions in a table format. Args: directory (Path): The directory where the CSV will be saved. - format (str): The format in which the data should be saved in the CSV. - additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + Available writers: "tensor". """ - def __init__( - self, directory: Union[str, Path], format: str, additional_writers: Optional[Dict[str, Callable]] = None - ) -> None: - # Predefined writers for different formats. - self._writers = { - "tensor": write_tensor, - } - - # Initialize the base class. - super().__init__(directory, format, additional_writers) - - # Create a dictionary to hold CSV records for each ID. These are populated at each batch end - # by `self.on_predict_batch_end` defined in the base class using the `write` method below. - # Finally, the records are dumped to a CSV file at the end of the epoch by `self.on_predict_epoch_end`. + def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None: + super().__init__(directory, writer) self.csv_records = {} - def write(self, tensor: Any, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None: + @property + def writers(self) -> Dict[str, Callable]: + return { + "tensor": lambda tensor: tensor.tolist(), + } + + def write(self, tensor: Any, id: Union[int, str]) -> None: """ - Write the tensor as a table record in the given format. - If there are multiple predictions, there will be a separate column for each prediction, named after - the corresponding `multi_pred_id`. If single prediction, there will be a single column named "pred". + Write the tensor as a table record using the specified writer. Args: tensor (Any): The tensor to be written. - id (Union[int, str]): The primary identifier for naming. - multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions. - format (str): Format in which tensor should be written. + id (Union[int, str]): The identifier used as the key for the record. """ - # Determine the column name based on the presence of multi_pred_id - column = "pred" if multi_pred_id is None else multi_pred_id + column = "pred" + record = self.writer(tensor) - # Get the appropriate writer function for the given format - writer = self.get_writer(format) - - # Convert the tensor to the desired format (e.g., list) - record = writer(tensor) - - # Store the record in the csv_records dictionary under the specified ID and column - if id not in self.csv_records: - self.csv_records[id] = {column: record} - else: - self.csv_records[id][column] = record + self.csv_records.setdefault(id, {})[column] = record def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """ @@ -74,7 +52,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No If training was done in a distributed setting, it gathers predictions from all processes and then saves them from the rank 0 process. """ - # Set the path where the CSV will be saved csv_path = self.directory / "predictions.csv" # Sort the records by ID and convert the dictionary to a list @@ -93,7 +70,3 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No # Save the records to a CSV file if trainer.is_global_zero: pd.DataFrame(self.csv_records).to_csv(csv_path) - - -def write_tensor(tensor: Any) -> List: - return tensor.tolist() diff --git a/lighter/utils/dynamic_imports.py b/lighter/utils/dynamic_imports.py index be690ea6..bf8e57ae 100644 --- a/lighter/utils/dynamic_imports.py +++ b/lighter/utils/dynamic_imports.py @@ -1,20 +1,65 @@ -from typing import Any +from typing import Dict import importlib import sys +from dataclasses import dataclass, field from pathlib import Path from loguru import logger +from monai.utils.module import optional_import -OPTIONAL_IMPORTS = {} + +@dataclass +class OptionalImports: + """Dataclass for handling optional imports. + + This class provides a way to handle optional imports in a convenient manner. + It allows importing modules that may or may not be available, and raises an ImportError if the module is not available. + + Example: :: + from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS + writer = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter() + + Attributes: + imports (Dict[str, object]): A dictionary to store the imported modules. + """ + + imports: Dict[str, object] = field(default_factory=dict) + + def __getitem__(self, module_name: str): + """Get the imported module by name. + + Args: + module_name (str): The name of the module to import. + + Raises: + ImportError: If the module is not available. + + Returns: + object: The imported module. + """ + if module_name not in self.imports: + self.imports[module_name], module_available = optional_import(module_name) + if not module_available: + raise ImportError(f"'{module_name}' is not available. Make sure that it is installed and spelled correctly.") + return self.imports[module_name] + + +OPTIONAL_IMPORTS = OptionalImports() def import_module_from_path(module_name: str, module_path: str) -> None: - """Given the path to a module, import it, and name it as specified. + """Import a module from a given path and assign it a specified name. + + This function imports a module from the specified path and assigns it the specified name. Args: - module_name (str): what to name the imported module. - module_path (str): path to the module to load. + module_name (str): The name to assign to the imported module. + module_path (str): The path to the module to import. + + Raises: + ValueError: If the module has already been imported. + FileNotFoundError: If the `__init__.py` file is not found in the module path. """ # Based on https://stackoverflow.com/a/41595552. @@ -29,20 +74,3 @@ def import_module_from_path(module_name: str, module_path: str) -> None: spec.loader.exec_module(module) sys.modules[module_name] = module logger.info(f"{module_path.parent} imported as '{module_name}' module.") - - -def import_attr(module_attr: str) -> Any: - """Import using dot-notation string, e.g., 'torch.nn.Module'. - - Args: - module_attr (str): dot-notation path to the attribute. - - Returns: - Any: imported attribute. - """ - # Split module from attribute name - module, attr = module_attr.rsplit(".", 1) - # Import the module - module = __import__(module, fromlist=[attr]) - # Get the attribute from the module - return getattr(module, attr) From 9ca6f7dabbe72d4674d4c84e1f6bc913d2ddc9b2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 20:45:31 +0000 Subject: [PATCH 18/20] Bump gitpython from 3.1.32 to 3.1.35 Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.32 to 3.1.35. - [Release notes](https://github.com/gitpython-developers/GitPython/releases) - [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES) - [Commits](https://github.com/gitpython-developers/GitPython/compare/3.1.32...3.1.35) --- updated-dependencies: - dependency-name: gitpython dependency-type: indirect ... Signed-off-by: dependabot[bot] --- poetry.lock | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 091fd1ed..e90aec6b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -744,13 +744,13 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.32" +version = "3.1.35" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"}, - {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"}, + {file = "GitPython-3.1.35-py3-none-any.whl", hash = "sha256:c19b4292d7a1d3c0f653858db273ff8a6614100d1eb1528b014ec97286193c09"}, + {file = "GitPython-3.1.35.tar.gz", hash = "sha256:9cbefbd1789a5fe9bcf621bb34d3f441f3a90c8461d377f84eda73e721d9b06b"}, ] [package.dependencies] @@ -1190,6 +1190,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1782,7 +1792,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -2255,6 +2265,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2262,8 +2273,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2280,6 +2298,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2287,6 +2306,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, From 01bfcd33a64d1268abd58f2c38312358dada949d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 20:45:34 +0000 Subject: [PATCH 19/20] Bump certifi from 2023.5.7 to 2023.7.22 Bumps [certifi](https://github.com/certifi/python-certifi) from 2023.5.7 to 2023.7.22. - [Commits](https://github.com/certifi/python-certifi/compare/2023.05.07...2023.07.22) --- updated-dependencies: - dependency-name: certifi dependency-type: indirect ... Signed-off-by: dependabot[bot] --- poetry.lock | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 091fd1ed..0364466a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -285,13 +285,13 @@ files = [ [[package]] name = "certifi" -version = "2023.5.7" +version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, ] [[package]] @@ -1190,6 +1190,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1782,7 +1792,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -2255,6 +2265,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2262,8 +2273,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2280,6 +2298,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2287,6 +2306,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, From 7ccce957bbca1135a7055928a7626cf01bf3f2d5 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Thu, 14 Sep 2023 17:24:10 -0400 Subject: [PATCH 20/20] Remove add_batch_dim --- lighter/callbacks/logger.py | 1 - lighter/callbacks/utils.py | 6 +----- lighter/callbacks/writer/file.py | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 32fb1935..1c34c558 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -171,7 +171,6 @@ def _log_histogram(self, name: str, tensor: torch.Tensor, global_step: int) -> N global_step (int): current global step. """ tensor = tensor.detach().cpu() - if self.tensorboard: self.tensorboard.add_histogram(name, tensor, global_step=global_step) if self.wandb: diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 8b8b39f0..bdadb930 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -15,20 +15,16 @@ def get_lighter_mode(lightning_stage: str) -> str: return lightning_to_lighter[lightning_stage] -def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor: +def preprocess_image(image: torch.Tensor) -> torch.Tensor: """Preprocess the image before logging it. If it is a batch of multiple images, it will create a grid image of them. In case of 3D, a single image is displayed with slices stacked vertically, while a batch of 3D images as a grid where each column is a different 3D image. Args: image (torch.Tensor): 2D or 3D image tensor. - add_batch_dim (bool, optional): Whether to add a batch dimension to the input image. - Use only when the input image does not have a batch dimension. Defaults to False. Returns: torch.Tensor: image ready for logging. """ - if add_batch_dim: - image = image.unsqueeze(0) # If 3D (BCDHW), concat the images vertically and horizontally. if image.ndim == 5: shape = image.shape diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index 05534d3a..c2a0d579 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -62,7 +62,7 @@ def write_tensor(path, tensor): def write_image(path, tensor): path = path.with_suffix(".png") - tensor = preprocess_image(tensor, add_batch_dim=True) + tensor = preprocess_image(tensor) torchvision.io.write_png(tensor, path)