Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cli and runner. Implement reserved config keys. Add feature to pass args to Trainer's methods. #124

Merged
merged 14 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/basics/projects.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ In the above example, the path of the dataset is `/home/user/project/my_xray_dat

=== "Terminal"
```
lighter fit --config_file xray.yaml
lighter fit --config xray.yaml
```

</div>

1. Make sure to put an `__init__.py` file in this directory. Remember this is needed for an importable python module
1. Make sure to put an `__init__.py` file in this directory. Remember this is needed for an importable python module
4 changes: 2 additions & 2 deletions docs/basics/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ We just combine the Trainer and LighterSystem into a single YAML and run the com
```
=== "Terminal"
```
lighter fit --config_file cifar10.yaml
lighter fit --config cifar10.yaml
```


Congratulations!! You have run your first training example with Lighter.
Congratulations!! You have run your first training example with Lighter.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Say goodbye to messy scripts and notebooks. Lighter is here to help you organize
<div style="width: 49%;">
<h3 style="text-align: center">Lighter</h3>
```bash title="Terminal"
lighter fit --config_file cifar10.yaml
lighter fit --config cifar10.yaml
```
```yaml title="cifar10.yaml"
trainer:
Expand Down
23 changes: 0 additions & 23 deletions lighter/utils/cli.py

This file was deleted.

1 change: 0 additions & 1 deletion lighter/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
SUPPRESSED_MODULES = [
"fire",
"monai.bundle",
"lighter.utils.cli",
"lighter.utils.runner",
"pytorch_lightning.trainer",
"lightning_utilities",
Expand Down
111 changes: 76 additions & 35 deletions lighter/utils/runner.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,111 @@
from typing import Any, Dict
from typing import Any

from functools import partial

import fire
from monai.bundle.config_parser import ConfigParser
from pytorch_lightning import seed_everything

from lighter.system import LighterSystem
from lighter.utils.dynamic_imports import import_module_from_path

CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}}
TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"]


def cli() -> None:
"""Defines the command line interface for running lightning trainer's methods."""
commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES}
fire.Fire(commands)


def parse_config(**kwargs) -> ConfigParser:
"""
Parses configuration files and updates the provided parser
with given keyword arguments. Returns an updated parser object.

Args:
**kwargs (dict): Keyword arguments containing configuration data.
config_file (str): Path to the main configuration file.
args_file (str, optional): Path to secondary configuration file for additional arguments.
Additional key-value pairs can also be provided to be added or updated in the parser.

**kwargs (dict): Keyword arguments containing 'config' and, optionally, config overrides.
Returns:
An instance of ConfigParser with parsed and merged configuration data.
An instance of ConfigParser with configuration and overrides merged and parsed.
"""
# Ensure a config file is specified.
config = kwargs.pop("config", None)
if config is None:
raise ValueError("'--config' not specified. Please provide a valid configuration file.")

# Read the config file and update it with overrides.
parser = ConfigParser(CONFIG_STRUCTURE, globals=False)
parser.read_config(config)
parser.update(kwargs)
return parser

surajpaib marked this conversation as resolved.
Show resolved Hide resolved
# Check that a config file is specified.
if "config_file" not in kwargs:
raise ValueError("--config_file not specified. Exiting.")

# Parse the config file(s).
parser = ConfigParser()
parser.read_config(kwargs.pop("config_file"))
parser.update(pairs=kwargs)
def validate_config(parser: ConfigParser) -> None:
"""
Validates the configuration parser against predefined structures and allowed method names.

# Import the project folder as a module, if specified.
project = parser.get("project", None)
if project is not None:
import_module_from_path("project", project)
This function checks if the keys in the top-level of the configuration parser are valid according to the
CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.

return parser
Args:
parser (ConfigParser): The configuration parser instance to validate.

Raises:
ValueError: If there are invalid keys in the top-level configuration.
ValueError: If there are invalid method names specified in the 'args' section.
"""
# Validate parser keys against structure
root_keys = parser.get().keys()
invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
if invalid_root_keys:
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}")

surajpaib marked this conversation as resolved.
Show resolved Hide resolved
# Validate that 'args' contains only valid trainer method names.
args_keys = parser.get("args", {}).keys()
invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
if invalid_args_keys:
raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")


def run_trainer_method(method: Dict, **kwargs: Any):
"""Call monai.bundle.run() on a Trainer method. If a project path
is defined in the config file(s), import it.
def run(method: str, **kwargs: Any):
"""Run the trainer method.

Args:
method (str): name of the Trainer method to run. ["fit", "validate", "test", "predict", "tune"].
**kwargs (Any): keyword arguments passed to the `monai.bundle.run` function.
method (str): name of the trainer method to run.
**kwargs (Any): keyword arguments that include 'config' and specific config overrides passed to `parse_config()`.
"""
# Sets the random seed to `PL_GLOBAL_SEED` env variable. If not specified, it picks a random seed.
seed_everything()

# Parse the config file(s).
# Parse and validate the config.
parser = parse_config(**kwargs)
validate_config(parser)

# Get trainer and system
trainer = parser.get_parsed_content("trainer")
# Import the project folder as a module, if specified.
project = parser.get_parsed_content("project")
if project is not None:
import_module_from_path("project", project)

# Get the main components from the parsed config.
system = parser.get_parsed_content("system")
trainer = parser.get_parsed_content("trainer")
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})

# Checks
if not isinstance(system, LighterSystem):
raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.")
if not hasattr(trainer, method):
raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.")
if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args):
raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.")

# Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
config = parser.get()
config.pop("_meta_")
# Save the config to model checkpoints under the "hyper_parameters" key.
config.pop("_meta_") # MONAI Bundle adds this automatically, remove it.
system.save_hyperparameters(config)
# Log the config.
if trainer.logger is not None:
trainer.logger.log_hyperparams(config)

# Run the Trainer method.
if not hasattr(trainer, method):
raise ValueError(f"Trainer has no method named {method}.")
getattr(trainer, method)(system)
# Run the trainer method.
getattr(trainer, method)(system, **trainer_method_args)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
lighter = "lighter.utils.cli:interface"
lighter = "lighter.utils.runner:cli"

[tool.poetry]
name = "project-lighter"
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/test_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pytest

from lighter.utils.cli import run_trainer_method
from lighter.utils.runner import run

test_overrides = "./tests/integration/test_overrides.yaml"


@pytest.mark.parametrize(
("method_name", "config_file"),
("method_name", "config"),
[
( # Method name
"fit",
Expand All @@ -18,9 +18,9 @@
],
)
@pytest.mark.slow
def test_trainer_method(method_name: str, config_file: str):
def test_trainer_method(method_name: str, config: str):
""" """
kwargs = {"config_file": config_file, "args_file": test_overrides}
kwargs = {"config": [config, test_overrides]}

func_return = run_trainer_method(method_name, **kwargs)
func_return = run(method_name, **kwargs)
assert func_return is None
Loading