From f25361e24b462a947972004ff6c4ac1d2bcbd028 Mon Sep 17 00:00:00 2001 From: Ibrahim Hadzic Date: Sun, 14 May 2023 18:26:03 -0400 Subject: [PATCH] Simplify batch splitting, add None collate support for when target is not expected (#65) * parse_data works recursively * [WIP] simplify batch splitting and data to model and criterion parsing * Collate support for None, allowing returning None as target in datasets * Remove the check if target is None * Update logger.py --- lighter/callbacks/logger.py | 1 + lighter/callbacks/utils.py | 59 +++++++++++++-------- lighter/system.py | 100 ++++++++++++------------------------ lighter/utils/collate.py | 14 +++-- lighter/utils/misc.py | 12 ----- 5 files changed, 80 insertions(+), 106 deletions(-) diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 4023e4d6..e7f689b9 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -68,6 +68,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: return self.log_dir.mkdir(parents=True) + logger.info(f"Logging to {self.log_dir}") # Loguru log file. logger.add(sink=self.log_dir / f"{stage}.log") diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 4306e051..757bce68 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -17,17 +17,19 @@ def get_lighter_mode(lightning_stage: str) -> str: return lightning_to_lighter[lightning_stage] -def is_data_type_supported(data: Any) -> bool: - """Check the input data for its type. Valid data types are: +def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]) -> bool: + """ + Check the input data recursively for its type. Valid data types are: - torch.Tensor - List[torch.Tensor] - Tuple[torch.Tensor] - Dict[str, torch.Tensor] - Dict[str, List[torch.Tensor]] - Dict[str, Tuple[torch.Tensor]] + - Nested combinations of the above Args: - data (Any): input data to check + data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to check. Returns: bool: True if the data type is supported, False otherwise. @@ -44,36 +46,51 @@ def is_data_type_supported(data: Any) -> bool: def parse_data( - data: Union[Any, List[Any], Dict[str, Any], Dict[str, List[Any]], Dict[str, Tuple[Any]]] + data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None ) -> Dict[Optional[str], Any]: - """Parse the input data as follows: - - If dict, go over all keys and values, unpacking list and tuples, and assigning them all - a unique identifier based on the original key and their position if they were a list/tuple. - - If list/tuple, enumerate them and use their position as key for each value of the list/tuple. - - If any other type, return it as-is with the key set to 'None'. A 'None' key indicates that no - identifier is needed because no parsing ocurred. + """ + Parse the input data recursively, handling nested dictionaries, lists, and tuples. + + This function will recursively parse the input data, unpacking nested dictionaries, lists, and tuples. The result + will be a dictionary where each key is a unique identifier reflecting the data's original structure (dict keys + or list/tuple positions) and each value is a non-container data type from the input data. Args: - data (Any, List[Any], Dict[str, Any], Dict[str, List[Any]], Dict[str, Tuple[Any]): - input data to parse. + data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to parse. + prefix (Optional[str]): Current prefix for keys in the result dictionary. Defaults to None. Returns: - Dict[Optional[str], Any]: a dict where key is either a string - identifier or `None`, and value the parsed output. + Dict[Optional[str], Any]: A dictionary where key is either a string identifier or `None`, and value is the parsed output. + + Example: + input_data = { + "a": [1, 2], + "b": {"c": (3, 4), "d": 5} + } + output_data = parse_data(input_data) + # Output: + # { + # 'a_0': 1, + # 'a_1': 2, + # 'b_c_0': 3, + # 'b_c_1': 4, + # 'b_d': 5 + # } """ result = {} if isinstance(data, dict): for key, value in data.items(): - if isinstance(value, (list, tuple)): - for idx, element in enumerate(value): - result[f"{key}_{idx}" if len(value) > 1 else key] = element - else: - result[key] = value + # Recursively parse the value with an updated prefix + sub_result = parse_data(value, prefix=f"{prefix}_{key}" if prefix else key) + result.update(sub_result) elif isinstance(data, (list, tuple)): for idx, element in enumerate(data): - result[str(idx)] = element + # Recursively parse the element with an updated prefix + sub_result = parse_data(element, prefix=f"{prefix}_{idx}" if prefix else str(idx)) + result.update(sub_result) else: - result[None] = data + # Assign the value to the result dictionary using the current prefix as its key + result[prefix] = data return result diff --git a/lighter/system.py b/lighter/system.py index faf73ac3..327274fa 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -11,8 +11,8 @@ from torch.utils.data import DataLoader, Dataset, Sampler from torchmetrics import Metric, MetricCollection -from lighter.utils.collate import collate_fn_replace_corrupted -from lighter.utils.misc import countargs, ensure_list, get_name, hasarg +from lighter.utils.collate import collate_replace_corrupted +from lighter.utils.misc import ensure_list, get_name, hasarg from lighter.utils.model import reshape_pred_if_single_value_prediction @@ -165,6 +165,10 @@ def forward(self, input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Ten Returns: Any: output of the model. """ + # Freeze the layers if specified so. + if self.freezer is not None: + self.freezer(self.model, self.global_step, self.current_epoch) + # Keyword arguments to pass to the forward method kwargs = {} if hasarg(self.model.forward, "epoch"): @@ -174,24 +178,7 @@ def forward(self, input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Ten # Add `step` argument if forward accepts it kwargs["step"] = self.global_step - # Type not supported. - if not isinstance(input, (torch.Tensor, tuple, list, dict)): - logger.error(f"Input type '{type(input)}' not supported.") - sys.exit() - - # Freeze the layers if specified so. - if self.freezer is not None: - self.freezer(self.model, self.global_step, self.current_epoch) - - # Unpack Tuple or List. Only if num of args passed is less than or equal to num of args accepted. - if isinstance(input, (tuple, list)) and (len(input) + len(kwargs)) <= countargs(self.model): - return self.model(*input, **kwargs) - # Unpack Dict. Only if dict's keys match criterion's keyword arguments. - elif isinstance(input, dict) and all(hasarg(self.model, name) for name in input): - return self.model(**input, **kwargs) - # Tensor, Tuple, List, or Dict, as-is, not unpacked. - else: - return self.model(input, **kwargs) + return self.model(input, **kwargs) def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Union[Dict[str, Any], Any]: """Base step for all modes ("train", "val", "test", "predict") @@ -209,8 +196,27 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un For predict step, it returns pred only. """ - # Split the batch into input and target. Target will be `None` if not provided. - input, target = self._split_batch(batch) + # Ensure that the batch is a list, a tuple, or a dict. + if not isinstance(batch, (list, tuple, dict)): + raise TypeError( + "A batch must be a list, a tuple, or a dict." + "A batch dict must have 'input' and 'target' as keys." + "A batch list or a tuple must have 2 elements - input and target." + "If target does not exist, return `None` as target." + ) + # Ensure that a dict batch has input and target keys exclusively. + if isinstance(batch, dict) and set(batch.keys()) != {"input", "target"}: + raise ValueError("A batch must be a dict with 'input' and 'target' as keys.") + # Ensure that a list/tuple batch has 2 elements (input and target). + if len(batch) == 1: + raise ValueError( + "A batch must consist of 2 elements - input and target. If target does not exist, return `None` as target." + ) + if len(batch) > 2: + raise ValueError(f"A batch must consist of 2 elements - input and target, but found {len(batch)} elements.") + + # Split the batch into input and target. + input, target = batch if not isinstance(batch, dict) else (batch["input"], batch["target"]) # Forward if self.inferer and mode in ["val", "test", "predict"]: @@ -258,8 +264,8 @@ def _calculate_loss( # Keyword arguments to pass to the loss/criterion function kwargs = {} if hasarg(self.criterion.forward, "target"): - # Add `target` argument if forward accepts it. Casting performed if specified. - kwargs["target"] = target.to(self._cast_target_dtype_to) + # Add `target` argument if forward accepts it. Cast it if it is a tensor and if the target type is specified. + kwargs["target"] = target if not isinstance(target, torch.Tensor) else target.to(self._cast_target_dtype_to) else: if not self._target_not_used_reported and not self.trainer.sanity_checking: self._target_not_used_reported = True @@ -272,21 +278,7 @@ def _calculate_loss( "behavior you expected, redefine your criterion " "so that it has a `target` argument." ) - - # Type not supported. - if not isinstance(pred, (torch.Tensor, tuple, list, dict)): - logger.error(f"Pred type '{type(pred)}' not supported.") - sys.exit() - - # Unpack Tuple or List. Only if num of args passed is less than or equal to num of args accepted. - if isinstance(pred, (tuple, list)) and (len(pred) + len(kwargs)) <= countargs(self.criterion): - return self.criterion(*pred, **kwargs) - # Unpack Dict. Only if dict's keys match criterion's keyword arguments' names. - elif isinstance(pred, dict) and all(hasarg(self.criterion, name) for name in pred): - return self.criterion(**pred, **kwargs) - # Tensor, Tuple, List, or Dict, as-is, not unpacked. - else: - return self.criterion(pred, **kwargs) + return self.criterion(pred, **kwargs) def _base_dataloader(self, mode: str) -> DataLoader: """Instantiate the dataloader for a mode (train/val/test/predict). @@ -319,7 +311,7 @@ def _base_dataloader(self, mode: str) -> DataLoader: # A dataset can return None when a corrupted example occurs. This collate # function replaces None's with valid examples from the dataset. - collate_fn = partial(collate_fn_replace_corrupted, dataset=dataset, default_collate_fn=collate_fn) + collate_fn = partial(collate_replace_corrupted, dataset=dataset, default_collate_fn=collate_fn) return DataLoader( dataset, sampler=sampler, @@ -395,34 +387,6 @@ def setup(self, stage: str) -> None: self.predict_dataloader = partial(self._base_dataloader, mode="predict") self.predict_step = partial(self._base_step, mode="predict") - def _split_batch(self, batch) -> Tuple[torch.Tensor, Optional[Any]]: - """Split the batch into input and target. Target will be `None` if not provided. - - Args: - batch (List, Tuple): output of the DataLoader and input to the model. - - Returns: - Tuple(torch.Tensor, Optional[Any]): input and target. - """ - # Check if the batch format is correct. - if len(batch) > 2: - raise ValueError( - "Found more than 2 items in the batch. `LighterSystem` requires the dataloader to return either " - "input tensor(s) only, or a two-element tuple/list consisting of input tensor(s) and target(s)." - ) - - # Report the batch split type. Only on the first call. - if not self._batch_type_reported: - self._batch_type_reported = True - if len(batch) == 1: - logger.info("Target not provided. Using `None` as target. Ignore if intended.") - else: - logger.info("Using the first item as input and the second item as target.") - - # Split the batch into input and target. Target will be `None` if not provided. - input, target = (batch, None) if len(batch) == 1 else batch - return input, target - def _init_placeholders_for_dataloader_and_step_methods(self) -> None: """`LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods in the `self.setup()` method. However, `LightningModule` excepts them to be defined at diff --git a/lighter/utils/collate.py b/lighter/utils/collate.py index d2c462d3..b0bd8238 100644 --- a/lighter/utils/collate.py +++ b/lighter/utils/collate.py @@ -4,27 +4,31 @@ import torch from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import collate_str_fn, default_collate_fn_map from torch.utils.data.dataloader import default_collate +# Collate support for None. Just as a string, None is not collated. Allows elements of the batch to be None. +default_collate_fn_map.update({type(None): collate_str_fn}) -def collate_fn_replace_corrupted(batch: List[Any], dataset: DataLoader, default_collate_fn: Callable = None) -> torch.Tensor: + +def collate_replace_corrupted(batch: Any, dataset: DataLoader, default_collate_fn: Callable = None) -> Any: """Collate function that allows to replace corrupted examples in the batch. The dataloader should return `None` when that occurs. The `None`s in the batch are replaced with other, randomly-selected, examples. Args: - batch (List[Any]): batch from the DataLoader. + batch (Any): batch from the DataLoader. dataset (Dataset): dataset that the DataLoader is passing through. Needs to be fixed in place with functools.partial before passing it to DataLoader's 'collate_fn' option as 'collate_fn' should only have a single argument - batch. Example: ``` - collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)` + collate_fn = functools.partial(collate_replace_corrupted, dataset=dataset)` loader = DataLoader(dataset, ..., collate_fn=collate_fn). ``` default_collate_fn (Callable): the collate function to call once the batch has no corrupted examples. If `None`, `torch.utils.data.dataloader.default_collate` is called. Defaults to None. Returns: - torch.Tensor: batch with new examples instead of corrupted ones. + Any: batch with new examples instead of corrupted ones. """ # Use `torch.utils.data.dataloader.default_collate` if no other default collate function is specified. default_collate_fn = default_collate_fn if default_collate_fn is not None else default_collate @@ -39,6 +43,6 @@ def collate_fn_replace_corrupted(batch: List[Any], dataset: DataLoader, default_ # Replace a corrupted example with another randomly selected example. batch.extend([dataset[random.randint(0, len(dataset) - 1)] for _ in range(num_corrupted)]) # Recursive call to replace the replacements if they are corrupted. - return collate_fn_replace_corrupted(batch, dataset) + return collate_replace_corrupted(batch, dataset) # Finally, when the whole batch is fine, apply the default collate function. return default_collate_fn(batch) diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index a1dbfbce..b4927c94 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -58,18 +58,6 @@ def hasarg(_callable: Callable, arg_name: str) -> bool: return arg_name in args -def countargs(_callable: Callable) -> bool: - """Count the number of arguments that a function, class, or method accepts. - - Args: - callable (Callable): function, class, or method to inspect. - - Returns: - int: number of arguments that it accepts. - """ - return len(inspect.signature(_callable).parameters.keys()) - - def get_name(_callable: Callable, include_module_name: bool = False) -> str: """Get the name of an object, class or function.