Skip to content

Commit

Permalink
Simplify batch splitting, add None collate support for when target is…
Browse files Browse the repository at this point in the history
… not expected (#65)

* parse_data works recursively

* [WIP] simplify batch splitting and data to model and criterion parsing

* Collate support for None, allowing returning None as target in datasets

* Remove the check if target is None

* Update logger.py
  • Loading branch information
ibro45 authored May 14, 2023
1 parent 1876e71 commit f25361e
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 106 deletions.
1 change: 1 addition & 0 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
return

self.log_dir.mkdir(parents=True)
logger.info(f"Logging to {self.log_dir}")

# Loguru log file.
logger.add(sink=self.log_dir / f"{stage}.log")
Expand Down
59 changes: 38 additions & 21 deletions lighter/callbacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ def get_lighter_mode(lightning_stage: str) -> str:
return lightning_to_lighter[lightning_stage]


def is_data_type_supported(data: Any) -> bool:
"""Check the input data for its type. Valid data types are:
def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]) -> bool:
"""
Check the input data recursively for its type. Valid data types are:
- torch.Tensor
- List[torch.Tensor]
- Tuple[torch.Tensor]
- Dict[str, torch.Tensor]
- Dict[str, List[torch.Tensor]]
- Dict[str, Tuple[torch.Tensor]]
- Nested combinations of the above
Args:
data (Any): input data to check
data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to check.
Returns:
bool: True if the data type is supported, False otherwise.
Expand All @@ -44,36 +46,51 @@ def is_data_type_supported(data: Any) -> bool:


def parse_data(
data: Union[Any, List[Any], Dict[str, Any], Dict[str, List[Any]], Dict[str, Tuple[Any]]]
data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None
) -> Dict[Optional[str], Any]:
"""Parse the input data as follows:
- If dict, go over all keys and values, unpacking list and tuples, and assigning them all
a unique identifier based on the original key and their position if they were a list/tuple.
- If list/tuple, enumerate them and use their position as key for each value of the list/tuple.
- If any other type, return it as-is with the key set to 'None'. A 'None' key indicates that no
identifier is needed because no parsing ocurred.
"""
Parse the input data recursively, handling nested dictionaries, lists, and tuples.
This function will recursively parse the input data, unpacking nested dictionaries, lists, and tuples. The result
will be a dictionary where each key is a unique identifier reflecting the data's original structure (dict keys
or list/tuple positions) and each value is a non-container data type from the input data.
Args:
data (Any, List[Any], Dict[str, Any], Dict[str, List[Any]], Dict[str, Tuple[Any]):
input data to parse.
data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to parse.
prefix (Optional[str]): Current prefix for keys in the result dictionary. Defaults to None.
Returns:
Dict[Optional[str], Any]: a dict where key is either a string
identifier or `None`, and value the parsed output.
Dict[Optional[str], Any]: A dictionary where key is either a string identifier or `None`, and value is the parsed output.
Example:
input_data = {
"a": [1, 2],
"b": {"c": (3, 4), "d": 5}
}
output_data = parse_data(input_data)
# Output:
# {
# 'a_0': 1,
# 'a_1': 2,
# 'b_c_0': 3,
# 'b_c_1': 4,
# 'b_d': 5
# }
"""
result = {}
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (list, tuple)):
for idx, element in enumerate(value):
result[f"{key}_{idx}" if len(value) > 1 else key] = element
else:
result[key] = value
# Recursively parse the value with an updated prefix
sub_result = parse_data(value, prefix=f"{prefix}_{key}" if prefix else key)
result.update(sub_result)
elif isinstance(data, (list, tuple)):
for idx, element in enumerate(data):
result[str(idx)] = element
# Recursively parse the element with an updated prefix
sub_result = parse_data(element, prefix=f"{prefix}_{idx}" if prefix else str(idx))
result.update(sub_result)
else:
result[None] = data
# Assign the value to the result dictionary using the current prefix as its key
result[prefix] = data
return result


Expand Down
100 changes: 32 additions & 68 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torch.utils.data import DataLoader, Dataset, Sampler
from torchmetrics import Metric, MetricCollection

from lighter.utils.collate import collate_fn_replace_corrupted
from lighter.utils.misc import countargs, ensure_list, get_name, hasarg
from lighter.utils.collate import collate_replace_corrupted
from lighter.utils.misc import ensure_list, get_name, hasarg
from lighter.utils.model import reshape_pred_if_single_value_prediction


Expand Down Expand Up @@ -165,6 +165,10 @@ def forward(self, input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Ten
Returns:
Any: output of the model.
"""
# Freeze the layers if specified so.
if self.freezer is not None:
self.freezer(self.model, self.global_step, self.current_epoch)

# Keyword arguments to pass to the forward method
kwargs = {}
if hasarg(self.model.forward, "epoch"):
Expand All @@ -174,24 +178,7 @@ def forward(self, input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Ten
# Add `step` argument if forward accepts it
kwargs["step"] = self.global_step

# Type not supported.
if not isinstance(input, (torch.Tensor, tuple, list, dict)):
logger.error(f"Input type '{type(input)}' not supported.")
sys.exit()

# Freeze the layers if specified so.
if self.freezer is not None:
self.freezer(self.model, self.global_step, self.current_epoch)

# Unpack Tuple or List. Only if num of args passed is less than or equal to num of args accepted.
if isinstance(input, (tuple, list)) and (len(input) + len(kwargs)) <= countargs(self.model):
return self.model(*input, **kwargs)
# Unpack Dict. Only if dict's keys match criterion's keyword arguments.
elif isinstance(input, dict) and all(hasarg(self.model, name) for name in input):
return self.model(**input, **kwargs)
# Tensor, Tuple, List, or Dict, as-is, not unpacked.
else:
return self.model(input, **kwargs)
return self.model(input, **kwargs)

def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Union[Dict[str, Any], Any]:
"""Base step for all modes ("train", "val", "test", "predict")
Expand All @@ -209,8 +196,27 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un
For predict step, it returns pred only.
"""
# Split the batch into input and target. Target will be `None` if not provided.
input, target = self._split_batch(batch)
# Ensure that the batch is a list, a tuple, or a dict.
if not isinstance(batch, (list, tuple, dict)):
raise TypeError(
"A batch must be a list, a tuple, or a dict."
"A batch dict must have 'input' and 'target' as keys."
"A batch list or a tuple must have 2 elements - input and target."
"If target does not exist, return `None` as target."
)
# Ensure that a dict batch has input and target keys exclusively.
if isinstance(batch, dict) and set(batch.keys()) != {"input", "target"}:
raise ValueError("A batch must be a dict with 'input' and 'target' as keys.")
# Ensure that a list/tuple batch has 2 elements (input and target).
if len(batch) == 1:
raise ValueError(
"A batch must consist of 2 elements - input and target. If target does not exist, return `None` as target."
)
if len(batch) > 2:
raise ValueError(f"A batch must consist of 2 elements - input and target, but found {len(batch)} elements.")

# Split the batch into input and target.
input, target = batch if not isinstance(batch, dict) else (batch["input"], batch["target"])

# Forward
if self.inferer and mode in ["val", "test", "predict"]:
Expand Down Expand Up @@ -258,8 +264,8 @@ def _calculate_loss(
# Keyword arguments to pass to the loss/criterion function
kwargs = {}
if hasarg(self.criterion.forward, "target"):
# Add `target` argument if forward accepts it. Casting performed if specified.
kwargs["target"] = target.to(self._cast_target_dtype_to)
# Add `target` argument if forward accepts it. Cast it if it is a tensor and if the target type is specified.
kwargs["target"] = target if not isinstance(target, torch.Tensor) else target.to(self._cast_target_dtype_to)
else:
if not self._target_not_used_reported and not self.trainer.sanity_checking:
self._target_not_used_reported = True
Expand All @@ -272,21 +278,7 @@ def _calculate_loss(
"behavior you expected, redefine your criterion "
"so that it has a `target` argument."
)

# Type not supported.
if not isinstance(pred, (torch.Tensor, tuple, list, dict)):
logger.error(f"Pred type '{type(pred)}' not supported.")
sys.exit()

# Unpack Tuple or List. Only if num of args passed is less than or equal to num of args accepted.
if isinstance(pred, (tuple, list)) and (len(pred) + len(kwargs)) <= countargs(self.criterion):
return self.criterion(*pred, **kwargs)
# Unpack Dict. Only if dict's keys match criterion's keyword arguments' names.
elif isinstance(pred, dict) and all(hasarg(self.criterion, name) for name in pred):
return self.criterion(**pred, **kwargs)
# Tensor, Tuple, List, or Dict, as-is, not unpacked.
else:
return self.criterion(pred, **kwargs)
return self.criterion(pred, **kwargs)

def _base_dataloader(self, mode: str) -> DataLoader:
"""Instantiate the dataloader for a mode (train/val/test/predict).
Expand Down Expand Up @@ -319,7 +311,7 @@ def _base_dataloader(self, mode: str) -> DataLoader:

# A dataset can return None when a corrupted example occurs. This collate
# function replaces None's with valid examples from the dataset.
collate_fn = partial(collate_fn_replace_corrupted, dataset=dataset, default_collate_fn=collate_fn)
collate_fn = partial(collate_replace_corrupted, dataset=dataset, default_collate_fn=collate_fn)
return DataLoader(
dataset,
sampler=sampler,
Expand Down Expand Up @@ -395,34 +387,6 @@ def setup(self, stage: str) -> None:
self.predict_dataloader = partial(self._base_dataloader, mode="predict")
self.predict_step = partial(self._base_step, mode="predict")

def _split_batch(self, batch) -> Tuple[torch.Tensor, Optional[Any]]:
"""Split the batch into input and target. Target will be `None` if not provided.
Args:
batch (List, Tuple): output of the DataLoader and input to the model.
Returns:
Tuple(torch.Tensor, Optional[Any]): input and target.
"""
# Check if the batch format is correct.
if len(batch) > 2:
raise ValueError(
"Found more than 2 items in the batch. `LighterSystem` requires the dataloader to return either "
"input tensor(s) only, or a two-element tuple/list consisting of input tensor(s) and target(s)."
)

# Report the batch split type. Only on the first call.
if not self._batch_type_reported:
self._batch_type_reported = True
if len(batch) == 1:
logger.info("Target not provided. Using `None` as target. Ignore if intended.")
else:
logger.info("Using the first item as input and the second item as target.")

# Split the batch into input and target. Target will be `None` if not provided.
input, target = (batch, None) if len(batch) == 1 else batch
return input, target

def _init_placeholders_for_dataloader_and_step_methods(self) -> None:
"""`LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods
in the `self.setup()` method. However, `LightningModule` excepts them to be defined at
Expand Down
14 changes: 9 additions & 5 deletions lighter/utils/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@

import torch
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import collate_str_fn, default_collate_fn_map
from torch.utils.data.dataloader import default_collate

# Collate support for None. Just as a string, None is not collated. Allows elements of the batch to be None.
default_collate_fn_map.update({type(None): collate_str_fn})

def collate_fn_replace_corrupted(batch: List[Any], dataset: DataLoader, default_collate_fn: Callable = None) -> torch.Tensor:

def collate_replace_corrupted(batch: Any, dataset: DataLoader, default_collate_fn: Callable = None) -> Any:
"""Collate function that allows to replace corrupted examples in the batch.
The dataloader should return `None` when that occurs.
The `None`s in the batch are replaced with other, randomly-selected, examples.
Args:
batch (List[Any]): batch from the DataLoader.
batch (Any): batch from the DataLoader.
dataset (Dataset): dataset that the DataLoader is passing through. Needs to be fixed
in place with functools.partial before passing it to DataLoader's 'collate_fn' option
as 'collate_fn' should only have a single argument - batch. Example:
```
collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)`
collate_fn = functools.partial(collate_replace_corrupted, dataset=dataset)`
loader = DataLoader(dataset, ..., collate_fn=collate_fn).
```
default_collate_fn (Callable): the collate function to call once the batch has no corrupted examples.
If `None`, `torch.utils.data.dataloader.default_collate` is called. Defaults to None.
Returns:
torch.Tensor: batch with new examples instead of corrupted ones.
Any: batch with new examples instead of corrupted ones.
"""
# Use `torch.utils.data.dataloader.default_collate` if no other default collate function is specified.
default_collate_fn = default_collate_fn if default_collate_fn is not None else default_collate
Expand All @@ -39,6 +43,6 @@ def collate_fn_replace_corrupted(batch: List[Any], dataset: DataLoader, default_
# Replace a corrupted example with another randomly selected example.
batch.extend([dataset[random.randint(0, len(dataset) - 1)] for _ in range(num_corrupted)])
# Recursive call to replace the replacements if they are corrupted.
return collate_fn_replace_corrupted(batch, dataset)
return collate_replace_corrupted(batch, dataset)
# Finally, when the whole batch is fine, apply the default collate function.
return default_collate_fn(batch)
12 changes: 0 additions & 12 deletions lighter/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ def hasarg(_callable: Callable, arg_name: str) -> bool:
return arg_name in args


def countargs(_callable: Callable) -> bool:
"""Count the number of arguments that a function, class, or method accepts.
Args:
callable (Callable): function, class, or method to inspect.
Returns:
int: number of arguments that it accepts.
"""
return len(inspect.signature(_callable).parameters.keys())


def get_name(_callable: Callable, include_module_name: bool = False) -> str:
"""Get the name of an object, class or function.
Expand Down

0 comments on commit f25361e

Please sign in to comment.