From 535ce73d98b448019e73a39fd58108cbc7823c1c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 20 Dec 2022 21:48:08 +0100 Subject: [PATCH] - Stubs resolver now supports sys.version_info conditionals. - When stub type fails to parse, keep it with empty type if not among known parameters. - Avoid issues with invalid value defaults, e.g. torch.optim.optimizer.required. - Fixed fail_untyped=False not propagated to subclass --*.help actions. --- CHANGELOG.rst | 1 + jsonargparse/_stubs_resolver.py | 111 +++++++++++++++------- jsonargparse/core.py | 3 + jsonargparse/optionals.py | 2 +- jsonargparse/signatures.py | 6 +- jsonargparse/typehints.py | 9 +- jsonargparse_tests/test_signatures.py | 18 ++++ jsonargparse_tests/test_stubs_resolver.py | 54 +++++++++-- setup.cfg | 3 + 9 files changed, 162 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2f98f5ef..f4f6c13f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -34,6 +34,7 @@ Fixed - Pure dataclass instance default being considered as a subclass type. - Discard ``init_args`` after ``class_path`` change causes error `#205 `__. +- ``fail_untyped=False`` not propagated to subclass ``--*.help`` actions. - Issues reported by CodeQL. Changed diff --git a/jsonargparse/_stubs_resolver.py b/jsonargparse/_stubs_resolver.py index 2f431541..85e3f360 100644 --- a/jsonargparse/_stubs_resolver.py +++ b/jsonargparse/_stubs_resolver.py @@ -1,6 +1,8 @@ import ast import inspect import sys +from copy import deepcopy +from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -9,7 +11,7 @@ typeshed_client_support, typing_extensions_import, ) -from .util import import_object, unique +from .util import unique if TYPE_CHECKING: # pragma: no cover import typeshed_client as tc @@ -20,12 +22,12 @@ kinds = inspect._ParameterKind -def import_module(name: str): +def import_module_or_none(path: str): + if path.endswith('.__init__'): + path = path[:-9] try: - if '.' in name: - return import_object(name) - return __import__(name) - except Exception: + return import_module(path) + except ModuleNotFoundError: return None @@ -48,6 +50,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: module_path = self.module_path[:-node.level] if node.module: module_path.append(node.module) + node = deepcopy(node) node.module = '.'.join(module_path) node.level = 0 for alias in node.names: @@ -61,7 +64,13 @@ def find(self, node: ast.AST, module_path: str) -> Dict[str, Tuple[Optional[str] def ast_annassign_to_assign(node: ast.AnnAssign) -> ast.Assign: - return ast.Assign(targets=[node.target], value=node.value, type_ignores=[], lineno=1, end_lineno=1) + return ast.Assign( + targets=[node.target], + value=node.value, + type_ignores=[], + lineno=node.lineno, + end_lineno=node.lineno, + ) class AssignsVisitor(ast.NodeVisitor): @@ -81,6 +90,36 @@ def find(self, node: ast.AST) -> Dict[str, ast.Assign]: return self.assigns_found +class MethodsVisitor(ast.NodeVisitor): + + method_found: Optional[ast.FunctionDef] + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if not self.method_found and node.name == self.method_name: + self.method_found = node + + def visit_If(self, node: ast.If) -> None: + test_ast = ast.parse('___test___ = 0') + test_ast.body[0].value = node.test # type: ignore + exec_vars = {'sys': sys} + try: + exec(compile(test_ast, filename="", mode="exec"), exec_vars, exec_vars) + except Exception: + pass + else: + if exec_vars['___test___']: + node.orelse = [] + else: + node.body = [] + self.generic_visit(node) + + def find(self, node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]: + self.method_name = method_name + self.method_found = None + self.visit(node) + return self.method_found + + stubs_resolver = None @@ -93,16 +132,6 @@ def get_stubs_resolver(): return stubs_resolver -def ast_get_class_method(node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]: - method_ast = None - if isinstance(node, ast.ClassDef): - for elem in node.body: - if isinstance(elem, ast.FunctionDef) and elem.name == method_name: - method_ast = elem - break - return method_ast - - def get_mro_method_parent(parent, method_name): while hasattr(parent, '__dict__') and method_name not in parent.__dict__: try: @@ -112,6 +141,19 @@ def get_mro_method_parent(parent, method_name): return None if parent is object else parent +def get_source_module(path: str, component) -> tc.ModulePath: + if component is None: + module_path, name = path.rsplit('.', 1) + component = getattr(import_module_or_none(module_path), name, None) + if component is not None: + module = inspect.getmodule(component) + assert module is not None + module_path = module.__name__ + if getattr(module, '__file__', '').endswith('__init__.py'): + module_path += '.__init__' + return tc.ModulePath(tuple(module_path.split('.'))) + + class StubsResolver(tc.Resolver): def __init__(self, search_context = None) -> None: @@ -120,13 +162,13 @@ def __init__(self, search_context = None) -> None: self._module_assigns_cache: Dict[str, Dict[str, ast.Assign]] = {} self._module_imports_cache: Dict[str, Dict[str, Tuple[Optional[str], str]]] = {} - def get_imported_info(self, name: str) -> Optional[tc.ImportedInfo]: - resolved = super().get_fully_qualified_name(name) + def get_imported_info(self, path: str, component=None) -> Optional[tc.ImportedInfo]: + resolved = self.get_fully_qualified_name(path) imported_info = None if isinstance(resolved, tc.ImportedInfo): - imported_info = resolved - elif isinstance(resolved, tc.NameInfo): - source_module = tc.ModulePath(tuple(name.split('.')[:-1])) + resolved = resolved.info + if isinstance(resolved, tc.NameInfo): + source_module = get_source_module(path, component) imported_info = tc.ImportedInfo(source_module=source_module, info=resolved) return imported_info @@ -135,12 +177,14 @@ def get_component_imported_info(self, component, parent) -> Optional[tc.Imported parent = type(component.__self__) component = getattr(parent, component.__name__) if not parent: - return self.get_imported_info(f'{component.__module__}.{component.__name__}') + return self.get_imported_info(f'{component.__module__}.{component.__name__}', component) parent = get_mro_method_parent(parent, component.__name__) - stub_import = parent and self.get_imported_info(f'{parent.__module__}.{parent.__name__}') + stub_import = parent and self.get_imported_info(f'{parent.__module__}.{parent.__name__}', component) if stub_import and isinstance(stub_import.info.ast, ast.AST): - method_ast = ast_get_class_method(stub_import.info.ast, component.__name__) - if method_ast is not None: + method_ast = MethodsVisitor().find(stub_import.info.ast, component.__name__) + if method_ast is None: + stub_import = None + else: name_info = tc.NameInfo(name=component.__qualname__, is_exported=False, ast=method_ast) stub_import = tc.ImportedInfo(source_module=stub_import.source_module, info=name_info) return stub_import @@ -169,7 +213,7 @@ def get_module_stub_imports(self, module_path: str): def add_import_aliases(self, aliases, stub_import: tc.ImportedInfo): module_path = '.'.join(stub_import.source_module) - module = import_module(module_path) + module = import_module_or_none(module_path) stub_ast: Optional[ast.AST] = None if isinstance(stub_import.info.ast, (ast.Assign, ast.AnnAssign)): stub_ast = stub_import.info.ast.value @@ -196,7 +240,7 @@ def add_module_aliases(self, aliases, module_path, module, node): self.add_module_aliases(aliases, module_path, module, value.value) elif name in self.get_module_stub_imports(module_path): imported_module_path, imported_name = self.get_module_stub_imports(module_path)[name] - imported_module = import_module(imported_module_path) + imported_module = import_module_or_none(imported_module_path) if hasattr(imported_module, imported_name): source = imported_module_path value = getattr(imported_module, imported_name) @@ -255,10 +299,12 @@ def get_arg_type(arg_ast, aliases): try: exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars) except NameError as ex: + ex_from = None for name, alias_exception in bad_aliases.items(): if str(ex) == f"name '{name}' is not defined": - raise NameError(str(alias_exception)) from ex - raise ex + ex_from = alias_exception + break + raise ex from ex_from return exec_vars['___arg_type___'] @@ -285,6 +331,7 @@ def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any] try: types[name] = get_arg_type(arg_ast, aliases) except Exception as ex: - logger.debug(f'Failed to use type stub for parameter {name}', exc_info=ex) - continue + logger.debug(f'Failed to parse type stub for {component.__qualname__!r} parameter {name!r}', exc_info=ex) + if name not in known_params: + types[name] = inspect._empty return types diff --git a/jsonargparse/core.py b/jsonargparse/core.py index d26f69f4..ad1d0a00 100644 --- a/jsonargparse/core.py +++ b/jsonargparse/core.py @@ -1218,6 +1218,7 @@ def _apply_actions( cfg: Union[Namespace, Dict[str, Any]], parent_key: str = '', prev_cfg: Optional[Namespace] = None, + skip_fn: Optional[Callable[[Any], bool]] = None, ) -> Namespace: """Runs _check_value_key on actions present in config.""" if isinstance(cfg, dict): @@ -1262,6 +1263,8 @@ def _apply_actions( action_dest = action.dest if subcommand is None else subcommand+'.'+action.dest value = cfg[action_dest] + if skip_fn and skip_fn(value): + continue with lenient_check_context(), load_value_context(self.parser_mode): value = self._check_value_key(action, value, action_dest, prev_cfg) if isinstance(action, _ActionConfigLoad): diff --git a/jsonargparse/optionals.py b/jsonargparse/optionals.py index 3cf7fc40..61f3f489 100644 --- a/jsonargparse/optionals.py +++ b/jsonargparse/optionals.py @@ -65,7 +65,7 @@ def import_typeshed_client(): import typeshed_client return typeshed_client else: - return __import__('argparse').Namespace(ImportedInfo=object, Resolver=object) + return __import__('argparse').Namespace(ImportedInfo=object, ModulePath=object, Resolver=object) class UndefinedException(Exception): diff --git a/jsonargparse/signatures.py b/jsonargparse/signatures.py index 319ad114..1e3cf034 100644 --- a/jsonargparse/signatures.py +++ b/jsonargparse/signatures.py @@ -330,13 +330,13 @@ def _add_signature_parameter( kwargs['type'] = annotation elif annotation != inspect_empty: try: - is_subclass_typehint = ActionTypeHint.is_subclass_typehint(annotation) + is_subclass_typehint = ActionTypeHint.is_subclass_typehint(annotation, all_subtypes=False) kwargs['type'] = annotation - sub_add_kwargs = None + sub_add_kwargs: dict = {'fail_untyped': fail_untyped, 'sub_configs': sub_configs} if is_subclass_typehint: prefix = name + '.init_args.' subclass_skip = {s[len(prefix):] for s in skip or [] if s.startswith(prefix)} - sub_add_kwargs = {'fail_untyped': fail_untyped, 'skip': subclass_skip} + sub_add_kwargs['skip'] = subclass_skip args = ActionTypeHint.prepare_add_argument( args=args, kwargs=kwargs, diff --git a/jsonargparse/typehints.py b/jsonargparse/typehints.py index 10193700..32d4fcd5 100644 --- a/jsonargparse/typehints.py +++ b/jsonargparse/typehints.py @@ -339,8 +339,15 @@ def sub_defaults_context(): @staticmethod def add_sub_defaults(parser, cfg): + def skip_sub_defaults_apply(v): + return not ( + isinstance(v, (str, Namespace)) or + is_subclass_spec(v) or + (isinstance(v, list) and any(is_subclass_spec(e) for e in v)) or + (isinstance(v, dict) and any(is_subclass_spec(e) for e in v.values())) + ) with ActionTypeHint.sub_defaults_context(): - parser._apply_actions(cfg) + parser._apply_actions(cfg, skip_fn=skip_sub_defaults_apply) @staticmethod diff --git a/jsonargparse_tests/test_signatures.py b/jsonargparse_tests/test_signatures.py index 214b2da7..c3e84901 100755 --- a/jsonargparse_tests/test_signatures.py +++ b/jsonargparse_tests/test_signatures.py @@ -966,6 +966,24 @@ def func(a1, a2=None): self.assertEqual(Namespace(a1=None, a2=None), parser.parse_args([])) + def test_fail_untyped_false_subclass_help(self): + class Class1: + def __init__(self, a1, a2=None): + self.a1 = a1 + + def func(c1: Union[int, Class1]): + return c1 + + with mock_module(Class1) as module: + parser = ArgumentParser(error_handler=None) + parser.add_function_arguments(func, fail_untyped=False) + + help_str = StringIO() + with redirect_stdout(help_str), self.assertRaises(SystemExit): + parser.parse_args([f'--c1.help={module}.Class1']) + self.assertIn('--c1.init_args.a1 A1', help_str.getvalue()) + + @unittest.skipIf(not docstring_parser_support, 'docstring-parser package is required') def test_docstring_parse_fail(self): diff --git a/jsonargparse_tests/test_stubs_resolver.py b/jsonargparse_tests/test_stubs_resolver.py index ebcb5ad2..6633bcb3 100755 --- a/jsonargparse_tests/test_stubs_resolver.py +++ b/jsonargparse_tests/test_stubs_resolver.py @@ -11,15 +11,19 @@ from ipaddress import ip_network from random import Random, SystemRandom, uniform from tarfile import TarFile +from typing import Any from unittest.mock import patch from uuid import UUID, uuid5 +import yaml + from jsonargparse import ArgumentParser from jsonargparse._stubs_resolver import get_mro_method_parent, get_stubs_resolver from jsonargparse.parameter_resolvers import get_signature_parameters as get_params from jsonargparse_tests.base import get_debug_level_logger logger = get_debug_level_logger(__name__) +torch_available = find_spec('torch') @contextmanager @@ -86,12 +90,34 @@ def test_get_params_object_instance_method(self): params = get_params(uniform) self.assertEqual([('a', inspect._empty), ('b', inspect._empty)], get_param_types(params)) - @unittest.skipIf(sys.version_info[:2] < (3, 10), 'new union syntax introduced in python 3.10') + def test_get_params_conditional_python_version(self): + params = get_params(Random, 'seed') + self.assertEqual(['a', 'version'], get_param_names(params)) + if sys.version_info >= (3, 10): + self.assertEqual('int | float | str | bytes | bytearray | None', str(params[0].annotation)) + else: + expected = Any if sys.version_info < (3, 9) else inspect._empty + self.assertEqual(expected, params[0].annotation) + self.assertEqual(int, params[1].annotation) + with mock_typeshed_client_unavailable(): + params = get_params(Random, 'seed') + self.assertEqual([('a', inspect._empty), ('version', inspect._empty)], get_param_types(params)) + + @patch('jsonargparse._stubs_resolver.exec') + def test_get_params_exec_failure(self, mock_exec): + mock_exec.side_effect = NameError('failed') + params = get_params(Random, 'seed') + self.assertEqual([('a', inspect._empty), ('version', inspect._empty)], get_param_types(params)) + def test_get_params_classmethod(self): params = get_params(TarFile, 'open') - self.assertTrue(all(p.annotation != inspect._empty for p in params)) + expected = ['name', 'mode', 'fileobj', 'bufsize', 'format', 'tarinfo', 'dereference', 'ignore_zeros', 'encoding', 'errors', 'pax_headers', 'debug', 'errorlevel'] + self.assertEqual(expected, get_param_names(params)[:len(expected)]) + if sys.version_info >= (3, 10): + self.assertTrue(all(p.annotation != inspect._empty for p in params)) with mock_typeshed_client_unavailable(): params = get_params(TarFile, 'open') + self.assertEqual(expected, get_param_names(params)[:len(expected)]) self.assertTrue(all(p.annotation == inspect._empty for p in params)) def test_get_params_staticmethod(self): @@ -104,13 +130,25 @@ def test_get_params_staticmethod(self): def test_get_params_function(self): params = get_params(ip_network) self.assertEqual(['address', 'strict'], get_param_names(params)) - if sys.version_info[:2] >= (3, 10): + if sys.version_info >= (3, 10): self.assertIn('int | str | bytes | ipaddress.IPv4Address | ', str(params[0].annotation)) self.assertEqual(bool, params[1].annotation) with mock_typeshed_client_unavailable(): params = get_params(ip_network) self.assertEqual([('address', inspect._empty), ('strict', inspect._empty)], get_param_types(params)) + def test_get_param_relative_import_from_init(self): + params = get_params(yaml.safe_load) + self.assertEqual(['stream'], get_param_names(params)) + if sys.version_info >= (3, 10): + self.assertNotEqual(params[0].annotation, inspect._empty) + else: + self.assertEqual(params[0].annotation, inspect._empty) + with mock_typeshed_client_unavailable(): + params = get_params(yaml.safe_load) + self.assertEqual(['stream'], get_param_names(params)) + self.assertEqual(params[0].annotation, inspect._empty) + def test_get_params_non_unique_alias(self): params = get_params(uuid5) self.assertEqual([('namespace', UUID), ('name', str)], get_param_types(params)) @@ -127,7 +165,6 @@ def alias_is_unique(aliases, name, source, value): self.assertIn("non-unique alias 'UUID': problem (module)", log.output[0]) @unittest.skipIf(not find_spec('requests'), 'requests package is required') - @unittest.skipIf(sys.version_info[:2] < (3, 10), 'new union syntax introduced in python 3.10') def test_get_params_complex_function_requests_get(self): from requests import get with mock_typeshed_client_unavailable(): @@ -139,16 +176,17 @@ def test_get_params_complex_function_requests_get(self): params = get_params(get) expected += ['data', 'headers', 'cookies', 'files', 'auth', 'timeout', 'allow_redirects', 'proxies', 'hooks', 'stream', 'verify', 'cert', 'json'] self.assertEqual(expected, get_param_names(params)) - self.assertTrue(all(p.annotation != inspect._empty for p in params)) + if sys.version_info >= (3, 10): + self.assertTrue(all(p.annotation != inspect._empty for p in params)) parser = ArgumentParser(error_handler=None) - parser.add_function_arguments(get) + parser.add_function_arguments(get, fail_untyped=False) self.assertEqual(['url', 'params'], list(parser.get_defaults().keys())) help_str = StringIO() parser.print_help(help_str) self.assertIn('default: Unknown', help_str.getvalue()) - @unittest.skipIf(not find_spec('torch'), 'torch package is required') + @unittest.skipIf(not torch_available, 'torch package is required') def test_get_params_torch_optimizer(self): import torch.optim # pylint: disable=import-error @@ -181,7 +219,7 @@ def skip_stub_inconsistencies(cls, params): params = get_params(cls) self.assertTrue(any(p.annotation == inspect._empty for p in params)) - @unittest.skipIf(not find_spec('torch'), 'torch package is required') + @unittest.skipIf(not torch_available, 'torch package is required') def test_get_params_torch_lr_scheduler(self): import torch.optim.lr_scheduler # pylint: disable=import-error diff --git a/setup.cfg b/setup.cfg index e9f3c443..bc332675 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ reconplogger = reconplogger>=4.4.0 test = %(test_no_urls)s + %(types_pyyaml)s responses>=0.12.0 types-requests>=2.28.9 test_no_urls = @@ -60,7 +61,9 @@ dev = pylint = pylint>=2.15.6 mypy = + %(types_pyyaml)s mypy>=0.701 +types_pyyaml = types-PyYAML>=6.0.11 doc = Sphinx>=1.7.9