Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LighterWriter - write/save predictions #40

Merged
merged 13 commits into from
Mar 2, 2023
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ tensorboard/
prototyping.ipynb
checkpoints/

# Our ignores
projects/*
!projects/README.md
!projects/cifar10
**/predictions/
2 changes: 2 additions & 0 deletions lighter/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .logger import LighterLogger
from .writer.file import LighterFileWriter
from .writer.table import LighterTableWriter
49 changes: 19 additions & 30 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Union

import sys
from datetime import datetime
from pathlib import Path

import torch
import torch.distributed as dist
import torchvision
from loguru import logger
from monai.utils.module import optional_import
from pytorch_lightning import Callback, Trainer
Expand All @@ -20,14 +18,14 @@
class LighterLogger(Callback):
def __init__(
self,
project,
log_dir,
tensorboard=False,
wandb=False,
input_type=None,
target_type=None,
pred_type=None,
max_samples=None,
project: str,
log_dir: str,
tensorboard: bool = False,
wandb: bool = False,
input_type: str = None,
target_type: str = None,
pred_type: str = None,
max_samples: int = None,
) -> None:
self.project = project
# Only used on rank 0, the dir is created in setup().
Expand Down Expand Up @@ -62,7 +60,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
logger.error("When using LighterLogger, set Trainer(logger=None).")
sys.exit()

if dist.is_initialized() and dist.get_rank() != 0:
if not trainer.is_global_zero:
return

self.log_dir.mkdir(parents=True)
Expand Down Expand Up @@ -94,7 +92,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
# self.wandb.config.update(config)

def teardown(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
if dist.is_initialized() and dist.get_rank() != 0:
if not trainer.is_global_zero:
return
self.tensorboard.close()

Expand All @@ -109,9 +107,6 @@ def _log(self, outputs: dict, mode: str, global_step: int, is_epoch=False) -> No
is_epoch (bool): whether the log is being done at the end
of an epoch or astep. Default is False.
"""
if dist.is_initialized() and dist.get_rank() != 0:
return

step_or_epoch = "epoch" if is_epoch else "step"

# Loss
Expand Down Expand Up @@ -229,9 +224,8 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None:
# Accumulate the loss.
if mode in ["train", "val"]:
self.loss[mode] += outputs["loss"].item()
# Logging frequency.
if self.global_step_counter[mode] % trainer.log_every_n_steps == 0:
# Log. Done only on rank 0.
# Logging frequency. Log only on rank 0.
if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0:
self._log(outputs, mode, global_step=self._get_global_step(trainer))
# Increment the step counters.
self.global_step_counter[mode] += 1
Expand All @@ -241,7 +235,7 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None:
def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
"""Performs logging at the end of an epoch. It calculates the average
loss and metrics for the epoch and logs them. In distributed mode, it averages
the losses and metrics from all processes.
the losses and metrics from all ranks.

Args:
trainer (Trainer): Trainer, passed automatically by PyTorch Lightning.
Expand All @@ -255,14 +249,8 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
if mode in ["train", "val"]:
# Get the accumulated loss.
loss = self.loss[mode]
# Reduce the loss to rank 0 and average it.
if dist.is_initialized():
# Distributed communication works only tensors.
loss = torch.tensor(loss).to(pl_module.device)
# On rank 0, sum the losses from all ranks. Other ranks remain with the same loss as before.
dist.reduce(loss, dst=0)
# On rank 0, average the loss sum by dividing it with the number of processes.
loss = loss.item() / dist.get_world_size() if dist.get_rank() == 0 else loss.item()
# Reduce the loss and average it on each rank.
loss = trainer.strategy.reduce(loss, reduce_op="mean")
# Divide the accumulated loss by the number of steps in the epoch.
loss /= self.epoch_step_counter[mode]
outputs["loss"] = loss
Expand All @@ -276,8 +264,9 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
# Reset the metrics for the next epoch.
metrics.reset()

# Log. Done only on rank 0.
self._log(outputs, mode, is_epoch=True, global_step=self._get_global_step(trainer))
# Log. Only on rank 0.
if trainer.is_global_zero:
self._log(outputs, mode, is_epoch=True, global_step=self._get_global_step(trainer))

def _get_global_step(self, trainer: Trainer) -> int:
"""Return the global step for the current mode. Note that when Trainer
Expand Down
Empty file.
140 changes: 140 additions & 0 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Any, Dict, List, Optional, Union

import itertools
import sys
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path

import torch
from loguru import logger
from pytorch_lightning import Callback, Trainer

from lighter import LighterSystem
from lighter.callbacks.utils import parse_data, structure_preserving_concatenate


class LighterBaseWriter(ABC, Callback):
"""Base class for a Writer. Override `self.write()` to define how a prediction should be saved.
`LighterBaseWriter` sets up the write directory, and defines `on_predict_batch_end` and
`on_predict_epoch_end`. `write_interval` specifies which of the two should the writer call.

Args:
write_dir (str): the Writer will create a directory inside of `write_dir` with date
and time as its name and store the predictions there.
write_as (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]):
type in which the predictions will be stored. Passed automatically to the `write()`
abstract method and can be used to support writing different types. Should the Writer
support only one type, this argument can be removed from the overriden `__init__()`'s
arguments and set `self.write_as = None`.
write_interval (str, optional): whether to write on each step or at the end of the prediction epoch.
Defaults to "step".
"""

def __init__(
self,
write_dir: str,
write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]],
write_interval: str = "step",
) -> None:
self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S")
self.write_as = write_as
self.write_interval = write_interval

self.parsed_write_as = None

@abstractmethod
def write(
self,
idx: int,
identifier: Optional[str],
tensor: torch.Tensor,
write_as: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]],
):
"""This method must be overridden to specify how a tensor should be saved. If the Writer
supports multiple types of saving, handle the `write_as` argument with an if-else statement.

If the Writer only supports one type, remove `write_as` from the overridden
`__init__()` method and set `self.write_as=None`.

The `idx` and `identifier` arguments can be used to specify the name of the file
or the row and column of a table for the prediction.

Parameters:
idx (int): The index of the prediction.
identifier (Optional[str]): The identifier of the prediction. It will be `None` if there's
only one prediction, an index if the prediction is a list of predictions, a key if it's
a dict of predictions, and a key_index if it's a dict of list of predictions.
tensor (torch.Tensor): The predicted tensor.
write_as (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]):
Specifies how to write the predictions. If it's a single string value, the predictions
will be saved under that type regardless of whether they are single- or multi-output
predictions. To write different outputs in the multi-output predictions using different
methods, use the appropriate format for `write_as`.
"""

def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
if stage != "predict":
return

if self.write_interval not in ["step", "epoch"]:
logger.error("`write_interval` must be either 'step' or 'epoch'.")
sys.exit()

# Broadcast the `write_dir` so that all ranks write their predictions there.
self.write_dir = trainer.strategy.broadcast(self.write_dir)
# Let rank 0 create the `write_dir`.
if trainer.is_global_zero:
self.write_dir.mkdir(parents=True)
# If `write_dir` does not exist, the ranks are not on the same storage.
if not self.write_dir.exists():
logger.error(
f"Rank {trainer.global_rank} is not on the same storage as rank 0."
"Please run the prediction only on nodes that are on the same storage."
)
sys.exit()

def on_predict_batch_end(
self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.write_interval != "step":
return
indices = trainer.predict_loop.epoch_loop.current_batch_indices
self._on_batch_or_epoch_end(outputs, indices)

def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None:
if self.write_interval != "epoch":
return
# Only one epoch when predicting, index the lists of outputs and batch indices accordingly.
indices = trainer.predict_loop.epoch_batch_indices[0]
outputs = outputs[0]
# Concatenate/flatten so that each output corresponds to its index.
indices = list(itertools.chain(*indices))
outputs = structure_preserving_concatenate(outputs)
self._on_batch_or_epoch_end(outputs, indices)

def _on_batch_or_epoch_end(self, outputs, indices):
# Parse the outputs into a structure ready for writing.
parsed_outputs = parse_data(outputs)
# Runs only on the first step.
if self.parsed_write_as is None:
# Parse `self.write_as`. If multi-value, check if its structure matches `parsed_output`'s structure.
self.parsed_write_as = self._parse_write_as(self.write_as, parsed_outputs)

for idx in indices:
for identifier in parsed_outputs: # pylint: disable=consider-using-dict-items
tensor = parsed_outputs[identifier]
write_as = self.parsed_write_as[identifier]
self.write(idx, identifier, tensor, write_as)

def _parse_write_as(self, write_as, parsed_outputs: Dict[str, Any]):
# If `write_as` is a string (single value), all outputs will be saved in that specified format.
if isinstance(write_as, str):
parsed_write_as = {key: write_as for key in parsed_outputs}
# Otherwise, `write_as` needs to match the structure of the outputs in order to assign each tensor its type.
else:
parsed_write_as = parse_data(write_as)
if not set(parsed_write_as) == set(parsed_outputs):
logger.error("`write_as` structure does not match the prediction's structure.")
sys.exit()
return parsed_write_as
34 changes: 34 additions & 0 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

import torch
import torchvision
from loguru import logger

from lighter.callbacks.utils import preprocess_image
from lighter.callbacks.writer.base import LighterBaseWriter


class LighterFileWriter(LighterBaseWriter):
def write(self, idx, identifier, tensor, write_as):
filename = f"{write_as}" if identifier is None else f"{identifier}_{write_as}"
write_dir = self.write_dir / str(idx)
write_dir.mkdir()

if write_as is None:
pass
elif write_as == "tensor":
path = write_dir / f"{filename}.pt"
torch.save(tensor, path)
elif write_as == "image":
path = write_dir / f"{filename}.png"
torchvision.io.write_png(preprocess_image(tensor), path)
elif write_as == "video":
path = write_dir / f"{filename}.mp4"
torchvision.io.write_video(path, tensor, fps=24)
elif write_as == "scalar":
raise NotImplementedError
elif write_as == "audio":
raise NotImplementedError
else:
logger.error(f"`write_as` '{write_as}' not supported.")
sys.exit()
57 changes: 57 additions & 0 deletions lighter/callbacks/writer/table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict, List, Union

import itertools
import sys

import pandas as pd
from loguru import logger
from pytorch_lightning import Trainer

from lighter import LighterSystem
from lighter.callbacks.writer.base import LighterBaseWriter


class LighterTableWriter(LighterBaseWriter):
def __init__(self, write_dir: str, write_as: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None:
super().__init__(write_dir, write_as, write_interval="epoch")
self.csv_records = {}

def write(self, idx, identifier, tensor, write_as):
# Column name will be set to 'pred' if the identifier is None.
column = "pred" if identifier is None else identifier

if write_as is None:
record = None
elif write_as == "tensor":
record = tensor.tolist()
elif write_as == "scalar":
raise NotImplementedError
else:
logger.error(f"`write_as` '{write_as}' not supported.")
sys.exit()

if idx not in self.csv_records:
self.csv_records[idx] = {column: record}
else:
self.csv_records[idx][column] = record

def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None:
super().on_predict_epoch_end(trainer, pl_module, outputs)

csv_path = self.write_dir / "predictions.csv"
logger.info(f"Saving the predictions to {csv_path}")

# Sort the dict of dicts by key and turn it into a list of dicts.
self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)]
# Gather the records from all ranks when in DDP.
if trainer.world_size > 1:
# Since `all_gather` supports tensors only, mimic the behavior using `broadcast`.
ddp_csv_records = [self.csv_records] * trainer.world_size
for rank in range(trainer.world_size):
# Broadcast the records from the current rank and save it at its designated position.
ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank)
# Combine the records from all ranks. List of lists of dicts -> list of dicts.
self.csv_records = list(itertools.chain(*ddp_csv_records))

# Create a dataframe and save it.
pd.DataFrame(self.csv_records).to_csv(csv_path)
Loading