From ebe80e3e8722f36b6216b4497b77ef955d937e5b Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 27 Jun 2023 09:37:52 +0200 Subject: [PATCH] Fix use of jsonargparse avoiding reliance on non-public internal logic (#1620) --- requirements/base.txt | 2 +- src/flash/core/utilities/flash_cli.py | 18 +++++++--- src/flash/core/utilities/lightning_cli.py | 41 +++------------------- tests/core/utilities/test_lightning_cli.py | 28 +++++++-------- 4 files changed, 32 insertions(+), 57 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 96d017ea82..e460970415 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,7 +8,7 @@ torchmetrics >0.7.0, <0.11.0 # strict pytorch-lightning >1.8.0, <2.0.0 # strict pyDeprecate >0.2.0 pandas >1.1.0, <=1.5.2 -jsonargparse[signatures] >4.0.0, <=4.9.0 +jsonargparse[signatures] >=4.22.0, <4.23.0 click >=7.1.2, <=8.1.3 protobuf <=3.20.1 fsspec[http] >=2022.5.0,<=2023.6.0 diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py index 1de8f5f9df..132fc85479 100644 --- a/src/flash/core/utilities/flash_cli.py +++ b/src/flash/core/utilities/flash_cli.py @@ -20,8 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Type, Union import pytorch_lightning as pl -from jsonargparse import ArgumentParser -from jsonargparse.signatures import get_class_signature_functions +from jsonargparse import ArgumentParser, class_from_function from lightning_utilities.core.overrides import is_overridden from pytorch_lightning import LightningModule, Trainer @@ -31,7 +30,6 @@ LightningArgumentParser, LightningCLI, SaveConfigCallback, - class_from_function, ) from flash.core.utilities.stability import beta @@ -107,6 +105,16 @@ def wrapper(*args, **kwargs): return wrapper +def get_class_signature_functions(classes): + signatures = [] + for num, cls in enumerate(classes): + if cls.__new__ is not object.__new__ and not any(cls.__new__ is c.__new__ for c in classes[num + 1 :]): + signatures.append((cls, cls.__new__)) + if not any(cls.__init__ is c.__init__ for c in classes[num + 1 :]): + signatures.append((cls, cls.__init__)) + return signatures + + def get_overlapping_args(func_a, func_b) -> Set[str]: func_a = get_class_signature_functions([func_a])[0][1] func_b = get_class_signature_functions([func_b])[0][1] @@ -214,7 +222,7 @@ def add_arguments_to_parser(self, parser) -> None: def add_subcommand_from_function(self, subcommands, function, function_name=None): subcommand = ArgumentParser() if get_kwarg_name(function) == "data_module_kwargs": - datamodule_function = class_from_function(function, return_type=self.local_datamodule_class) + datamodule_function = class_from_function(function, self.local_datamodule_class) subcommand.add_class_arguments( datamodule_function, fail_untyped=False, @@ -233,7 +241,7 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None }, ) else: - datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) + datamodule_function = class_from_function(drop_kwargs(function), self.local_datamodule_class) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index 37ce4a470e..e288c964e4 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -4,14 +4,11 @@ import os import warnings from argparse import Namespace -from functools import wraps from types import MethodType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast import torch -from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode -from jsonargparse.signatures import ClassFromFunctionBase -from jsonargparse.typehints import ClassType +from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -25,46 +22,16 @@ LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -def class_from_function( - func: Callable[..., ClassType], - return_type: Optional[Type[ClassType]] = None, -) -> Type[ClassType]: - """Creates a dynamic class which if instantiated is equivalent to calling func. - - Args: - func: A function that returns an instance of a class. It must have a return type annotation. - """ - - @wraps(func) - def __new__(cls, *args, **kwargs): - return func(*args, **kwargs) - - if return_type is None: - return_type = inspect.signature(func).return_annotation - - if isinstance(return_type, str): - raise RuntimeError("Classmethod instantiation is not supported when the return type annotation is a string.") - - class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore - pass - - ClassFromFunction.__new__ = __new__ # type: ignore - ClassFromFunction.__doc__ = func.__doc__ - ClassFromFunction.__name__ = func.__name__ - - return ClassFromFunction - - class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" - def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ `_. """ - super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + super().__init__(*args, **kwargs) self.add_argument( "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) @@ -95,7 +62,7 @@ def add_lightning_class_args( if inspect.isclass(lightning_class) and issubclass( cast(type, lightning_class), - (Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase), + (Trainer, LightningModule, LightningDataModule, Callback), ): if issubclass(cast(type, lightning_class), Callback): self.callback_keys.append(nested_key) diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index e2b4c9d8c0..44b9a9c559 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -40,7 +40,7 @@ def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) args = parser.parse_args([]) args.max_epochs = 5 @@ -54,7 +54,7 @@ def test_default_args(mock_argparse, tmpdir): @pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--default_root_dir=./"], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args(cli_args) @@ -79,11 +79,11 @@ def test_add_argparse_args_redefined(cli_args): ("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}), ( "--auto_lr_find any_string --auto_scale_batch_size ON", - {"auto_lr_find": "any_string", "auto_scale_batch_size": True}, + {"auto_lr_find": "any_string", "auto_scale_batch_size": "ON"}, ), - ("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": True}), - ("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": False}), - ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}), + ("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": "On"}), + ("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": "No"}), + ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": "FALSE"}), ("--limit_train_batches=100", {"limit_train_batches": 100}), ("--limit_train_batches 0.8", {"limit_train_batches": 0.8}), ], @@ -91,7 +91,7 @@ def test_add_argparse_args_redefined(cli_args): def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(" ") if cli_args else [] - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) parser.add_lightning_class_args(Trainer, None) with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -112,7 +112,7 @@ def test_parse_args_parsing(cli_args, expected): ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) parser.add_lightning_class_args(Trainer, None) with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -137,7 +137,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1]) cli_args = cli_args.split(" ") if cli_args else [] - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) parser.add_lightning_class_args(Trainer, None) with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -310,8 +310,8 @@ def test_lightning_cli_args(tmpdir): config = yaml.safe_load(f.read()) assert "model" not in config assert "model" not in cli.config - assert config["data"] == cli.config["data"] - assert config["trainer"] == cli.config["trainer"] + assert config["data"] == cli.config["data"].as_dict() + assert config["trainer"] == cli.config["trainer"].as_dict() @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @@ -363,9 +363,9 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) - assert config["model"] == cli.config["model"] - assert config["data"] == cli.config["data"] - assert config["trainer"] == cli.config["trainer"] + assert config["model"] == cli.config["model"].as_dict() + assert config["data"] == cli.config["data"].as_dict() + assert config["trainer"] == cli.config["trainer"].as_dict() def any_model_any_data_cli():