Skip to content

Commit

Permalink
- Stubs resolver now supports sys.version_info conditionals.
Browse files Browse the repository at this point in the history
- When stub type fails to parse, keep it with empty type if not among known parameters.
- Avoid issues with invalid value defaults, e.g. torch.optim.optimizer.required.
- Fixed fail_untyped=False not propagated to subclass --*.help actions.
  • Loading branch information
mauvilsa committed Dec 20, 2022
1 parent 02a4ff5 commit 535ce73
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Fixed
- Pure dataclass instance default being considered as a subclass type.
- Discard ``init_args`` after ``class_path`` change causes error `#205
<https://github.com/omni-us/jsonargparse/issues/205>`__.
- ``fail_untyped=False`` not propagated to subclass ``--*.help`` actions.
- Issues reported by CodeQL.

Changed
Expand Down
111 changes: 79 additions & 32 deletions jsonargparse/_stubs_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import inspect
import sys
from copy import deepcopy
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

Expand All @@ -9,7 +11,7 @@
typeshed_client_support,
typing_extensions_import,
)
from .util import import_object, unique
from .util import unique

if TYPE_CHECKING: # pragma: no cover
import typeshed_client as tc
Expand All @@ -20,12 +22,12 @@
kinds = inspect._ParameterKind


def import_module(name: str):
def import_module_or_none(path: str):
if path.endswith('.__init__'):
path = path[:-9]
try:
if '.' in name:
return import_object(name)
return __import__(name)
except Exception:
return import_module(path)
except ModuleNotFoundError:
return None


Expand All @@ -48,6 +50,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
module_path = self.module_path[:-node.level]
if node.module:
module_path.append(node.module)
node = deepcopy(node)
node.module = '.'.join(module_path)
node.level = 0
for alias in node.names:
Expand All @@ -61,7 +64,13 @@ def find(self, node: ast.AST, module_path: str) -> Dict[str, Tuple[Optional[str]


def ast_annassign_to_assign(node: ast.AnnAssign) -> ast.Assign:
return ast.Assign(targets=[node.target], value=node.value, type_ignores=[], lineno=1, end_lineno=1)
return ast.Assign(
targets=[node.target],
value=node.value,
type_ignores=[],
lineno=node.lineno,
end_lineno=node.lineno,
)


class AssignsVisitor(ast.NodeVisitor):
Expand All @@ -81,6 +90,36 @@ def find(self, node: ast.AST) -> Dict[str, ast.Assign]:
return self.assigns_found


class MethodsVisitor(ast.NodeVisitor):

method_found: Optional[ast.FunctionDef]

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if not self.method_found and node.name == self.method_name:
self.method_found = node

def visit_If(self, node: ast.If) -> None:
test_ast = ast.parse('___test___ = 0')
test_ast.body[0].value = node.test # type: ignore
exec_vars = {'sys': sys}
try:
exec(compile(test_ast, filename="<ast>", mode="exec"), exec_vars, exec_vars)
except Exception:
pass
else:
if exec_vars['___test___']:
node.orelse = []
else:
node.body = []
self.generic_visit(node)

def find(self, node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]:
self.method_name = method_name
self.method_found = None
self.visit(node)
return self.method_found


stubs_resolver = None


Expand All @@ -93,16 +132,6 @@ def get_stubs_resolver():
return stubs_resolver


def ast_get_class_method(node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]:
method_ast = None
if isinstance(node, ast.ClassDef):
for elem in node.body:
if isinstance(elem, ast.FunctionDef) and elem.name == method_name:
method_ast = elem
break
return method_ast


def get_mro_method_parent(parent, method_name):
while hasattr(parent, '__dict__') and method_name not in parent.__dict__:
try:
Expand All @@ -112,6 +141,19 @@ def get_mro_method_parent(parent, method_name):
return None if parent is object else parent


def get_source_module(path: str, component) -> tc.ModulePath:
if component is None:
module_path, name = path.rsplit('.', 1)
component = getattr(import_module_or_none(module_path), name, None)
if component is not None:
module = inspect.getmodule(component)
assert module is not None
module_path = module.__name__
if getattr(module, '__file__', '').endswith('__init__.py'):
module_path += '.__init__'
return tc.ModulePath(tuple(module_path.split('.')))


class StubsResolver(tc.Resolver):

def __init__(self, search_context = None) -> None:
Expand All @@ -120,13 +162,13 @@ def __init__(self, search_context = None) -> None:
self._module_assigns_cache: Dict[str, Dict[str, ast.Assign]] = {}
self._module_imports_cache: Dict[str, Dict[str, Tuple[Optional[str], str]]] = {}

def get_imported_info(self, name: str) -> Optional[tc.ImportedInfo]:
resolved = super().get_fully_qualified_name(name)
def get_imported_info(self, path: str, component=None) -> Optional[tc.ImportedInfo]:
resolved = self.get_fully_qualified_name(path)
imported_info = None
if isinstance(resolved, tc.ImportedInfo):
imported_info = resolved
elif isinstance(resolved, tc.NameInfo):
source_module = tc.ModulePath(tuple(name.split('.')[:-1]))
resolved = resolved.info
if isinstance(resolved, tc.NameInfo):
source_module = get_source_module(path, component)
imported_info = tc.ImportedInfo(source_module=source_module, info=resolved)
return imported_info

Expand All @@ -135,12 +177,14 @@ def get_component_imported_info(self, component, parent) -> Optional[tc.Imported
parent = type(component.__self__)
component = getattr(parent, component.__name__)
if not parent:
return self.get_imported_info(f'{component.__module__}.{component.__name__}')
return self.get_imported_info(f'{component.__module__}.{component.__name__}', component)
parent = get_mro_method_parent(parent, component.__name__)
stub_import = parent and self.get_imported_info(f'{parent.__module__}.{parent.__name__}')
stub_import = parent and self.get_imported_info(f'{parent.__module__}.{parent.__name__}', component)
if stub_import and isinstance(stub_import.info.ast, ast.AST):
method_ast = ast_get_class_method(stub_import.info.ast, component.__name__)
if method_ast is not None:
method_ast = MethodsVisitor().find(stub_import.info.ast, component.__name__)
if method_ast is None:
stub_import = None
else:
name_info = tc.NameInfo(name=component.__qualname__, is_exported=False, ast=method_ast)
stub_import = tc.ImportedInfo(source_module=stub_import.source_module, info=name_info)
return stub_import
Expand Down Expand Up @@ -169,7 +213,7 @@ def get_module_stub_imports(self, module_path: str):

def add_import_aliases(self, aliases, stub_import: tc.ImportedInfo):
module_path = '.'.join(stub_import.source_module)
module = import_module(module_path)
module = import_module_or_none(module_path)
stub_ast: Optional[ast.AST] = None
if isinstance(stub_import.info.ast, (ast.Assign, ast.AnnAssign)):
stub_ast = stub_import.info.ast.value
Expand All @@ -196,7 +240,7 @@ def add_module_aliases(self, aliases, module_path, module, node):
self.add_module_aliases(aliases, module_path, module, value.value)
elif name in self.get_module_stub_imports(module_path):
imported_module_path, imported_name = self.get_module_stub_imports(module_path)[name]
imported_module = import_module(imported_module_path)
imported_module = import_module_or_none(imported_module_path)
if hasattr(imported_module, imported_name):
source = imported_module_path
value = getattr(imported_module, imported_name)
Expand Down Expand Up @@ -255,10 +299,12 @@ def get_arg_type(arg_ast, aliases):
try:
exec(compile(type_ast, filename="<ast>", mode="exec"), exec_vars, exec_vars)
except NameError as ex:
ex_from = None
for name, alias_exception in bad_aliases.items():
if str(ex) == f"name '{name}' is not defined":
raise NameError(str(alias_exception)) from ex
raise ex
ex_from = alias_exception
break
raise ex from ex_from
return exec_vars['___arg_type___']


Expand All @@ -285,6 +331,7 @@ def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any]
try:
types[name] = get_arg_type(arg_ast, aliases)
except Exception as ex:
logger.debug(f'Failed to use type stub for parameter {name}', exc_info=ex)
continue
logger.debug(f'Failed to parse type stub for {component.__qualname__!r} parameter {name!r}', exc_info=ex)
if name not in known_params:
types[name] = inspect._empty
return types
3 changes: 3 additions & 0 deletions jsonargparse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ def _apply_actions(
cfg: Union[Namespace, Dict[str, Any]],
parent_key: str = '',
prev_cfg: Optional[Namespace] = None,
skip_fn: Optional[Callable[[Any], bool]] = None,
) -> Namespace:
"""Runs _check_value_key on actions present in config."""
if isinstance(cfg, dict):
Expand Down Expand Up @@ -1262,6 +1263,8 @@ def _apply_actions(

action_dest = action.dest if subcommand is None else subcommand+'.'+action.dest
value = cfg[action_dest]
if skip_fn and skip_fn(value):
continue
with lenient_check_context(), load_value_context(self.parser_mode):
value = self._check_value_key(action, value, action_dest, prev_cfg)
if isinstance(action, _ActionConfigLoad):
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def import_typeshed_client():
import typeshed_client
return typeshed_client
else:
return __import__('argparse').Namespace(ImportedInfo=object, Resolver=object)
return __import__('argparse').Namespace(ImportedInfo=object, ModulePath=object, Resolver=object)


class UndefinedException(Exception):
Expand Down
6 changes: 3 additions & 3 deletions jsonargparse/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,13 @@ def _add_signature_parameter(
kwargs['type'] = annotation
elif annotation != inspect_empty:
try:
is_subclass_typehint = ActionTypeHint.is_subclass_typehint(annotation)
is_subclass_typehint = ActionTypeHint.is_subclass_typehint(annotation, all_subtypes=False)
kwargs['type'] = annotation
sub_add_kwargs = None
sub_add_kwargs: dict = {'fail_untyped': fail_untyped, 'sub_configs': sub_configs}
if is_subclass_typehint:
prefix = name + '.init_args.'
subclass_skip = {s[len(prefix):] for s in skip or [] if s.startswith(prefix)}
sub_add_kwargs = {'fail_untyped': fail_untyped, 'skip': subclass_skip}
sub_add_kwargs['skip'] = subclass_skip
args = ActionTypeHint.prepare_add_argument(
args=args,
kwargs=kwargs,
Expand Down
9 changes: 8 additions & 1 deletion jsonargparse/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,15 @@ def sub_defaults_context():

@staticmethod
def add_sub_defaults(parser, cfg):
def skip_sub_defaults_apply(v):
return not (
isinstance(v, (str, Namespace)) or
is_subclass_spec(v) or
(isinstance(v, list) and any(is_subclass_spec(e) for e in v)) or
(isinstance(v, dict) and any(is_subclass_spec(e) for e in v.values()))
)
with ActionTypeHint.sub_defaults_context():
parser._apply_actions(cfg)
parser._apply_actions(cfg, skip_fn=skip_sub_defaults_apply)


@staticmethod
Expand Down
18 changes: 18 additions & 0 deletions jsonargparse_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,24 @@ def func(a1, a2=None):
self.assertEqual(Namespace(a1=None, a2=None), parser.parse_args([]))


def test_fail_untyped_false_subclass_help(self):
class Class1:
def __init__(self, a1, a2=None):
self.a1 = a1

def func(c1: Union[int, Class1]):
return c1

with mock_module(Class1) as module:
parser = ArgumentParser(error_handler=None)
parser.add_function_arguments(func, fail_untyped=False)

help_str = StringIO()
with redirect_stdout(help_str), self.assertRaises(SystemExit):
parser.parse_args([f'--c1.help={module}.Class1'])
self.assertIn('--c1.init_args.a1 A1', help_str.getvalue())


@unittest.skipIf(not docstring_parser_support, 'docstring-parser package is required')
def test_docstring_parse_fail(self):

Expand Down
Loading

0 comments on commit 535ce73

Please sign in to comment.