Skip to content

Commit

Permalink
Improve style (#27)
Browse files Browse the repository at this point in the history
* Add CI, packaging and tests

* simplify trainer method config, reversed config_file list for project importing

* update the test

* improve style

* more minor changes

* predict to forward to prevent confusion with predict mode

* Fix wrongly merged test

* Delete dependabot.yml

---------

Co-authored-by: Suraj Pai <[email protected]>
  • Loading branch information
ibro45 and surajpaib authored Feb 1, 2023
1 parent e8868de commit 1e0c213
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 59 deletions.
6 changes: 3 additions & 3 deletions assets/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 9 additions & 9 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import re
import sys
from datetime import datetime
from pathlib import Path
Expand All @@ -11,8 +10,6 @@
from loguru import logger
from monai.utils.module import optional_import
from pytorch_lightning import Callback, Trainer
from torch.utils import tensorboard
from yaml import safe_load

from lighter import LighterSystem

Expand Down Expand Up @@ -71,16 +68,18 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
self.log_dir.mkdir(parents=True)

# Load the dumped config file to log it to the loggers.
# config = safe_load(open(self.log_dir / "config.yaml"))
# config = yaml.safe_load(open(self.log_dir / "config.yaml"))

# Loguru log file.
# logger.add(sink=self.log_dir / f"{stage}.log")

# Tensorboard initialization.
if self.tensorboard:
# Tensorboard is a part of PyTorch, no need to check if it is not available.
OPTIONAL_IMPORTS["tensorboard"], _ = optional_import("torch.utils.tensorboard")
tensorboard_dir = self.log_dir / "tensorboard"
tensorboard_dir.mkdir()
self.tensorboard = tensorboard.SummaryWriter(log_dir=tensorboard_dir)
self.tensorboard = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter(log_dir=tensorboard_dir)
# self.tensorboard.add_hparams(config)

# Wandb initialization.
Expand Down Expand Up @@ -343,17 +342,17 @@ def check_image_data_type(data: Any, name: str) -> None:
name (str): name of the image data, for logging purposes.
"""
if isinstance(data, dict):
is_valid = all(check_image_data_type(elem) for elem in data.values())
is_valid = all(check_image_data_type(elem, name) for elem in data.values())
elif isinstance(data, list):
is_valid = all(check_image_data_type(elem) for elem in data)
is_valid = all(check_image_data_type(elem, name) for elem in data)
elif isinstance(data, torch.Tensor):
is_valid = True
else:
is_valid = False

if not is_valid:
logger.error(
f"`{name}` has to be a Tensor, List[Tensors], Dict[str, Tensor]"
f"`{name}` has to be a Tensor, List[Tensor], Dict[str, Tensor]"
f", or Dict[str, List[Tensor]]. `{type(data)}` is not supported."
)
sys.exit()
Expand All @@ -366,7 +365,8 @@ def parse_image_data(
each tuple contains an identifier and a tensor.
Args:
data (Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], List[torch.Tensor], torch.Tensor]): image tensor(s).
data (Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor],
Dict[str, List[torch.Tensor]]]): image tensor(s).
Returns:
List[Tuple[Optional[str], torch.Tensor]]: a list of tuples where the first element is
Expand Down
33 changes: 14 additions & 19 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def forward(self, input: Union[torch.Tensor, List, Tuple]) -> Union[torch.Tensor
return self.model(input, **kwargs)

def _base_step(self, batch: Tuple, batch_idx: int, mode: str) -> Union[Dict[str, Any], Any]:
"""Base step for all modes ('train', 'val', 'test', 'predict')
"""Base step for all modes ("train", "val", "test", "predict")
Args:
batch (Tuple):
Expand All @@ -183,16 +183,15 @@ def _base_step(self, batch: Tuple, batch_idx: int, mode: str) -> Union[Dict[str,
"""
input, target = batch if len(batch) == 2 else (batch[:-1], batch[-1])

# Predict
# Forward
if self._patch_based_inferer and mode in ["val", "test", "predict"]:
# TODO: Patch-based inference doesn't support multiple inputs yet
pred = self._patch_based_inferer(input, self)
else:
pred = self(input)

pred = reshape_pred_if_single_value_prediction(pred, target)

# Calculate the loss
# Calculate the loss.
loss = None
if mode in ["train", "val"]:
loss = self._calculate_loss(pred, target)
Expand All @@ -203,11 +202,11 @@ def _base_step(self, batch: Tuple, batch_idx: int, mode: str) -> Union[Dict[str,
if self._post_criterion_activation is not None:
pred = self._post_criterion_activation(pred)

# In predict mode, skip metrics and logging parts and return the predicted value
# In predict mode, skip metrics and logging parts and return the predicted value.
if mode == "predict":
return pred

# Calculate the metrics for the step
# Calculate the metrics for the step.
step_metrics = getattr(self, f"{mode}_metrics")(pred, target)

return {"loss": loss, "metrics": step_metrics, "input": input, "target": target, "pred": pred}
Expand Down Expand Up @@ -256,14 +255,14 @@ def _calculate_loss(self, pred: Union[torch.Tensor, List, Tuple], target: Union[
return self.criterion(pred, **kwargs)

def _base_dataloader(self, mode: str) -> DataLoader:
"""Instantiate the dataloader for a mode (train/val/test).
"""Instantiate the dataloader for a mode (train/val/test/predict).
Includes a collate function that enables the DataLoader to replace
None's (alias for corrupted examples) in the batch with valid examples.
To make use of it, write a try-except in your Dataset that handles
corrupted data by returning None instead.
Args:
mode (str): mode for which to create the dataloader ['train', 'val', 'test'].
mode (str): mode for which to create the dataloader ["train", "val", "test", "predict"].
Returns:
DataLoader: instantiated DataLoader.
Expand Down Expand Up @@ -299,36 +298,32 @@ def _base_dataloader(self, mode: str) -> DataLoader:
collate_fn=collate_fn,
)

def configure_optimizers(self) -> Dict:
def configure_optimizers(self) -> Union[Optimizer, List[Dict[str, Union[Optimizer, "Scheduler"]]]]:
"""LightningModule method. Returns optimizers and, if defined, schedulers.
Returns:
Single optimizer: If only a optimizer is provided
Tuple of dictionaries: a tuple of `dict` with keys `optimizer` and `lr_scheduler`
Optimizer or a List of Dict of paired Optimizers and Schedulers: instantiated
optimizers and/or schedulers.
"""
if not self.optimizers:
logger.error("Please specify 'optimizers' in the config. Exiting.")
logger.error("Please specify 'system.optimizers' in the config. Exiting.")
sys.exit()
if not self.schedulers:
return self.optimizers

if len(self.optimizers) != len(self.schedulers):
logger.error("'optimizers' and 'schedulers' should be paired")
logger.error("Each optimizer must have its own scheduler.")
sys.exit()

optim_sched_paired = []
for optimizer, scheduler in zip(self.optimizers, self.schedulers):
optim_sched_paired.append({"optimizer": optimizer, "lr_scheduler": scheduler})

return tuple(optim_sched_paired)
return [{"optimizer": opt, "lr_scheduler": sched} for opt, sched in zip(self.optimizers, self.schedulers)]

def setup(self, stage: str) -> None:
"""Automatically called by the LightningModule after the initialization.
`LighterSystem`'s setup checks if the required dataset is provided in the config and
sets up LightningModule methods for the stage in which the system is.
Args:
stage (str): passed by PyTorch Lightning. ['fit', 'validate', 'test'].
stage (str): passed by PyTorch Lightning. ["fit", "validate", "test"].
"""
# Stage-specific PyTorch Lightning methods. Defined dynamically so that the system
# only has methods used in the stage and for which the configuration was provided.
Expand Down
21 changes: 11 additions & 10 deletions lighter/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ def run_trainer_method(method_name, method_config: Dict, **kwargs: Any):
project_imported = False
# Handle multiple configs. Start from the config file specified last as it overrides the previous ones.
for config in reversed(ensure_list(kwargs["config_file"])):
config = yaml.safe_load(open(config, "r"))
if "project" not in config:
continue
# Only one config file can specify the project path
if project_imported:
logger.error("`project` must be specified in one config only. Exiting.")
sys.exit()
# Import it as a module named 'project'.
import_module_from_path("project", config["project"])
project_imported = True
with open(config, "r", encoding="utf-8") as config:
config = yaml.safe_load(config)
if "project" not in config:
continue
# Only one config file can specify the project path
if project_imported:
logger.error("`project` must be specified in one config only. Exiting.")
sys.exit()
# Import it as a module named 'project'.
import_module_from_path("project", config["project"])
project_imported = True
# Run the Trainer method.
run(method_name, **method_config, **kwargs)
32 changes: 16 additions & 16 deletions lighter/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@
from loguru import logger


def ensure_list(x: Any) -> List:
def ensure_list(vals: Any) -> List:
"""Wrap the input into a list if it is not a list. If it is a None, return an empty list.
Args:
x (Any): input to wrap into a list.
vals (Any): input to wrap into a list.
Returns:
List: output list.
"""
if isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
if x is None:
if isinstance(vals, list):
return vals
if isinstance(vals, tuple):
return list(vals)
if vals is None:
return []
return [x]
return [vals]


def dot_notation_setattr(obj: Callable, attr: str, value: Any):
"""Set object's attribute. May use dot notation.
def setattr_dot_notation(obj: Callable, attr: str, value: Any):
"""Set object's attribute. Supports dot notation.
Args:
obj (Callable): object.
attr (str): attribute name of the object.
value (Any): attribute value to be set.
value (Any): value to set the attribute to.
"""
if "." not in attr:
if not hasattr(obj, attr):
Expand All @@ -40,14 +40,14 @@ def dot_notation_setattr(obj: Callable, attr: str, value: Any):
# Solve recursively if the attribute is defined in dot-notation
else:
obj_name, attr = attr.split(".", maxsplit=1)
dot_notation_setattr(getattr(obj, obj_name), attr, value)
setattr_dot_notation(getattr(obj, obj_name), attr, value)


def hasarg(_callable: Callable, arg_name: str) -> bool:
"""Check if a function, class, or method has an argument with the specified name.
Args:
callable (Callable): function, class, or method to inspect.
_callable (Callable): function, class, or method to inspect.
arg_name (str): argument name to check for.
Returns:
Expand All @@ -74,15 +74,15 @@ def get_name(x: Callable, include_module_name: bool = False) -> str:
"""Get the name of an object, class or function.
Args:
x (Callable): object, class or function.
_callable (Callable): object, class or function.
include_module_name (bool, optional): whether to include the name of the module from
which it comes. Defaults to False.
Returns:
str: name
"""
name = type(x).__name__ if isinstance(x, object) else x.__name__
name = type(_callable).__name__ if isinstance(_callable, object) else _callable.__name__
if include_module_name:
module = type(x).__module__ if isinstance(x, object) else x.__module__
module = type(_callable).__module__ if isinstance(_callable, object) else _callable.__module__
name = f"{module}.{name}"
return name
4 changes: 2 additions & 2 deletions lighter/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn import Identity, Module, Sequential

from lighter.utils.misc import dot_notation_setattr
from lighter.utils.misc import setattr_dot_notation


def reshape_pred_if_single_value_prediction(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -33,7 +33,7 @@ def replace_layer_with(model: Module, layer_name: str, new_layer: Module) -> Mod
Returns:
Module: PyTorch model with the new layer set at the specified location.
"""
dot_notation_setattr(model, layer_name, new_layer)
setattr_dot_notation(model, layer_name, new_layer)
return model


Expand Down

0 comments on commit 1e0c213

Please sign in to comment.