Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "id" support. Refactor Writers. Add Writer additional format extensibility. #78

Merged
merged 22 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b24ab22
Remove loss logging when predicting
ibro45 Aug 11, 2023
4a13721
Add "id" for each batch sample ID-ing purposes. Refactor Writers, add…
ibro45 Aug 13, 2023
2ea66bd
Remove interval arg, group_tensors fn, and on pred epoch end writing.…
ibro45 Aug 14, 2023
ffc26ae
Small fixes
ibro45 Aug 15, 2023
6c4563d
Remove multi opt and scheduler support. Replace remaininig sys.exit's.
ibro45 Aug 15, 2023
f8d689b
Update configure_optimizers docstring
ibro45 Aug 15, 2023
57b4447
Fix index ID issue in DDP writing. Replace broadcast with gather in t…
ibro45 Aug 15, 2023
605764a
Add missing if DDP check
ibro45 Aug 15, 2023
ae1a452
Update docstrings, rename and refactor parse_data
ibro45 Aug 16, 2023
cb63a9e
Add freezer to init file
ibro45 Aug 16, 2023
736d9f6
Change property to attribute
ibro45 Aug 16, 2023
fe7693a
Add support for dict metrics. Refactor system.
ibro45 Aug 16, 2023
57bcd7a
Fix typos
ibro45 Aug 16, 2023
799719e
Remove unused imports
ibro45 Aug 16, 2023
21d84f3
Update logger.py to support the temp ModuleDict fix
ibro45 Aug 18, 2023
c8eedea
Add continue to freezer and detach cpu to image logging
ibro45 Aug 20, 2023
8874399
Remove multi_pred, refactor Writer, Logger, and optional imports
ibro45 Sep 14, 2023
9ca6f7d
Bump gitpython from 3.1.32 to 3.1.35
dependabot[bot] Sep 14, 2023
01bfcd3
Bump certifi from 2023.5.7 to 2023.7.22
dependabot[bot] Sep 14, 2023
441a23d
Merge remote-tracking branch 'origin/dependabot/pip/certifi-2023.7.22…
ibro45 Sep 14, 2023
d0193b8
Merge remote-tracking branch 'origin/dependabot/pip/gitpython-3.1.35'…
ibro45 Sep 14, 2023
7ccce95
Remove add_batch_dim
ibro45 Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lighter/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .freezer import LighterFreezer
from .logger import LighterLogger
from .writer.file import LighterFileWriter
from .writer.table import LighterTableWriter
2 changes: 1 addition & 1 deletion lighter/callbacks/freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def on_test_batch_start(
self._on_batch_start(trainer, pl_module)

def on_predict_batch_start(
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
self._on_batch_start(trainer, pl_module)

Expand Down
11 changes: 4 additions & 7 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Union

import itertools
import sys
from datetime import datetime
from pathlib import Path

Expand All @@ -12,7 +11,7 @@
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor

from lighter import LighterSystem
from lighter.callbacks.utils import get_lighter_mode, is_data_type_supported, parse_data, preprocess_image
from lighter.callbacks.utils import flatten_structure, get_lighter_mode, is_data_type_supported, preprocess_image
from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS


Expand Down Expand Up @@ -62,8 +61,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
stage (str): stage of the training process. Passed automatically by PyTorch Lightning.
"""
if trainer.logger is not None:
logger.error("When using LighterLogger, set Trainer(logger=None).")
sys.exit()
raise ValueError("When using LighterLogger, set Trainer(logger=None).")

if not trainer.is_global_zero:
return
Expand All @@ -88,8 +86,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
if self.wandb:
OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb")
if not wandb_available:
logger.error("Weights & Biases not installed. To install it, run `pip install wandb`. Exiting.")
sys.exit()
raise ImportError("Weights & Biases not installed. To install it, run `pip install wandb`.")
wandb_dir = self.log_dir / "wandb"
wandb_dir.mkdir()
self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config)
Expand Down Expand Up @@ -222,7 +219,7 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None:
f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], "
f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported."
)
for identifier, item in parse_data(outputs[name]).items():
for identifier, item in flatten_structure(outputs[name]).items():
ibro45 marked this conversation as resolved.
Show resolved Hide resolved
item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}"
self._log_by_type(item_name, item, self.log_types[name], global_step)

Expand Down
87 changes: 32 additions & 55 deletions lighter/callbacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,100 +45,77 @@ def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List
return is_valid


def parse_data(
def flatten_structure(
data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None
) -> Dict[Optional[str], Any]:
"""
Parse the input data recursively, handling nested dictionaries, lists, and tuples.
Recursively parse nested data structures into a flat dictionary.

This function will recursively parse the input data, unpacking nested dictionaries, lists, and tuples. The result
will be a dictionary where each key is a unique identifier reflecting the data's original structure (dict keys
or list/tuple positions) and each value is a non-container data type from the input data.
This function flattens dictionaries, lists, and tuples, returning a dictionary where each key is constructed
from the original structure's keys or list/tuple indices. The values in the output dictionary are non-container
data types extracted from the input.

Args:
data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to parse.
prefix (Optional[str]): Current prefix for keys in the result dictionary. Defaults to None.
data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]):
The input data to parse. Can be of any data type but the function is optimized
to handle dictionaries, lists, and tuples. Nested structures are also supported.

prefix (Optional[str]):
A prefix used when constructing keys for the output dictionary. Useful for recursive
calls to maintain context. Defaults to None.

Returns:
Dict[Optional[str], Any]: A dictionary where key is either a string identifier or `None`, and value is the parsed output.
Dict[Optional[str], Any]:
A flattened dictionary where keys are unique identifiers built from the original data structure,
and values are non-container data extracted from the input.

Example:
input_data = {
"a": [1, 2],
"b": {"c": (3, 4), "d": 5}
}
output_data = parse_data(input_data)
# Output:
# {
# 'a_0': 1,
# 'a_1': 2,
# 'b_c_0': 3,
# 'b_c_1': 4,
# 'b_d': 5
# }
output_data = flatten_structure(input_data)

Expected output:
{
'a_0': 1,
'a_1': 2,
'b_c_0': 3,
'b_c_1': 4,
'b_d': 5
}
"""
result = {}
if isinstance(data, dict):
for key, value in data.items():
# Recursively parse the value with an updated prefix
sub_result = parse_data(value, prefix=f"{prefix}_{key}" if prefix else key)
sub_result = flatten_structure(value, prefix=f"{prefix}_{key}" if prefix else key)
result.update(sub_result)
elif isinstance(data, (list, tuple)):
for idx, element in enumerate(data):
# Recursively parse the element with an updated prefix
sub_result = parse_data(element, prefix=f"{prefix}_{idx}" if prefix else str(idx))
sub_result = flatten_structure(element, prefix=f"{prefix}_{idx}" if prefix else str(idx))
result.update(sub_result)
else:
# Assign the value to the result dictionary using the current prefix as its key
result[prefix] = data
return result


def gather_tensors(
inputs: Union[List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]]
) -> Union[List, Dict]:
"""Recursively gather tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict).
An input list of tensors is returned as-is. Given an input list of data structures with tensors, this function
will gather all tensors into a list and save it under a single data structure. Assumes that all elements of
the input list have the same type and structure.

Args:
inputs (List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]):
They can be:
- List/Tuples of Dictionaries, each containing tensors to be gathered by their key.
- List/Tuples of Lists/tuples, each containing tensors to be gathered by their position.
- List/Tuples of Tensors, returned as-is.
- Nested versions of the above.
The input data structure must be the same for all elements of the list. They can be arbitrarily nested.

Returns:
Union[List, Dict]: The gathered tensors.
"""
# List of dicts.
if isinstance(inputs[0], dict):
keys = inputs[0].keys()
return {key: gather_tensors([input[key] for input in inputs]) for key in keys}
# List of lists or tuples.
elif isinstance(inputs[0], (list, tuple)):
return [gather_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))]
# List of tensors.
elif isinstance(inputs[0], torch.Tensor):
return inputs
else:
raise TypeError(f"Type `{type(inputs[0])}` not supported.")


def preprocess_image(image: torch.Tensor) -> torch.Tensor:
def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor:
"""Preprocess the image before logging it. If it is a batch of multiple images,
it will create a grid image of them. In case of 3D, a single image is displayed
with slices stacked vertically, while a batch of 3D images as a grid where each
column is a different 3D image.
Args:
image (torch.Tensor): 2D or 3D image tensor.
add_batch_dim (bool, optional): Whether to add a batch dimension to the input image.
Use only when the input image does not have a batch dimension. Defaults to False.
Returns:
torch.Tensor: image ready for logging.
"""
image = image.detach().cpu()
if add_batch_dim:
image = image.unsqueeze(0)
# If 3D (BCDHW), concat the images vertically and horizontally.
if image.ndim == 5:
shape = image.shape
Expand Down
Loading