diff --git a/examples/pl_basics/backbone_image_classifier.py b/examples/pl_basics/backbone_image_classifier.py index f09feec900d51..95f4385842488 100644 --- a/examples/pl_basics/backbone_image_classifier.py +++ b/examples/pl_basics/backbone_image_classifier.py @@ -124,7 +124,9 @@ def predict_dataloader(self): def cli_main(): - cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) + cli = LightningCLI( + LitClassifier, MyDataModule, seed_everything_default=1234, save_config_kwargs={"overwrite": True}, run=False + ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule) diff --git a/examples/pl_basics/profiler_example.py b/examples/pl_basics/profiler_example.py index 39c147c938d06..517c33c052e7b 100644 --- a/examples/pl_basics/profiler_example.py +++ b/examples/pl_basics/profiler_example.py @@ -107,7 +107,10 @@ def cli_main(): sys.argv += DEFAULT_CMD_LINE LightningCLI( - ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()} + ModelToProfile, + CIFAR10DataModule, + save_config_kwargs={"overwrite": True}, + trainer_defaults={"profiler": PyTorchProfiler()}, ) diff --git a/examples/pl_domain_templates/imagenet.py b/examples/pl_domain_templates/imagenet.py index 0a3b55d2a6a04..776871e636529 100644 --- a/examples/pl_domain_templates/imagenet.py +++ b/examples/pl_domain_templates/imagenet.py @@ -190,5 +190,5 @@ def test_dataloader(self): ], }, seed_everything_default=42, - save_config_overwrite=True, + save_config_kwargs={"overwrite": True}, ) diff --git a/examples/pl_hpu/mnist_sample.py b/examples/pl_hpu/mnist_sample.py index d48dd3da25994..b0422bb1dfc78 100644 --- a/examples/pl_hpu/mnist_sample.py +++ b/examples/pl_hpu/mnist_sample.py @@ -66,7 +66,7 @@ def configure_optimizers(self): "plugins": lazy_instance(HPUPrecisionPlugin, precision=16), }, run=False, - save_config_overwrite=True, + save_config_kwargs={"overwrite": True}, ) # Run the model ⚡ diff --git a/examples/pl_integrations/dali_image_classifier.py b/examples/pl_integrations/dali_image_classifier.py index 7385196a0ba41..49d55c55e6af0 100644 --- a/examples/pl_integrations/dali_image_classifier.py +++ b/examples/pl_integrations/dali_image_classifier.py @@ -194,7 +194,9 @@ def cli_main(): if not _DALI_AVAILABLE: return - cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) + cli = LightningCLI( + LitClassifier, MyDataModule, seed_everything_default=1234, save_config_kwargs={"overwrite": True}, run=False + ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) diff --git a/examples/pl_servable_module/production.py b/examples/pl_servable_module/production.py index 4005fecb7307d..aef5a29a1f5c3 100644 --- a/examples/pl_servable_module/production.py +++ b/examples/pl_servable_module/production.py @@ -109,7 +109,7 @@ def cli_main(): ProductionReadyModel, CIFAR10DataModule, seed_everything_default=42, - save_config_overwrite=True, + save_config_kwargs={"overwrite": True}, run=False, trainer_defaults={ "callbacks": [ServableModuleValidator()], diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5f8289a76c35c..c3b9c14917836 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -98,6 +98,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `pytorch_lightning.profiler.*` classes in favor of `pytorch_lightning.profilers` ([#16059](https://github.com/PyTorchLightning/pytorch-lightning/pull/16059)) +- Removed the deprecated `pytorch_lightning.utilities.cli` module in favor of `pytorch_lightning.cli` ([#16116](https://github.com/PyTorchLightning/pytorch-lightning/pull/16116)) + + ### Fixed - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) diff --git a/src/pytorch_lightning/_graveyard/__init__.py b/src/pytorch_lightning/_graveyard/__init__.py index a4f1d5ca661cf..9bd8c1e31874b 100644 --- a/src/pytorch_lightning/_graveyard/__init__.py +++ b/src/pytorch_lightning/_graveyard/__init__.py @@ -14,6 +14,7 @@ import pytorch_lightning._graveyard.accelerator import pytorch_lightning._graveyard.callbacks +import pytorch_lightning._graveyard.cli import pytorch_lightning._graveyard.core import pytorch_lightning._graveyard.legacy_import_unpickler import pytorch_lightning._graveyard.loggers diff --git a/src/pytorch_lightning/_graveyard/cli.py b/src/pytorch_lightning/_graveyard/cli.py new file mode 100644 index 0000000000000..cba1692ec71ac --- /dev/null +++ b/src/pytorch_lightning/_graveyard/cli.py @@ -0,0 +1,45 @@ +import sys +from typing import Any + + +def _patch_sys_modules() -> None: + # TODO: Remove in v2.0.0 + self = sys.modules[__name__] + sys.modules["pytorch_lightning.utilities.cli"] = self + + +class LightningCLI: + # TODO: Remove in v2.0.0 + def __init__(self, *_: Any, **__: Any) -> None: + raise NotImplementedError( + "`pytorch_lightning.utilities.cli.LightningCLI` was deprecated in v1.7.0 and is no" + " longer supported as of v1.9.0. Please use `pytorch_lightning.cli.LightningCLI` instead" + ) + + +class SaveConfigCallback: + # TODO: Remove in v2.0.0 + def __init__(self, *_: Any, **__: Any) -> None: + raise NotImplementedError( + "`pytorch_lightning.utilities.cli.SaveConfigCallback` was deprecated in v1.7.0 and is no" + " longer supported as of v1.9.0. Please use `pytorch_lightning.cli.SaveConfigCallback` instead" + ) + + +class LightningArgumentParser: + # TODO: Remove in v2.0.0 + def __init__(self, *_: Any, **__: Any) -> None: + raise NotImplementedError( + "`pytorch_lightning.utilities.cli.LightningArgumentParser` was deprecated in v1.7.0 and is no" + " longer supported as of v1.9.0. Please use `pytorch_lightning.cli.LightningArgumentParser` instead" + ) + + +def instantiate_class(*_: Any, **__: Any) -> None: + raise NotImplementedError( + "`pytorch_lightning.utilities.cli.instantiate_class` was deprecated in v1.7.0 and is no" + " longer supported as of v1.9.0. Please use `pytorch_lightning.cli.instantiate_class` instead" + ) + + +_patch_sys_modules() diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index 54871e6173efb..7e1566088405b 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -279,7 +279,6 @@ def __init__( args: ArgsType = None, run: bool = True, auto_configure_optimizers: bool = True, - auto_registry: bool = False, **kwargs: Any, # Remove with deprecations of v1.10 ) -> None: """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which @@ -323,7 +322,6 @@ def __init__( ``dict`` or ``jsonargparse.Namespace``. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. - auto_registry: Whether to automatically fill up the registries with all defined subclasses. """ self.save_config_callback = save_config_callback self.save_config_kwargs = save_config_kwargs or {} @@ -345,10 +343,6 @@ def __init__( self._datamodule_class = datamodule_class or LightningDataModule self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data - from pytorch_lightning.utilities.cli import _populate_registries - - _populate_registries(auto_registry) - main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser, args) @@ -371,15 +365,14 @@ def _handle_deprecated_params(self, kwargs: dict) -> None: ) self.seed_everything_default = False - for name in ["save_config_filename", "save_config_overwrite", "save_config_multifile"]: - if name in kwargs: - value = kwargs.pop(name) - key = name.replace("save_config_", "").replace("filename", "config_filename") - self.save_config_kwargs[key] = value - rank_zero_deprecation( - f"LightningCLI's {name!r} init parameter is deprecated from v1.8 and will " - f"be removed in v1.10. Use `save_config_kwargs={{'{key}': ...}}` instead." - ) + for name in kwargs.keys() & ["save_config_filename", "save_config_overwrite", "save_config_multifile"]: + value = kwargs.pop(name) + key = name.replace("save_config_", "").replace("filename", "config_filename") + self.save_config_kwargs[key] = value + rank_zero_deprecation( + f"LightningCLI's {name!r} init parameter is deprecated from v1.8 and will " + f"be removed in v1.10. Use `save_config_kwargs={{'{key}': ...}}` instead." + ) for name in kwargs.keys() & ["description", "env_prefix", "env_parse"]: value = kwargs.pop(name) diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py deleted file mode 100644 index 9916fd75cb2a6..0000000000000 --- a/src/pytorch_lightning/utilities/cli.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Deprecated utilities for LightningCLI.""" - -import inspect -from types import ModuleType -from typing import Any, Generator, List, Optional, Tuple, Type - -import torch -from lightning_utilities.core.inheritance import get_all_subclasses -from torch.optim import Optimizer - -import pytorch_lightning as pl -import pytorch_lightning.cli as new_cli -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation - -_deprecate_registry_message = ( - "`LightningCLI`'s registries were deprecated in v1.7 and will be removed " - "in v1.9. Now any imported subclass is automatically available by name in " - "`LightningCLI` without any need to explicitly register it." -) - -_deprecate_auto_registry_message = ( - "`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed " - "in v1.9. Now any imported subclass is automatically available by name in " - "`LightningCLI` without any need to explicitly register it." -) - - -class _Registry(dict): # Remove in v1.9 - def __call__( - self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True - ) -> Type: - """Registers a class mapped to a name. - - Args: - cls: the class to be mapped. - key: the name that identifies the provided class. - override: Whether to override an existing key. - """ - if key is None: - key = cls.__name__ - elif not isinstance(key, str): - raise TypeError(f"`key` must be a str, found {key}") - - if key not in self or override: - self[key] = cls - - self._deprecation(show_deprecation) - return cls - - def register_classes( - self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True - ) -> None: - """This function is an utility to register all classes from a module.""" - for cls in self.get_members(module, base_cls): - self(cls=cls, override=override, show_deprecation=show_deprecation) - - @staticmethod - def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]: - return ( - cls - for _, cls in inspect.getmembers(module, predicate=inspect.isclass) - if issubclass(cls, base_cls) and cls != base_cls - ) - - @property - def names(self) -> List[str]: - """Returns the registered names.""" - self._deprecation() - return list(self.keys()) - - @property - def classes(self) -> Tuple[Type, ...]: - """Returns the registered classes.""" - self._deprecation() - return tuple(self.values()) - - def __str__(self) -> str: - return f"Registered objects: {self.names}" - - def _deprecation(self, show_deprecation: bool = True) -> None: - if show_deprecation and not getattr(self, "deprecation_shown", False): - rank_zero_deprecation(_deprecate_registry_message) - self.deprecation_shown = True - - -OPTIMIZER_REGISTRY = _Registry() -LR_SCHEDULER_REGISTRY = _Registry() -CALLBACK_REGISTRY = _Registry() -MODEL_REGISTRY = _Registry() -DATAMODULE_REGISTRY = _Registry() -LOGGER_REGISTRY = _Registry() - - -def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 - if subclasses: - rank_zero_deprecation(_deprecate_auto_registry_message) - # this will register any subclasses from all loaded modules including userland - for cls in get_all_subclasses(torch.optim.Optimizer): - OPTIMIZER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): - LR_SCHEDULER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.Callback): - CALLBACK_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningModule): - MODEL_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningDataModule): - DATAMODULE_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.loggers.Logger): - LOGGER_REGISTRY(cls, show_deprecation=False) - else: - # manually register torch's subclasses and our subclasses - OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False) - LR_SCHEDULER_REGISTRY.register_classes( - torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False - ) - CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False) - LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False) - # `ReduceLROnPlateau` does not subclass `_LRScheduler` - LR_SCHEDULER_REGISTRY(cls=new_cli.ReduceLROnPlateau, show_deprecation=False) - - -def _deprecation(cls: Type) -> None: - rank_zero_deprecation( - f"`pytorch_lightning.utilities.cli.{cls.__name__}` has been deprecated in v1.7 and will be removed in v1.9." - f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead." - ) - - -class LightningArgumentParser(new_cli.LightningArgumentParser): - def __init__(self, *args: Any, **kwargs: Any) -> None: - _deprecation(type(self)) - super().__init__(*args, **kwargs) - - -class SaveConfigCallback(new_cli.SaveConfigCallback): - def __init__(self, *args: Any, **kwargs: Any) -> None: - _deprecation(type(self)) - super().__init__(*args, **kwargs) - - -class LightningCLI(new_cli.LightningCLI): - def __init__(self, *args: Any, **kwargs: Any) -> None: - _deprecation(type(self)) - super().__init__(*args, **kwargs) - - -def instantiate_class(*args: Any, **kwargs: Any) -> Any: - rank_zero_deprecation( - "`pytorch_lightning.utilities.cli.instantiate_class` has been deprecated in v1.7 and will be removed in v1.9." - " Use the equivalent function in `pytorch_lightning.cli.instantiate_class` instead." - ) - return new_cli.instantiate_class(*args, **kwargs) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 7f07baf315f33..eb5292f3008e8 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -13,16 +13,13 @@ # limitations under the License. from unittest import mock -from unittest.mock import Mock import pytest import pytorch_lightning.loggers.base as logger_base -import pytorch_lightning.utilities.cli as old_cli from pytorch_lightning import Trainer -from pytorch_lightning.cli import LightningCLI, SaveConfigCallback +from pytorch_lightning.cli import LightningCLI from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -136,32 +133,3 @@ def test_deprecated_dataloader_reset(): trainer = Trainer() with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"): trainer.reset_train_val_dataloaders() - - -def test_lightningCLI_registries_register(): - with pytest.deprecated_call(match=old_cli._deprecate_registry_message): - - @old_cli.CALLBACK_REGISTRY - class CustomCallback(SaveConfigCallback): - pass - - -def test_lightningCLI_registries_register_automatically(): - with pytest.deprecated_call(match=old_cli._deprecate_auto_registry_message): - with mock.patch("sys.argv", ["any.py"]): - LightningCLI(BoringModel, run=False, auto_registry=True) - - -def test_lightningCLI_old_module_deprecation(): - with pytest.deprecated_call(match=r"LightningCLI.*deprecated in v1.7.*Use the equivalent class"): - with mock.patch("sys.argv", ["any.py"]): - old_cli.LightningCLI(BoringModel, run=False) - - with pytest.deprecated_call(match=r"SaveConfigCallback.*deprecated in v1.7.*Use the equivalent class"): - old_cli.SaveConfigCallback(Mock(), Mock(), Mock()) - - with pytest.deprecated_call(match=r"LightningArgumentParser.*deprecated in v1.7.*Use the equivalent class"): - old_cli.LightningArgumentParser() - - with pytest.deprecated_call(match=r"instantiate_class.*deprecated in v1.7.*Use the equivalent function"): - assert isinstance(old_cli.instantiate_class(tuple(), {"class_path": "pytorch_lightning.Trainer"}), Trainer) diff --git a/tests/tests_pytorch/graveyard/test_cli.py b/tests/tests_pytorch/graveyard/test_cli.py new file mode 100644 index 0000000000000..b902093e46d3d --- /dev/null +++ b/tests/tests_pytorch/graveyard/test_cli.py @@ -0,0 +1,23 @@ +import pytest + + +def test_lightningCLI_old_module_removal(): + from pytorch_lightning.utilities.cli import LightningCLI + + with pytest.raises(NotImplementedError, match=r"LightningCLI.*no longer supported as of v1.9"): + LightningCLI() + + from pytorch_lightning.utilities.cli import SaveConfigCallback + + with pytest.raises(NotImplementedError, match=r"SaveConfigCallback.*no longer supported as of v1.9"): + SaveConfigCallback() + + from pytorch_lightning.utilities.cli import LightningArgumentParser + + with pytest.raises(NotImplementedError, match=r"LightningArgumentParser.*no longer supported as of v1.9"): + LightningArgumentParser() + + from pytorch_lightning.utilities.cli import instantiate_class + + with pytest.raises(NotImplementedError, match=r"instantiate_class.*no longer supported as of v1.9"): + instantiate_class() diff --git a/tests/tests_pytorch/strategies/scripts/cli_script.py b/tests/tests_pytorch/strategies/scripts/cli_script.py index 17f0d29392eb9..61f46005dc7cd 100644 --- a/tests/tests_pytorch/strategies/scripts/cli_script.py +++ b/tests/tests_pytorch/strategies/scripts/cli_script.py @@ -20,5 +20,5 @@ BoringModel, BoringDataModule, seed_everything_default=42, - save_config_overwrite=True, + save_config_kwargs={"overwrite": True}, )