From 110ca842c3edd3e90901b8d5ec947955489255e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 19:54:13 +0200 Subject: [PATCH 1/8] Promote the CLI out of utilities --- .../image_classifier_4_lightning_module.py | 2 +- ...image_classifier_5_lightning_datamodule.py | 2 +- examples/pl_basics/autoencoder.py | 2 +- .../pl_basics/backbone_image_classifier.py | 2 +- examples/pl_basics/profiler_example.py | 2 +- .../computer_vision_fine_tuning.py | 2 +- examples/pl_domain_templates/imagenet.py | 2 +- examples/pl_hpu/mnist_sample.py | 2 +- .../pl_integrations/dali_image_classifier.py | 2 +- examples/pl_servable_module/production.py | 2 +- src/pytorch_lightning/cli.py | 699 ++++++++++++++++++ src/pytorch_lightning/utilities/cli.py | 698 +---------------- tests/tests_app/core/scripts/lightning_cli.py | 4 +- tests/tests_app/core/scripts/registry.py | 2 +- .../tests_app/utilities/test_introspection.py | 2 +- .../deprecated_api/test_remove_1-9.py | 2 +- .../tests_pytorch/{utilities => }/test_cli.py | 61 +- 17 files changed, 749 insertions(+), 739 deletions(-) create mode 100644 src/pytorch_lightning/cli.py rename tests/tests_pytorch/{utilities => }/test_cli.py (97%) diff --git a/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py b/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py index ec7ff5edb2de6..6c2c84014a5f2 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py +++ b/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py @@ -25,7 +25,7 @@ from pytorch_lightning import cli_lightning_logo, LightningModule from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py b/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py index 3e1357e2dfbb5..10340611aa310 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py +++ b/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py @@ -25,7 +25,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pl_basics/autoencoder.py b/examples/pl_basics/autoencoder.py index 6fcbeafa4b389..9b1e571f0a474 100644 --- a/examples/pl_basics/autoencoder.py +++ b/examples/pl_basics/autoencoder.py @@ -25,7 +25,7 @@ from pytorch_lightning import callbacks, cli_lightning_logo, LightningDataModule, LightningModule, Trainer from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only diff --git a/examples/pl_basics/backbone_image_classifier.py b/examples/pl_basics/backbone_image_classifier.py index be9a255cf990f..3e1854df85b54 100644 --- a/examples/pl_basics/backbone_image_classifier.py +++ b/examples/pl_basics/backbone_image_classifier.py @@ -24,7 +24,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: diff --git a/examples/pl_basics/profiler_example.py b/examples/pl_basics/profiler_example.py index 050740e3ce314..0fc27aa60f652 100644 --- a/examples/pl_basics/profiler_example.py +++ b/examples/pl_basics/profiler_example.py @@ -32,7 +32,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.profilers.pytorch import PyTorchProfiler -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI DEFAULT_CMD_LINE = ( "fit", diff --git a/examples/pl_domain_templates/computer_vision_fine_tuning.py b/examples/pl_domain_templates/computer_vision_fine_tuning.py index dc31d79ab0032..fedd837de0348 100644 --- a/examples/pl_domain_templates/computer_vision_fine_tuning.py +++ b/examples/pl_domain_templates/computer_vision_fine_tuning.py @@ -56,7 +56,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities.rank_zero import rank_zero_info log = logging.getLogger(__name__) diff --git a/examples/pl_domain_templates/imagenet.py b/examples/pl_domain_templates/imagenet.py index 8e8c9bd0f0105..bfbfedb58f990 100644 --- a/examples/pl_domain_templates/imagenet.py +++ b/examples/pl_domain_templates/imagenet.py @@ -48,7 +48,7 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar from pytorch_lightning.strategies import ParallelStrategy -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI class ImageNetLightningModel(LightningModule): diff --git a/examples/pl_hpu/mnist_sample.py b/examples/pl_hpu/mnist_sample.py index de5d7c62ba1d1..9b7fae8f17d0c 100644 --- a/examples/pl_hpu/mnist_sample.py +++ b/examples/pl_hpu/mnist_sample.py @@ -18,7 +18,7 @@ from pytorch_lightning import LightningModule from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule from pytorch_lightning.plugins import HPUPrecisionPlugin -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI class LitClassifier(LightningModule): diff --git a/examples/pl_integrations/dali_image_classifier.py b/examples/pl_integrations/dali_image_classifier.py index 5d5bfc1fa9769..a9741b9a68fc4 100644 --- a/examples/pl_integrations/dali_image_classifier.py +++ b/examples/pl_integrations/dali_image_classifier.py @@ -24,7 +24,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities.imports import _DALI_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: diff --git a/examples/pl_servable_module/production.py b/examples/pl_servable_module/production.py index f1b148eec3f81..e738f02f04f5f 100644 --- a/examples/pl_servable_module/production.py +++ b/examples/pl_servable_module/production.py @@ -13,7 +13,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.serve import ServableModule, ServableModuleValidator -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.cli import LightningCLI DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py new file mode 100644 index 0000000000000..f07c311d1ffdf --- /dev/null +++ b/src/pytorch_lightning/cli.py @@ -0,0 +1,699 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import partial, update_wrapper +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _RequirementAvailable +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn + +_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0") + +if _JSONARGPARSE_SIGNATURES_AVAILABLE: + import docstring_parser + from jsonargparse import ( + ActionConfigFile, + ArgumentParser, + class_from_function, + Namespace, + register_unresolvable_import_paths, + set_config_read_mode, + ) + + register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 + set_config_read_mode(fsspec_enabled=True) +else: + locals()["ArgumentParser"] = object + locals()["Namespace"] = object + + +class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): + def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: + super().__init__(optimizer, *args, **kwargs) + self.monitor = monitor + + +# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: +LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau] +LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]] + + +class LightningArgumentParser(ArgumentParser): + """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize argument parser that supports configuration file input. + + For full details of accepted arguments see `ArgumentParser.__init__ + `_. + """ + if not _JSONARGPARSE_SIGNATURES_AVAILABLE: + raise ModuleNotFoundError( + f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}. Try `pip install -U 'jsonargparse[signatures]'`." + ) + super().__init__(*args, **kwargs) + self.add_argument( + "-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." + ) + self.callback_keys: List[str] = [] + # separate optimizers and lr schedulers to know which were added + self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + + def add_lightning_class_args( + self, + lightning_class: Union[ + Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], + Type[Trainer], + Type[LightningModule], + Type[LightningDataModule], + Type[Callback], + ], + nested_key: str, + subclass_mode: bool = False, + required: bool = True, + ) -> List[str]: + """Adds arguments from a lightning class to a nested key of the parser. + + Args: + lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. + nested_key: Name of the nested namespace to store arguments. + subclass_mode: Whether allow any subclass of the given class. + required: Whether the argument group is required. + + Returns: + A list with the names of the class arguments added. + """ + if callable(lightning_class) and not isinstance(lightning_class, type): + lightning_class = class_from_function(lightning_class) + + if isinstance(lightning_class, type) and issubclass( + lightning_class, (Trainer, LightningModule, LightningDataModule, Callback) + ): + if issubclass(lightning_class, Callback): + self.callback_keys.append(nested_key) + if subclass_mode: + return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) + return self.add_class_arguments( + lightning_class, + nested_key, + fail_untyped=False, + instantiate=not issubclass(lightning_class, Trainer), + sub_configs=True, + ) + raise MisconfigurationException( + f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " + "Trainer, LightningModule, LightningDataModule, or Callback." + ) + + def add_optimizer_args( + self, + optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), + nested_key: str = "optimizer", + link_to: str = "AUTOMATIC", + ) -> None: + """Adds arguments from an optimizer class to a nested key of the parser. + + Args: + optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses. + nested_key: Name of the nested namespace to store arguments. + link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ + if isinstance(optimizer_class, tuple): + assert all(issubclass(o, Optimizer) for o in optimizer_class) + else: + assert issubclass(optimizer_class, Optimizer) + kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + if isinstance(optimizer_class, tuple): + self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + else: + self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs) + self._optimizers[nested_key] = (optimizer_class, link_to) + + def add_lr_scheduler_args( + self, + lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + nested_key: str = "lr_scheduler", + link_to: str = "AUTOMATIC", + ) -> None: + """Adds arguments from a learning rate scheduler class to a nested key of the parser. + + Args: + lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use + tuple to allow subclasses. + nested_key: Name of the nested namespace to store arguments. + link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ + if isinstance(lr_scheduler_class, tuple): + assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) + else: + assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) + kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + if isinstance(lr_scheduler_class, tuple): + self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + else: + self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) + self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + + +class SaveConfigCallback(Callback): + """Saves a LightningCLI config to the log_dir when training starts. + + Args: + parser: The parser object used to parse the configuration. + config: The parsed configuration that will be saved. + config_filename: Filename for the config file. + overwrite: Whether to overwrite an existing config file. + multifile: When input is multiple config files, saved config preserves this structure. + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ + + def __init__( + self, + parser: LightningArgumentParser, + config: Namespace, + config_filename: str, + overwrite: bool = False, + multifile: bool = False, + ) -> None: + self.parser = parser + self.config = config + self.config_filename = config_filename + self.overwrite = overwrite + self.multifile = multifile + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + log_dir = trainer.log_dir # this broadcasts the directory + assert log_dir is not None + config_path = os.path.join(log_dir, self.config_filename) + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = fs.isfile(config_path) if trainer.is_global_zero else False + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." + ) + + # save the file on rank 0 + if trainer.is_global_zero: + # save only on rank zero to avoid race conditions. + # the `log_dir` needs to be created as we rely on the logger to do it usually + # but it hasn't logged anything at this point + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + + +class LightningCLI: + """Implementation of a configurable command line tool for pytorch-lightning.""" + + def __init__( + self, + model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, + datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, + save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, + save_config_filename: str = "config.yaml", + save_config_overwrite: bool = False, + save_config_multifile: bool = False, + trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, + trainer_defaults: Optional[Dict[str, Any]] = None, + seed_everything_default: Union[bool, int] = True, + description: str = "pytorch-lightning trainer command line tool", + env_prefix: str = "PL", + env_parse: bool = False, + parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False, + run: bool = True, + auto_registry: bool = False, + ) -> None: + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which + are called / instantiated using a parsed configuration file and / or command line args. + + Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. + A full configuration yaml would be parsed from ``PL_CONFIG`` if set. + Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. + + For more info, read :ref:`the CLI docs `. + + .. warning:: ``LightningCLI`` is in beta and subject to change. + + Args: + model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a + callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when + called. If ``None``, you can pass a registered model with ``--model=MyModel``. + datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a + callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when + called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. + save_config_callback: A callback class to save the training config. + save_config_filename: Filename for the config file. + save_config_overwrite: Whether to overwrite an existing config file. + save_config_multifile: When input is multiple config files, saved config preserves this structure. + trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a + callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. + trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through + this argument will not be configurable from a configuration file and will always be present for + this particular CLI. Alternatively, configurable callbacks can be added as explained in + :ref:`the CLI docs `. + seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` + seed argument. Set to True to automatically choose a valid seed. + Setting it to False will not call seed_everything. + description: Description of the tool shown when running ``--help``. + env_prefix: Prefix for environment variables. + env_parse: Whether environment variable parsing is enabled. + parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. + subclass_mode_model: Whether model can be any `subclass + `_ + of the given class. + subclass_mode_data: Whether datamodule can be any `subclass + `_ + of the given class. + run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` + method. If set to ``False``, the trainer and model classes will be instantiated only. + auto_registry: Whether to automatically fill up the registries with all defined subclasses. + """ + self.save_config_callback = save_config_callback + self.save_config_filename = save_config_filename + self.save_config_overwrite = save_config_overwrite + self.save_config_multifile = save_config_multifile + self.trainer_class = trainer_class + self.trainer_defaults = trainer_defaults or {} + self.seed_everything_default = seed_everything_default + + if self.seed_everything_default is None: + rank_zero_deprecation( + "Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 " + "and will be removed in v1.9. Set it to `False` instead." + ) + self.seed_everything_default = False + + self.model_class = model_class + # used to differentiate between the original value and the processed value + self._model_class = model_class or LightningModule + self.subclass_mode_model = (model_class is None) or subclass_mode_model + + self.datamodule_class = datamodule_class + # used to differentiate between the original value and the processed value + self._datamodule_class = datamodule_class or LightningDataModule + self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data + + from pytorch_lightning.utilities.cli import _populate_registries + _populate_registries(auto_registry) + + main_kwargs, subparser_kwargs = self._setup_parser_kwargs( + parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 + {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, + ) + self.setup_parser(run, main_kwargs, subparser_kwargs) + self.parse_arguments(self.parser) + + self.subcommand = self.config["subcommand"] if run else None + + self._set_seed() + + self.before_instantiate_classes() + self.instantiate_classes() + + if self.subcommand is not None: + self._run_subcommand(self.subcommand) + + def _setup_parser_kwargs( + self, kwargs: Dict[str, Any], defaults: Dict[str, Any] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if kwargs.keys() & self.subcommands().keys(): + # `kwargs` contains arguments per subcommand + return defaults, kwargs + main_kwargs = defaults + main_kwargs.update(kwargs) + return main_kwargs, {} + + def init_parser(self, **kwargs: Any) -> LightningArgumentParser: + """Method that instantiates the argument parser.""" + kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"]) + return LightningArgumentParser(**kwargs) + + def setup_parser( + self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] + ) -> None: + """Initialize and setup the parser, subcommands, and arguments.""" + self.parser = self.init_parser(**main_kwargs) + if add_subcommands: + self._subcommand_method_arguments: Dict[str, List[str]] = {} + self._add_subcommands(self.parser, **subparser_kwargs) + else: + self._add_arguments(self.parser) + + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Adds default arguments to the parser.""" + parser.add_argument( + "--seed_everything", + type=Union[bool, int], + default=self.seed_everything_default, + help=( + "Set to an int to run seed_everything with this value before classes instantiation." + "Set to True to use a random seed." + ), + ) + + def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Adds arguments from the core classes to the parser.""" + parser.add_lightning_class_args(self.trainer_class, "trainer") + trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} + parser.set_defaults(trainer_defaults) + + parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model) + + if self.datamodule_class is not None: + parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data) + else: + # this should not be required because the user might want to use the `LightningModule` dataloaders + parser.add_lightning_class_args( + self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False + ) + + def _add_arguments(self, parser: LightningArgumentParser) -> None: + # default + core + custom arguments + self.add_default_arguments_to_parser(parser) + self.add_core_arguments_to_parser(parser) + self.add_arguments_to_parser(parser) + # add default optimizer args if necessary + if not parser._optimizers: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args((Optimizer,)) + if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LRSchedulerTypeTuple) + self.link_optimizers_and_lr_schedulers(parser) + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Implement to add extra arguments to the parser or link arguments. + + Args: + parser: The parser object to which arguments can be added + """ + + @staticmethod + def subcommands() -> Dict[str, Set[str]]: + """Defines the list of available subcommands and the arguments to skip.""" + return { + "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, + "validate": {"model", "dataloaders", "datamodule"}, + "test": {"model", "dataloaders", "datamodule"}, + "predict": {"model", "dataloaders", "datamodule"}, + "tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, + } + + def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: + """Adds subcommands to the input parser.""" + parser_subcommands = parser.add_subcommands() + # the user might have passed a builder function + trainer_class = ( + self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class) + ) + # register all subcommands in separate subcommand parsers under the main parser + for subcommand in self.subcommands(): + subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {})) + fn = getattr(trainer_class, subcommand) + # extract the first line description in the docstring for the subcommand help message + description = _get_short_description(fn) + parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) + + def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: + parser = self.init_parser(**kwargs) + self._add_arguments(parser) + # subcommand arguments + skip = self.subcommands()[subcommand] + added = parser.add_method_arguments(klass, subcommand, skip=skip) + # need to save which arguments were added to pass them to the method later + self._subcommand_method_arguments[subcommand] = added + return parser + + @staticmethod + def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: + """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" + optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers} + for key, (class_type, link_to) in optimizers_and_lr_schedulers.items(): + if link_to == "AUTOMATIC": + continue + if isinstance(class_type, tuple): + parser.link_arguments(key, link_to) + else: + add_class_path = _add_class_path_generator(class_type) + parser.link_arguments(key, link_to, compute_fn=add_class_path) + + def parse_arguments(self, parser: LightningArgumentParser) -> None: + """Parses command line arguments and stores it in ``self.config``.""" + self.config = parser.parse_args() + + def before_instantiate_classes(self) -> None: + """Implement to run some code before instantiating the classes.""" + + def instantiate_classes(self) -> None: + """Instantiates the classes and sets their attributes.""" + self.config_init = self.parser.instantiate_classes(self.config) + self.datamodule = self._get(self.config_init, "data") + self.model = self._get(self.config_init, "model") + self._add_configure_optimizers_method_to_model(self.subcommand) + self.trainer = self.instantiate_trainer() + + def instantiate_trainer(self, **kwargs: Any) -> Trainer: + """Instantiates the trainer. + + Args: + kwargs: Any custom trainer arguments. + """ + extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys] + trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} + return self._instantiate_trainer(trainer_config, extra_callbacks) + + def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: + key = "callbacks" + if key in config: + if config[key] is None: + config[key] = [] + elif not isinstance(config[key], list): + config[key] = [config[key]] + config[key].extend(callbacks) + if key in self.trainer_defaults: + value = self.trainer_defaults[key] + config[key] += value if isinstance(value, list) else [value] + if self.save_config_callback and not config.get("fast_dev_run", False): + config_callback = self.save_config_callback( + self._parser(self.subcommand), + self.config.get(str(self.subcommand), self.config), + self.save_config_filename, + overwrite=self.save_config_overwrite, + multifile=self.save_config_multifile, + ) + config[key].append(config_callback) + else: + rank_zero_warn( + f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will" + " not be included." + ) + return self.trainer_class(**config) + + def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: + if subcommand is None: + return self.parser + # return the subcommand parser for the subcommand passed + action_subcommand = self.parser._subcommands_action + return action_subcommand._name_parser_map[subcommand] + + @staticmethod + def configure_optimizers( + lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None + ) -> Any: + """Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` + method. + + Args: + lightning_module: A reference to the model. + optimizer: The optimizer. + lr_scheduler: The learning rate scheduler (if used). + """ + if lr_scheduler is None: + return optimizer + if isinstance(lr_scheduler, ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor}, + } + return [optimizer], [lr_scheduler] + + def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: + """Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method + if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" + parser = self._parser(subcommand) + + def get_automatic( + class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] + ) -> List[str]: + automatic = [] + for key, (base_class, link_to) in register.items(): + if not isinstance(base_class, tuple): + base_class = (base_class,) + if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): + automatic.append(key) + return automatic + + optimizers = get_automatic(Optimizer, parser._optimizers) + lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) + + if len(optimizers) == 0: + return + + if len(optimizers) > 1 or len(lr_schedulers) > 1: + raise MisconfigurationException( + f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " + f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " + "is expected to link the argument groups and implement `configure_optimizers`, see " + "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" + "#optimizers-and-learning-rate-schedulers" + ) + + optimizer_class = parser._optimizers[optimizers[0]][0] + optimizer_init = self._get(self.config_init, optimizers[0]) + if not isinstance(optimizer_class, tuple): + optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) + if not optimizer_init: + # optimizers were registered automatically but not passed by the user + return + + lr_scheduler_init = None + if lr_schedulers: + lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] + lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) + if not isinstance(lr_scheduler_class, tuple): + lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) + + if is_overridden("configure_optimizers", self.model): + _warn( + f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " + f"`{self.__class__.__name__}.configure_optimizers`." + ) + + optimizer = instantiate_class(self.model.parameters(), optimizer_init) + lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None + fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) + update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` + # override the existing method + self.model.configure_optimizers = MethodType(fn, self.model) + + def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: + """Utility to get a config value which might be inside a subcommand.""" + return config.get(str(self.subcommand), config).get(key, default) + + def _run_subcommand(self, subcommand: str) -> None: + """Run the chosen subcommand.""" + before_fn = getattr(self, f"before_{subcommand}", None) + if callable(before_fn): + before_fn() + + default = getattr(self.trainer, subcommand) + fn = getattr(self, subcommand, default) + fn_kwargs = self._prepare_subcommand_kwargs(subcommand) + fn(**fn_kwargs) + + after_fn = getattr(self, f"after_{subcommand}", None) + if callable(after_fn): + after_fn() + + def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: + """Prepares the keyword arguments to pass to the subcommand to run.""" + fn_kwargs = { + k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] + } + fn_kwargs["model"] = self.model + if self.datamodule is not None: + fn_kwargs["datamodule"] = self.datamodule + return fn_kwargs + + def _set_seed(self) -> None: + """Sets the seed.""" + config_seed = self._get(self.config, "seed_everything") + if config_seed is False: + return + if config_seed is True: + # user requested seeding, choose randomly + config_seed = seed_everything(workers=True) + else: + config_seed = seed_everything(config_seed, workers=True) + self.config["seed_everything"] = config_seed + + +def _class_path_from_class(class_type: Type) -> str: + return class_type.__module__ + "." + class_type.__name__ + + +def _global_add_class_path( + class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None +) -> Dict[str, Any]: + if isinstance(init_args, Namespace): + init_args = init_args.as_dict() + return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} + + +def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: + def add_class_path(init_args: Namespace) -> Dict[str, Any]: + return _global_add_class_path(class_type, init_args) + + return add_class_path + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +def _get_short_description(component: object) -> Optional[str]: + if component.__doc__ is None: + return None + try: + docstring = docstring_parser.parse(component.__doc__) + return docstring.short_description + except (ValueError, docstring_parser.ParseError) as ex: + rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index f9d3375a6c6d8..5d0a1a2e106fb 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -11,44 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for LightningCLI.""" +"""Deprecated utilities for LightningCLI.""" import inspect -import os -from functools import partial, update_wrapper -from types import MethodType, ModuleType -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union +from types import ModuleType +from typing import Any, Generator, List, Optional, Tuple, Type import torch from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer -from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RequirementAvailable +import pytorch_lightning.cli as new_cli from pytorch_lightning.utilities.meta import get_all_subclasses -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn - -_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0") - -if _JSONARGPARSE_SIGNATURES_AVAILABLE: - import docstring_parser - from jsonargparse import ( - ActionConfigFile, - ArgumentParser, - class_from_function, - Namespace, - register_unresolvable_import_paths, - set_config_read_mode, - ) - - register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 - set_config_read_mode(fsspec_enabled=True) -else: - locals()["ArgumentParser"] = object - locals()["Namespace"] = object +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation _deprecate_registry_message = ( @@ -130,17 +105,6 @@ def _deprecation(self, show_deprecation: bool = True) -> None: LOGGER_REGISTRY = _Registry() -class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): - def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: - super().__init__(optimizer, *args, **kwargs) - self.monitor = monitor - - -# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]] - def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 if subclasses: @@ -167,643 +131,35 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False) LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False) # `ReduceLROnPlateau` does not subclass `_LRScheduler` - LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False) + LR_SCHEDULER_REGISTRY(cls=new_cli.ReduceLROnPlateau, show_deprecation=False) -class LightningArgumentParser(ArgumentParser): - """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" +def _deprecation(cls: Type) -> None: + rank_zero_deprecation( + f"`{cls.__qualname__}` has been deprecated in v1.7 and will be removed in v1.9." + f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead." + ) +class LightningArgumentParser(new_cli.LightningArgumentParser): def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize argument parser that supports configuration file input. - - For full details of accepted arguments see `ArgumentParser.__init__ - `_. - """ - if not _JSONARGPARSE_SIGNATURES_AVAILABLE: - raise ModuleNotFoundError( - f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}. Try `pip install -U 'jsonargparse[signatures]'`." - ) + _deprecation(type(self)) super().__init__(*args, **kwargs) - self.add_argument( - "-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." - ) - self.callback_keys: List[str] = [] - # separate optimizers and lr schedulers to know which were added - self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - - def add_lightning_class_args( - self, - lightning_class: Union[ - Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - Type[Trainer], - Type[LightningModule], - Type[LightningDataModule], - Type[Callback], - ], - nested_key: str, - subclass_mode: bool = False, - required: bool = True, - ) -> List[str]: - """Adds arguments from a lightning class to a nested key of the parser. - - Args: - lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. - nested_key: Name of the nested namespace to store arguments. - subclass_mode: Whether allow any subclass of the given class. - required: Whether the argument group is required. - - Returns: - A list with the names of the class arguments added. - """ - if callable(lightning_class) and not isinstance(lightning_class, type): - lightning_class = class_from_function(lightning_class) - - if isinstance(lightning_class, type) and issubclass( - lightning_class, (Trainer, LightningModule, LightningDataModule, Callback) - ): - if issubclass(lightning_class, Callback): - self.callback_keys.append(nested_key) - if subclass_mode: - return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) - return self.add_class_arguments( - lightning_class, - nested_key, - fail_untyped=False, - instantiate=not issubclass(lightning_class, Trainer), - sub_configs=True, - ) - raise MisconfigurationException( - f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " - "Trainer, LightningModule, LightningDataModule, or Callback." - ) - - def add_optimizer_args( - self, - optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), - nested_key: str = "optimizer", - link_to: str = "AUTOMATIC", - ) -> None: - """Adds arguments from an optimizer class to a nested key of the parser. - - Args: - optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses. - nested_key: Name of the nested namespace to store arguments. - link_to: Dot notation of a parser key to set arguments or AUTOMATIC. - """ - if isinstance(optimizer_class, tuple): - assert all(issubclass(o, Optimizer) for o in optimizer_class) - else: - assert issubclass(optimizer_class, Optimizer) - kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} - if isinstance(optimizer_class, tuple): - self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) - else: - self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs) - self._optimizers[nested_key] = (optimizer_class, link_to) - - def add_lr_scheduler_args( - self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, - nested_key: str = "lr_scheduler", - link_to: str = "AUTOMATIC", - ) -> None: - """Adds arguments from a learning rate scheduler class to a nested key of the parser. - - Args: - lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use - tuple to allow subclasses. - nested_key: Name of the nested namespace to store arguments. - link_to: Dot notation of a parser key to set arguments or AUTOMATIC. - """ - if isinstance(lr_scheduler_class, tuple): - assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) - else: - assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} - if isinstance(lr_scheduler_class, tuple): - self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) - else: - self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) - self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) - - -class SaveConfigCallback(Callback): - """Saves a LightningCLI config to the log_dir when training starts. - - Args: - parser: The parser object used to parse the configuration. - config: The parsed configuration that will be saved. - config_filename: Filename for the config file. - overwrite: Whether to overwrite an existing config file. - multifile: When input is multiple config files, saved config preserves this structure. - - Raises: - RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run - """ - - def __init__( - self, - parser: LightningArgumentParser, - config: Namespace, - config_filename: str, - overwrite: bool = False, - multifile: bool = False, - ) -> None: - self.parser = parser - self.config = config - self.config_filename = config_filename - self.overwrite = overwrite - self.multifile = multifile - - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: - log_dir = trainer.log_dir # this broadcasts the directory - assert log_dir is not None - config_path = os.path.join(log_dir, self.config_filename) - fs = get_filesystem(log_dir) - - if not self.overwrite: - # check if the file exists on rank 0 - file_exists = fs.isfile(config_path) if trainer.is_global_zero else False - # broadcast whether to fail to all ranks - file_exists = trainer.strategy.broadcast(file_exists) - if file_exists: - raise RuntimeError( - f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" - " results of a previous run. You can delete the previous config file," - " set `LightningCLI(save_config_callback=None)` to disable config saving," - " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." - ) - - # save the file on rank 0 - if trainer.is_global_zero: - # save only on rank zero to avoid race conditions. - # the `log_dir` needs to be created as we rely on the logger to do it usually - # but it hasn't logged anything at this point - fs.makedirs(log_dir, exist_ok=True) - self.parser.save( - self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile - ) - - -class LightningCLI: - """Implementation of a configurable command line tool for pytorch-lightning.""" - - def __init__( - self, - model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, - save_config_filename: str = "config.yaml", - save_config_overwrite: bool = False, - save_config_multifile: bool = False, - trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[Dict[str, Any]] = None, - seed_everything_default: Union[bool, int] = True, - description: str = "pytorch-lightning trainer command line tool", - env_prefix: str = "PL", - env_parse: bool = False, - parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, - subclass_mode_model: bool = False, - subclass_mode_data: bool = False, - run: bool = True, - auto_registry: bool = False, - ) -> None: - """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which - are called / instantiated using a parsed configuration file and / or command line args. - - Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. - A full configuration yaml would be parsed from ``PL_CONFIG`` if set. - Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. - - For more info, read :ref:`the CLI docs `. - - .. warning:: ``LightningCLI`` is in beta and subject to change. - - Args: - model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a - callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when - called. If ``None``, you can pass a registered model with ``--model=MyModel``. - datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a - callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when - called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. - save_config_callback: A callback class to save the training config. - save_config_filename: Filename for the config file. - save_config_overwrite: Whether to overwrite an existing config file. - save_config_multifile: When input is multiple config files, saved config preserves this structure. - trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a - callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. - trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through - this argument will not be configurable from a configuration file and will always be present for - this particular CLI. Alternatively, configurable callbacks can be added as explained in - :ref:`the CLI docs `. - seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` - seed argument. Set to True to automatically choose a valid seed. - Setting it to False will not call seed_everything. - description: Description of the tool shown when running ``--help``. - env_prefix: Prefix for environment variables. - env_parse: Whether environment variable parsing is enabled. - parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. - subclass_mode_model: Whether model can be any `subclass - `_ - of the given class. - subclass_mode_data: Whether datamodule can be any `subclass - `_ - of the given class. - run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` - method. If set to ``False``, the trainer and model classes will be instantiated only. - auto_registry: Whether to automatically fill up the registries with all defined subclasses. - """ - self.save_config_callback = save_config_callback - self.save_config_filename = save_config_filename - self.save_config_overwrite = save_config_overwrite - self.save_config_multifile = save_config_multifile - self.trainer_class = trainer_class - self.trainer_defaults = trainer_defaults or {} - self.seed_everything_default = seed_everything_default - - if self.seed_everything_default is None: - rank_zero_deprecation( - "Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 " - "and will be removed in v1.9. Set it to `False` instead." - ) - self.seed_everything_default = False - - self.model_class = model_class - # used to differentiate between the original value and the processed value - self._model_class = model_class or LightningModule - self.subclass_mode_model = (model_class is None) or subclass_mode_model - - self.datamodule_class = datamodule_class - # used to differentiate between the original value and the processed value - self._datamodule_class = datamodule_class or LightningDataModule - self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data - - _populate_registries(auto_registry) - - main_kwargs, subparser_kwargs = self._setup_parser_kwargs( - parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 - {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, - ) - self.setup_parser(run, main_kwargs, subparser_kwargs) - self.parse_arguments(self.parser) - - self.subcommand = self.config["subcommand"] if run else None - - self._set_seed() - - self.before_instantiate_classes() - self.instantiate_classes() - - if self.subcommand is not None: - self._run_subcommand(self.subcommand) - - def _setup_parser_kwargs( - self, kwargs: Dict[str, Any], defaults: Dict[str, Any] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - if kwargs.keys() & self.subcommands().keys(): - # `kwargs` contains arguments per subcommand - return defaults, kwargs - main_kwargs = defaults - main_kwargs.update(kwargs) - return main_kwargs, {} - - def init_parser(self, **kwargs: Any) -> LightningArgumentParser: - """Method that instantiates the argument parser.""" - kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"]) - return LightningArgumentParser(**kwargs) - - def setup_parser( - self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] - ) -> None: - """Initialize and setup the parser, subcommands, and arguments.""" - self.parser = self.init_parser(**main_kwargs) - if add_subcommands: - self._subcommand_method_arguments: Dict[str, List[str]] = {} - self._add_subcommands(self.parser, **subparser_kwargs) - else: - self._add_arguments(self.parser) - - def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """Adds default arguments to the parser.""" - parser.add_argument( - "--seed_everything", - type=Union[bool, int], - default=self.seed_everything_default, - help=( - "Set to an int to run seed_everything with this value before classes instantiation." - "Set to True to use a random seed." - ), - ) - - def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """Adds arguments from the core classes to the parser.""" - parser.add_lightning_class_args(self.trainer_class, "trainer") - trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} - parser.set_defaults(trainer_defaults) - - parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model) - - if self.datamodule_class is not None: - parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data) - else: - # this should not be required because the user might want to use the `LightningModule` dataloaders - parser.add_lightning_class_args( - self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False - ) - - def _add_arguments(self, parser: LightningArgumentParser) -> None: - # default + core + custom arguments - self.add_default_arguments_to_parser(parser) - self.add_core_arguments_to_parser(parser) - self.add_arguments_to_parser(parser) - # add default optimizer args if necessary - if not parser._optimizers: # already added by the user in `add_arguments_to_parser` - parser.add_optimizer_args((Optimizer,)) - if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` - parser.add_lr_scheduler_args(LRSchedulerTypeTuple) - self.link_optimizers_and_lr_schedulers(parser) - - def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """Implement to add extra arguments to the parser or link arguments. - - Args: - parser: The parser object to which arguments can be added - """ - - @staticmethod - def subcommands() -> Dict[str, Set[str]]: - """Defines the list of available subcommands and the arguments to skip.""" - return { - "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, - "validate": {"model", "dataloaders", "datamodule"}, - "test": {"model", "dataloaders", "datamodule"}, - "predict": {"model", "dataloaders", "datamodule"}, - "tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, - } - - def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: - """Adds subcommands to the input parser.""" - parser_subcommands = parser.add_subcommands() - # the user might have passed a builder function - trainer_class = ( - self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class) - ) - # register all subcommands in separate subcommand parsers under the main parser - for subcommand in self.subcommands(): - subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {})) - fn = getattr(trainer_class, subcommand) - # extract the first line description in the docstring for the subcommand help message - description = _get_short_description(fn) - parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) - - def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: - parser = self.init_parser(**kwargs) - self._add_arguments(parser) - # subcommand arguments - skip = self.subcommands()[subcommand] - added = parser.add_method_arguments(klass, subcommand, skip=skip) - # need to save which arguments were added to pass them to the method later - self._subcommand_method_arguments[subcommand] = added - return parser - - @staticmethod - def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: - """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" - optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers} - for key, (class_type, link_to) in optimizers_and_lr_schedulers.items(): - if link_to == "AUTOMATIC": - continue - if isinstance(class_type, tuple): - parser.link_arguments(key, link_to) - else: - add_class_path = _add_class_path_generator(class_type) - parser.link_arguments(key, link_to, compute_fn=add_class_path) - - def parse_arguments(self, parser: LightningArgumentParser) -> None: - """Parses command line arguments and stores it in ``self.config``.""" - self.config = parser.parse_args() - - def before_instantiate_classes(self) -> None: - """Implement to run some code before instantiating the classes.""" - - def instantiate_classes(self) -> None: - """Instantiates the classes and sets their attributes.""" - self.config_init = self.parser.instantiate_classes(self.config) - self.datamodule = self._get(self.config_init, "data") - self.model = self._get(self.config_init, "model") - self._add_configure_optimizers_method_to_model(self.subcommand) - self.trainer = self.instantiate_trainer() - - def instantiate_trainer(self, **kwargs: Any) -> Trainer: - """Instantiates the trainer. - - Args: - kwargs: Any custom trainer arguments. - """ - extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys] - trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} - return self._instantiate_trainer(trainer_config, extra_callbacks) - - def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: - key = "callbacks" - if key in config: - if config[key] is None: - config[key] = [] - elif not isinstance(config[key], list): - config[key] = [config[key]] - config[key].extend(callbacks) - if key in self.trainer_defaults: - value = self.trainer_defaults[key] - config[key] += value if isinstance(value, list) else [value] - if self.save_config_callback and not config.get("fast_dev_run", False): - config_callback = self.save_config_callback( - self._parser(self.subcommand), - self.config.get(str(self.subcommand), self.config), - self.save_config_filename, - overwrite=self.save_config_overwrite, - multifile=self.save_config_multifile, - ) - config[key].append(config_callback) - else: - rank_zero_warn( - f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will" - " not be included." - ) - return self.trainer_class(**config) - - def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: - if subcommand is None: - return self.parser - # return the subcommand parser for the subcommand passed - action_subcommand = self.parser._subcommands_action - return action_subcommand._name_parser_map[subcommand] - - @staticmethod - def configure_optimizers( - lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None - ) -> Any: - """Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` - method. - - Args: - lightning_module: A reference to the model. - optimizer: The optimizer. - lr_scheduler: The learning rate scheduler (if used). - """ - if lr_scheduler is None: - return optimizer - if isinstance(lr_scheduler, ReduceLROnPlateau): - return { - "optimizer": optimizer, - "lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor}, - } - return [optimizer], [lr_scheduler] - - def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: - """Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method - if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" - parser = self._parser(subcommand) - - def get_automatic( - class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] - ) -> List[str]: - automatic = [] - for key, (base_class, link_to) in register.items(): - if not isinstance(base_class, tuple): - base_class = (base_class,) - if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): - automatic.append(key) - return automatic - - optimizers = get_automatic(Optimizer, parser._optimizers) - lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) - - if len(optimizers) == 0: - return - - if len(optimizers) > 1 or len(lr_schedulers) > 1: - raise MisconfigurationException( - f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " - f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " - "is expected to link the argument groups and implement `configure_optimizers`, see " - "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" - "#optimizers-and-learning-rate-schedulers" - ) - - optimizer_class = parser._optimizers[optimizers[0]][0] - optimizer_init = self._get(self.config_init, optimizers[0]) - if not isinstance(optimizer_class, tuple): - optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) - if not optimizer_init: - # optimizers were registered automatically but not passed by the user - return - - lr_scheduler_init = None - if lr_schedulers: - lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] - lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) - if not isinstance(lr_scheduler_class, tuple): - lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) - - if is_overridden("configure_optimizers", self.model): - _warn( - f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " - f"`{self.__class__.__name__}.configure_optimizers`." - ) - - optimizer = instantiate_class(self.model.parameters(), optimizer_init) - lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None - fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) - update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` - # override the existing method - self.model.configure_optimizers = MethodType(fn, self.model) - - def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: - """Utility to get a config value which might be inside a subcommand.""" - return config.get(str(self.subcommand), config).get(key, default) - - def _run_subcommand(self, subcommand: str) -> None: - """Run the chosen subcommand.""" - before_fn = getattr(self, f"before_{subcommand}", None) - if callable(before_fn): - before_fn() - - default = getattr(self.trainer, subcommand) - fn = getattr(self, subcommand, default) - fn_kwargs = self._prepare_subcommand_kwargs(subcommand) - fn(**fn_kwargs) - - after_fn = getattr(self, f"after_{subcommand}", None) - if callable(after_fn): - after_fn() - def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: - """Prepares the keyword arguments to pass to the subcommand to run.""" - fn_kwargs = { - k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] - } - fn_kwargs["model"] = self.model - if self.datamodule is not None: - fn_kwargs["datamodule"] = self.datamodule - return fn_kwargs - - def _set_seed(self) -> None: - """Sets the seed.""" - config_seed = self._get(self.config, "seed_everything") - if config_seed is False: - return - if config_seed is True: - # user requested seeding, choose randomly - config_seed = seed_everything(workers=True) - else: - config_seed = seed_everything(config_seed, workers=True) - self.config["seed_everything"] = config_seed - - -def _class_path_from_class(class_type: Type) -> str: - return class_type.__module__ + "." + class_type.__name__ - - -def _global_add_class_path( - class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None -) -> Dict[str, Any]: - if isinstance(init_args, Namespace): - init_args = init_args.as_dict() - return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} - - -def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: - def add_class_path(init_args: Namespace) -> Dict[str, Any]: - return _global_add_class_path(class_type, init_args) - - return add_class_path - - -def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: - """Instantiates a class with the given args and init. +class SaveConfigCallback(new_cli.SaveConfigCallback): + def __init__(self, *args: Any, **kwargs: Any) -> None: + _deprecation(type(self)) + super().__init__(*args, **kwargs) - Args: - args: Positional arguments required for instantiation. - init: Dict of the form {"class_path":...,"init_args":...}. - Returns: - The instantiated class object. - """ - kwargs = init.get("init_args", {}) - if not isinstance(args, tuple): - args = (args,) - class_module, class_name = init["class_path"].rsplit(".", 1) - module = __import__(class_module, fromlist=[class_name]) - args_class = getattr(module, class_name) - return args_class(*args, **kwargs) +class LightningCLI(new_cli.LightningCLI): + def __init__(self, *args: Any, **kwargs: Any) -> None: + _deprecation(type(self)) + super().__init__(*args, **kwargs) -def _get_short_description(component: object) -> Optional[str]: - if component.__doc__ is None: - return None - try: - docstring = docstring_parser.parse(component.__doc__) - return docstring.short_description - except (ValueError, docstring_parser.ParseError) as ex: - rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") +def instantiate_class(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.cli.instantiate_class` has been deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent function in `pytorch_lightning.cli.instantiate_class` instead." + ) + return new_cli.instantiate_class(*args, **kwargs) \ No newline at end of file diff --git a/tests/tests_app/core/scripts/lightning_cli.py b/tests/tests_app/core/scripts/lightning_cli.py index e6f2e7b3b0198..db2e4ccba0196 100644 --- a/tests/tests_app/core/scripts/lightning_cli.py +++ b/tests/tests_app/core/scripts/lightning_cli.py @@ -5,9 +5,7 @@ from torch.utils.data import DataLoader, Dataset if _is_pytorch_lightning_available(): - from pytorch_lightning import LightningDataModule, LightningModule - from pytorch_lightning.utilities import cli - + from pytorch_lightning import LightningDataModule, LightningModule, cli if __name__ == "__main__": diff --git a/tests/tests_app/core/scripts/registry.py b/tests/tests_app/core/scripts/registry.py index 35d6921756f36..16898374eb023 100644 --- a/tests/tests_app/core/scripts/registry.py +++ b/tests/tests_app/core/scripts/registry.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningDataModule, LightningModule - from pytorch_lightning.utilities.cli import LightningCLI + from pytorch_lightning.cli import LightningCLI class RandomDataset(Dataset): def __init__(self, size, length): diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py index 5d0c5a80d0155..623301b075d42 100644 --- a/tests/tests_app/utilities/test_introspection.py +++ b/tests/tests_app/utilities/test_introspection.py @@ -9,7 +9,7 @@ if _is_pytorch_lightning_available(): from pytorch_lightning import Trainer - from pytorch_lightning.utilities.cli import LightningCLI + from pytorch_lightning.cli import LightningCLI from tests_app import _PROJECT_ROOT diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 9c7d02d499ab4..d2267c36c7317 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -27,7 +27,7 @@ from pytorch_lightning.profiler.pytorch import PyTorchProfiler, RegisterRecordFunction, ScheduleWrapper from pytorch_lightning.profiler.simple import SimpleProfiler from pytorch_lightning.profiler.xla import XLAProfiler -from pytorch_lightning.utilities.cli import ( +from pytorch_lightning.cli import ( _deprecate_auto_registry_message, _deprecate_registry_message, CALLBACK_REGISTRY, diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/test_cli.py similarity index 97% rename from tests/tests_pytorch/utilities/test_cli.py rename to tests/tests_pytorch/test_cli.py index caafa9a3ca719..083d6ab25c363 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect import json import os @@ -27,19 +26,14 @@ import pytest import torch import yaml -from packaging import version +from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.utils import no_warning_call from torch.optim import SGD from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel -from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger -from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.strategies import DDPStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.cli import ( +from pytorch_lightning.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, instantiate_class, LightningArgumentParser, @@ -47,14 +41,13 @@ LRSchedulerTypeTuple, SaveConfigCallback, ) +from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel +from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger +from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE -from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.utils import no_warning_call - -torchvision_version = version.parse("0") -if _TORCHVISION_AVAILABLE: - torchvision_version = version.parse(__import__("torchvision").__version__) if _JSONARGPARSE_SIGNATURES_AVAILABLE: from jsonargparse import lazy_instance @@ -524,42 +517,6 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai assert isinstance(cli.model.submodule2, BoringModel) -@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") -def test_lightning_cli_torch_modules(tmpdir): - class TestModule(BoringModel): - def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): - super().__init__() - self.activation = activation - self.transform = transform - - config = """model: - activation: - class_path: torch.nn.LeakyReLU - init_args: - negative_slope: 0.2 - transform: - - class_path: torchvision.transforms.Resize - init_args: - size: 64 - - class_path: torchvision.transforms.CenterCrop - init_args: - size: 64 - """ - config_path = tmpdir / "config.yaml" - with open(config_path, "w") as f: - f.write(config) - - cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] - - with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = LightningCLI(TestModule, run=False) - - assert isinstance(cli.model.activation, torch.nn.LeakyReLU) - assert cli.model.activation.negative_slope == 0.2 - assert len(cli.model.transform) == 2 - assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) - - class BoringModelRequiredClasses(BoringModel): def __init__(self, num_classes: int, batch_size: int = 8): super().__init__() From 39c39675550078a325eb4cf750f7941565479d3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 17:59:09 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../image_classifier_4_lightning_module.py | 2 +- .../image_classifier_5_lightning_datamodule.py | 2 +- examples/pl_basics/autoencoder.py | 2 +- examples/pl_basics/backbone_image_classifier.py | 2 +- examples/pl_basics/profiler_example.py | 2 +- examples/pl_domain_templates/imagenet.py | 2 +- examples/pl_hpu/mnist_sample.py | 2 +- examples/pl_integrations/dali_image_classifier.py | 2 +- examples/pl_servable_module/production.py | 2 +- src/pytorch_lightning/cli.py | 1 + src/pytorch_lightning/utilities/cli.py | 6 +++--- tests/tests_app/core/scripts/lightning_cli.py | 2 +- .../deprecated_api/test_remove_1-9.py | 14 +++++++------- tests/tests_pytorch/test_cli.py | 4 ++-- 14 files changed, 23 insertions(+), 22 deletions(-) diff --git a/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py b/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py index 6c2c84014a5f2..1b2caa5637865 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py +++ b/examples/convert_from_pt_to_pl/image_classifier_4_lightning_module.py @@ -23,9 +23,9 @@ from torchmetrics import Accuracy from pytorch_lightning import cli_lightning_logo, LightningModule +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.cli import LightningCLI DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py b/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py index 10340611aa310..0d970b561263e 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py +++ b/examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py @@ -23,9 +23,9 @@ from torchmetrics import Accuracy from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.cli import LightningCLI DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pl_basics/autoencoder.py b/examples/pl_basics/autoencoder.py index 9b1e571f0a474..0fd9ddae18020 100644 --- a/examples/pl_basics/autoencoder.py +++ b/examples/pl_basics/autoencoder.py @@ -24,8 +24,8 @@ from torch.utils.data import DataLoader, random_split from pytorch_lightning import callbacks, cli_lightning_logo, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only diff --git a/examples/pl_basics/backbone_image_classifier.py b/examples/pl_basics/backbone_image_classifier.py index 3e1854df85b54..f09feec900d51 100644 --- a/examples/pl_basics/backbone_image_classifier.py +++ b/examples/pl_basics/backbone_image_classifier.py @@ -23,8 +23,8 @@ from torch.utils.data import DataLoader, random_split from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule -from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: diff --git a/examples/pl_basics/profiler_example.py b/examples/pl_basics/profiler_example.py index 0fc27aa60f652..6df8f769973c6 100644 --- a/examples/pl_basics/profiler_example.py +++ b/examples/pl_basics/profiler_example.py @@ -31,8 +31,8 @@ import torchvision.transforms as T from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule -from pytorch_lightning.profilers.pytorch import PyTorchProfiler from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.profilers.pytorch import PyTorchProfiler DEFAULT_CMD_LINE = ( "fit", diff --git a/examples/pl_domain_templates/imagenet.py b/examples/pl_domain_templates/imagenet.py index bfbfedb58f990..93284963db4b4 100644 --- a/examples/pl_domain_templates/imagenet.py +++ b/examples/pl_domain_templates/imagenet.py @@ -47,8 +47,8 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar -from pytorch_lightning.strategies import ParallelStrategy from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.strategies import ParallelStrategy class ImageNetLightningModel(LightningModule): diff --git a/examples/pl_hpu/mnist_sample.py b/examples/pl_hpu/mnist_sample.py index 9b7fae8f17d0c..d48dd3da25994 100644 --- a/examples/pl_hpu/mnist_sample.py +++ b/examples/pl_hpu/mnist_sample.py @@ -16,9 +16,9 @@ from torch.nn import functional as F from pytorch_lightning import LightningModule +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule from pytorch_lightning.plugins import HPUPrecisionPlugin -from pytorch_lightning.cli import LightningCLI class LitClassifier(LightningModule): diff --git a/examples/pl_integrations/dali_image_classifier.py b/examples/pl_integrations/dali_image_classifier.py index a9741b9a68fc4..7385196a0ba41 100644 --- a/examples/pl_integrations/dali_image_classifier.py +++ b/examples/pl_integrations/dali_image_classifier.py @@ -23,8 +23,8 @@ from torch.utils.data import random_split from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule -from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.utilities.imports import _DALI_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: diff --git a/examples/pl_servable_module/production.py b/examples/pl_servable_module/production.py index e738f02f04f5f..3ecd72376417a 100644 --- a/examples/pl_servable_module/production.py +++ b/examples/pl_servable_module/production.py @@ -12,8 +12,8 @@ from PIL import Image as PILImage from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule -from pytorch_lightning.serve import ServableModule, ServableModuleValidator from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.serve import ServableModule, ServableModuleValidator DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index f07c311d1ffdf..1b65f3ad72a2c 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -329,6 +329,7 @@ def __init__( self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data from pytorch_lightning.utilities.cli import _populate_registries + _populate_registries(auto_registry) main_kwargs, subparser_kwargs = self._setup_parser_kwargs( diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index 5d0a1a2e106fb..3a5ddc4e7a930 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -25,7 +25,6 @@ from pytorch_lightning.utilities.meta import get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation - _deprecate_registry_message = ( "`LightningCLI`'s registries were deprecated in v1.7 and will be removed " "in v1.9. Now any imported subclass is automatically available by name in " @@ -105,7 +104,6 @@ def _deprecation(self, show_deprecation: bool = True) -> None: LOGGER_REGISTRY = _Registry() - def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 if subclasses: rank_zero_deprecation(_deprecate_auto_registry_message) @@ -140,11 +138,13 @@ def _deprecation(cls: Type) -> None: f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead." ) + class LightningArgumentParser(new_cli.LightningArgumentParser): def __init__(self, *args: Any, **kwargs: Any) -> None: _deprecation(type(self)) super().__init__(*args, **kwargs) + class SaveConfigCallback(new_cli.SaveConfigCallback): def __init__(self, *args: Any, **kwargs: Any) -> None: _deprecation(type(self)) @@ -162,4 +162,4 @@ def instantiate_class(*args: Any, **kwargs: Any) -> Any: "`pytorch_lightning.utilities.cli.instantiate_class` has been deprecated in v1.7 and will be removed in v1.9." " Use the equivalent function in `pytorch_lightning.cli.instantiate_class` instead." ) - return new_cli.instantiate_class(*args, **kwargs) \ No newline at end of file + return new_cli.instantiate_class(*args, **kwargs) diff --git a/tests/tests_app/core/scripts/lightning_cli.py b/tests/tests_app/core/scripts/lightning_cli.py index db2e4ccba0196..7baf780f291c9 100644 --- a/tests/tests_app/core/scripts/lightning_cli.py +++ b/tests/tests_app/core/scripts/lightning_cli.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, Dataset if _is_pytorch_lightning_available(): - from pytorch_lightning import LightningDataModule, LightningModule, cli + from pytorch_lightning import cli, LightningDataModule, LightningModule if __name__ == "__main__": diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index d2267c36c7317..5b8c62af9cd80 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -19,6 +19,13 @@ import pytorch_lightning.loggers.base as logger_base from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.cli import ( + _deprecate_auto_registry_message, + _deprecate_registry_message, + CALLBACK_REGISTRY, + LightningCLI, + SaveConfigCallback, +) from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profiler.advanced import AdvancedProfiler @@ -27,13 +34,6 @@ from pytorch_lightning.profiler.pytorch import PyTorchProfiler, RegisterRecordFunction, ScheduleWrapper from pytorch_lightning.profiler.simple import SimpleProfiler from pytorch_lightning.profiler.xla import XLAProfiler -from pytorch_lightning.cli import ( - _deprecate_auto_registry_message, - _deprecate_registry_message, - CALLBACK_REGISTRY, - LightningCLI, - SaveConfigCallback, -) from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 083d6ab25c363..ca32b8964f8a8 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -26,8 +26,6 @@ import pytest import torch import yaml -from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.utils import no_warning_call from torch.optim import SGD from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR @@ -48,6 +46,8 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.utils import no_warning_call if _JSONARGPARSE_SIGNATURES_AVAILABLE: from jsonargparse import lazy_instance From 587d36934b378dbfb5be9e7faf297ee1ecb4fb91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 20:44:24 +0200 Subject: [PATCH 3/8] Fix --- tests/tests_pytorch/deprecated_api/test_remove_1-9.py | 10 ++++------ tests/tests_pytorch/test_cli.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 5b8c62af9cd80..ab79884773067 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -20,12 +20,10 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.cli import ( - _deprecate_auto_registry_message, - _deprecate_registry_message, - CALLBACK_REGISTRY, LightningCLI, SaveConfigCallback, ) +import pytorch_lightning.utilities.cli as old_cli from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profiler.advanced import AdvancedProfiler @@ -152,15 +150,15 @@ def test_deprecated_dataloader_reset(): def test_lightningCLI_registries_register(): - with pytest.deprecated_call(match=_deprecate_registry_message): + with pytest.deprecated_call(match=old_cli._deprecate_registry_message): - @CALLBACK_REGISTRY + @old_cli.CALLBACK_REGISTRY class CustomCallback(SaveConfigCallback): pass def test_lightningCLI_registries_register_automatically(): - with pytest.deprecated_call(match=_deprecate_auto_registry_message): + with pytest.deprecated_call(match=old_cli._deprecate_auto_registry_message): with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, run=False, auto_registry=True) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index ca32b8964f8a8..1847bca6e4004 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -550,7 +550,7 @@ def add_arguments_to_parser(self, parser): parser.link_arguments("data.batch_size", "model.init_args.batch_size") parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate") - cli_args[-1] = "--model=tests_pytorch.utilities.test_cli.BoringModelRequiredClasses" + cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses" with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI( From d3d5324b13792fc76f90e6a03f323b77364a354e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 21:04:37 +0200 Subject: [PATCH 4/8] Deprecations and CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 2 ++ src/pytorch_lightning/utilities/cli.py | 2 +- .../deprecated_api/test_remove_1-9.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6aed707726079..115589bc38590 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -173,6 +173,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221)) +- Deprecated public utilities in `pytorch_lightning.utilities.cli.LightningCLI` in favor of equivalent copies in `pytorch_lightning.cli.LightningCLI` ([#13767](https://github.com/PyTorchLightning/pytorch-lightning/pull/13767)) + - Deprecated `pytorch_lightning.profiler` in favor of `pytorch_lightning.profilers` ([#12308](https://github.com/PyTorchLightning/pytorch-lightning/pull/12308)) diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index 3a5ddc4e7a930..285b5361f9cd2 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -134,7 +134,7 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 def _deprecation(cls: Type) -> None: rank_zero_deprecation( - f"`{cls.__qualname__}` has been deprecated in v1.7 and will be removed in v1.9." + f"`pytorch_lightning.utilities.cli.{cls.__name__}` has been deprecated in v1.7 and will be removed in v1.9." f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead." ) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index ab79884773067..3763e02934b0e 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -13,6 +13,7 @@ # limitations under the License. from unittest import mock +from unittest.mock import Mock import pytest @@ -163,6 +164,21 @@ def test_lightningCLI_registries_register_automatically(): LightningCLI(BoringModel, run=False, auto_registry=True) +def test_lightningCLI_old_module_deprecation(): + with pytest.deprecated_call(match=r"LightningCLI.*deprecated in v1.7.*Use the equivalent class"): + with mock.patch("sys.argv", ["any.py"]): + old_cli.LightningCLI(BoringModel, run=False) + + with pytest.deprecated_call(match=r"SaveConfigCallback.*deprecated in v1.7.*Use the equivalent class"): + old_cli.SaveConfigCallback(Mock(), Mock(), Mock()) + + with pytest.deprecated_call(match=r"LightningArgumentParser.*deprecated in v1.7.*Use the equivalent class"): + old_cli.LightningArgumentParser() + + with pytest.deprecated_call(match=r"instantiate_class.*deprecated in v1.7.*Use the equivalent function"): + assert isinstance(old_cli.instantiate_class(tuple(), {'class_path': 'pytorch_lightning.Trainer'}), Trainer) + + def test_profiler_deprecation_warning(): assert "Profiler` is deprecated in v1.7" in Profiler.__doc__ From f60563319c6717d9e943f0e03f825528f241981a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 19:06:15 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/deprecated_api/test_remove_1-9.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 3763e02934b0e..54c59bec62b5d 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -18,13 +18,10 @@ import pytest import pytorch_lightning.loggers.base as logger_base +import pytorch_lightning.utilities.cli as old_cli from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu import GPUAccelerator -from pytorch_lightning.cli import ( - LightningCLI, - SaveConfigCallback, -) -import pytorch_lightning.utilities.cli as old_cli +from pytorch_lightning.cli import LightningCLI, SaveConfigCallback from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profiler.advanced import AdvancedProfiler @@ -176,7 +173,7 @@ def test_lightningCLI_old_module_deprecation(): old_cli.LightningArgumentParser() with pytest.deprecated_call(match=r"instantiate_class.*deprecated in v1.7.*Use the equivalent function"): - assert isinstance(old_cli.instantiate_class(tuple(), {'class_path': 'pytorch_lightning.Trainer'}), Trainer) + assert isinstance(old_cli.instantiate_class(tuple(), {"class_path": "pytorch_lightning.Trainer"}), Trainer) def test_profiler_deprecation_warning(): From 6bc0105a871fba10ad349f39af85d4281ecdd1fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 21:06:53 +0200 Subject: [PATCH 6/8] Docs --- .../cli/lightning_cli_advanced_2.rst | 2 +- .../cli/lightning_cli_advanced_3.rst | 10 +++++----- .../cli/lightning_cli_expert.rst | 18 +++++++++--------- docs/source-pytorch/cli/lightning_cli_faq.rst | 6 +++--- .../cli/lightning_cli_intermediate.rst | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_2.rst b/docs/source-pytorch/cli/lightning_cli_advanced_2.rst index 0474699db706d..dd80114070225 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_2.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_2.rst @@ -15,7 +15,7 @@ pass - class LightningCLI(pl.utilities.cli.LightningCLI): + class LightningCLI(pl.cli.LightningCLI): def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs): super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index 0e9c3f406d7ec..df062061022c9 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -15,7 +15,7 @@ pass - class LightningCLI(pl.utilities.cli.LightningCLI): + class LightningCLI(pl.cli.LightningCLI): def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs): super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs) @@ -88,7 +88,7 @@ Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.module.LightningModule` and :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is -imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of +imported before the :class:`~pytorch_lightning.cli.LightningCLI` class is run, the name can be used instead of the full import path. From command line the syntax is the following: @@ -117,7 +117,7 @@ callback appended. Here is an example: .. note:: - Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`) + Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.cli.SaveConfigCallback`) always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input config files. @@ -306,7 +306,7 @@ example can be when one wants to add support for multiple optimizers: .. code-block:: python - from pytorch_lightning.utilities.cli import instantiate_class + from pytorch_lightning.cli import instantiate_class class MyModel(LightningModule): @@ -330,7 +330,7 @@ example can be when one wants to add support for multiple optimizers: cli = MyLightningCLI(MyModel) The value given to :code:`optimizer*_init` will always be a dictionary including :code:`class_path` and -:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class` +:code:`init_args` entries. The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the :code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. diff --git a/docs/source-pytorch/cli/lightning_cli_expert.rst b/docs/source-pytorch/cli/lightning_cli_expert.rst index 50292b3dd251a..60454f5e9bd82 100644 --- a/docs/source-pytorch/cli/lightning_cli_expert.rst +++ b/docs/source-pytorch/cli/lightning_cli_expert.rst @@ -15,7 +15,7 @@ pass - class LightningCLI(pl.utilities.cli.LightningCLI): + class LightningCLI(pl.cli.LightningCLI): def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs): super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs) @@ -62,23 +62,23 @@ Eliminate config boilerplate (Advanced) Customize the LightningCLI ************************** -The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some +The init parameters of the :class:`~pytorch_lightning.cli.LightningCLI` class can be used to customize some things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to instantiate the trainer and configuration parser. Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be extended to customize different parts of the command line tool. The argument parser class used by -:class:`~pytorch_lightning.utilities.cli.LightningCLI` is -:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus +:class:`~pytorch_lightning.cli.LightningCLI` is +:class:`~pytorch_lightning.cli.LightningArgumentParser` which is an extension of python's argparse, thus adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring parameters to have type hints. For more details about this please refer to the `respective documentation `_. -The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the -:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include +The :class:`~pytorch_lightning.cli.LightningCLI` class has the +:meth:`~pytorch_lightning.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The -:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before +:class:`~pytorch_lightning.cli.LightningCLI` class also has two methods that can be used to run code before and after the trainer runs: :code:`before_` and :code:`after_`. A realistic example for these would be to send an email before and after the execution. The code for the :code:`fit` subcommand would be something like: @@ -104,7 +104,7 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train .. tip:: - Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other + Have a look at the :class:`~pytorch_lightning.cli.LightningCLI` class API reference to learn about other methods that can be extended to customize a CLI. ---- @@ -211,7 +211,7 @@ A more compact version that avoids writing a dictionary would be: ************************ Connect two config files ************************ -Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the +Another case in which it might be desired to extend :class:`~pytorch_lightning.cli.LightningCLI` is that the model and data module depend on a common parameter. For example in some cases both classes require to know the :code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented diff --git a/docs/source-pytorch/cli/lightning_cli_faq.rst b/docs/source-pytorch/cli/lightning_cli_faq.rst index ca1be71cae7f8..672e27979f7c9 100644 --- a/docs/source-pytorch/cli/lightning_cli_faq.rst +++ b/docs/source-pytorch/cli/lightning_cli_faq.rst @@ -15,7 +15,7 @@ pass - class LightningCLI(pl.utilities.cli.LightningCLI): + class LightningCLI(pl.cli.LightningCLI): def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs): super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs) @@ -65,7 +65,7 @@ there is a failure an exception is raised and the full stack trace printed. Reproducibility with the LightningCLI ************************************* The topic of reproducibility is complex and it is impossible to guarantee reproducibility by just providing a class that -people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to +people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.cli.LightningCLI` tries to give a framework and recommendations to make reproducibility simpler. When an experiment is run, it is good practice to use a stable version of the source code, either being a released @@ -85,7 +85,7 @@ For every CLI implemented, users are encouraged to learn how to run it by readin :code:`--help` option and use the :code:`--print_config` option to guide the writing of config files. A few more details that might not be clear by only reading the help are the following. -:class:`~pytorch_lightning.utilities.cli.LightningCLI` is based on argparse and as such follows the same arguments style +:class:`~pytorch_lightning.cli.LightningCLI` is based on argparse and as such follows the same arguments style as many POSIX command line tools. Long options are prefixed with two dashes and its corresponding values should be provided with an empty space or an equal sign, as :code:`--option value` or :code:`--option=value`. Command line options are parsed from left to right, therefore if a setting appears multiple times the value most to the right will override diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate.rst b/docs/source-pytorch/cli/lightning_cli_intermediate.rst index 5f2cd3bca272d..6ed4921305c0d 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate.rst @@ -82,7 +82,7 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC # main.py import torch - from pytorch_lightning.utilities.cli import LightningCLI + from pytorch_lightning.cli import LightningCLI # simple demo classes for your convenience from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule From 10ac9a1bb2b68a3e590064829176b7f5be5e34f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 21:29:09 +0200 Subject: [PATCH 7/8] Add back test --- tests/tests_pytorch/test_cli.py | 47 +++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 1847bca6e4004..f415817574538 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -26,11 +26,6 @@ import pytest import torch import yaml -from torch.optim import SGD -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR - -from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, instantiate_class, @@ -39,6 +34,11 @@ LRSchedulerTypeTuple, SaveConfigCallback, ) +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + +from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger from pytorch_lightning.plugins.environments import SLURMEnvironment @@ -46,6 +46,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.utils import no_warning_call @@ -517,6 +518,42 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai assert isinstance(cli.model.submodule2, BoringModel) +@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available") +def test_lightning_cli_torch_modules(tmpdir): + class TestModule(BoringModel): + def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): + super().__init__() + self.activation = activation + self.transform = transform + + config = """model: + activation: + class_path: torch.nn.LeakyReLU + init_args: + negative_slope: 0.2 + transform: + - class_path: torchvision.transforms.Resize + init_args: + size: 64 + - class_path: torchvision.transforms.CenterCrop + init_args: + size: 64 + """ + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: + f.write(config) + + cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(TestModule, run=False) + + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) + assert cli.model.activation.negative_slope == 0.2 + assert len(cli.model.transform) == 2 + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) + + class BoringModelRequiredClasses(BoringModel): def __init__(self, num_classes: int, batch_size: int = 8): super().__init__() From 2218b807f48abbb181cb3737feb07c58dd90542f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 19:58:52 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/test_cli.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index f415817574538..9a00e0eb2fc75 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -26,6 +26,11 @@ import pytest import torch import yaml +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + +from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, instantiate_class, @@ -34,11 +39,6 @@ LRSchedulerTypeTuple, SaveConfigCallback, ) -from torch.optim import SGD -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR - -from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger from pytorch_lightning.plugins.environments import SLURMEnvironment