Skip to content

Commit

Permalink
Further improve typing of builtins brain (#2225)
Browse files Browse the repository at this point in the history
Resolves 12 mypy errors

Co-authored-by: Daniël van Noord <[email protected]>
  • Loading branch information
jacobtylerwalls and DanielNoord authored Jun 27, 2023
1 parent 2f8b636 commit 8d57ce2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _unpack_keywords(
keywords: list[tuple[str | None, nodes.NodeNG]],
context: InferenceContext | None = None,
):
values = {}
values: dict[str | None, InferenceResult] = {}
context = context or InferenceContext()
context.extra_context = self.argument_context_map
for name, value in keywords:
Expand Down
52 changes: 38 additions & 14 deletions astroid/brain/brain_builtin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterable
from functools import partial
from typing import Any, Type, Union, cast
from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Type, Union, cast

from astroid import arguments, helpers, inference_tip, nodes, objects, util
from astroid.builder import AstroidBuilder
Expand All @@ -29,6 +29,9 @@
SuccessfulInferenceResult,
)

if TYPE_CHECKING:
from astroid.bases import Instance

ContainerObjects = Union[
objects.FrozenSet,
objects.DictItems,
Expand All @@ -43,6 +46,13 @@
Type[frozenset],
]

CopyResult = Union[
nodes.Dict,
nodes.List,
nodes.Set,
objects.FrozenSet,
]

OBJECT_DUNDER_NEW = "object.__new__"

STR_CLASS = """
Expand Down Expand Up @@ -127,6 +137,10 @@ def ljust(self, width, fillchar=None):
"""


def _use_default() -> NoReturn: # pragma: no cover
raise UseInferenceDefault()


def _extend_string_class(class_node, code, rvalue):
"""Function to extend builtin str/unicode class."""
code = code.format(rvalue=rvalue)
Expand Down Expand Up @@ -193,7 +207,9 @@ def register_builtin_transform(transform, builtin_name) -> None:
an optional context.
"""

def _transform_wrapper(node, context: InferenceContext | None = None):
def _transform_wrapper(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator:
result = transform(node, context=context)
if result:
if not result.parent:
Expand Down Expand Up @@ -257,10 +273,12 @@ def _container_generic_transform(
iterables: tuple[type[nodes.BaseContainer] | type[ContainerObjects], ...],
build_elts: BuiltContainers,
) -> nodes.BaseContainer | None:
elts: Iterable | str | bytes

if isinstance(arg, klass):
return arg
if isinstance(arg, iterables):
arg = cast(ContainerObjects, arg)
arg = cast(Union[nodes.BaseContainer, ContainerObjects], arg)
if all(isinstance(elt, nodes.Const) for elt in arg.elts):
elts = [cast(nodes.Const, elt).value for elt in arg.elts]
else:
Expand All @@ -277,9 +295,10 @@ def _container_generic_transform(
elts.append(evaluated_object)
elif isinstance(arg, nodes.Dict):
# Dicts need to have consts as strings already.
if not all(isinstance(elt[0], nodes.Const) for elt in arg.items):
raise UseInferenceDefault()
elts = [item[0].value for item in arg.items]
elts = [
item[0].value if isinstance(item[0], nodes.Const) else _use_default()
for item in arg.items
]
elif isinstance(arg, nodes.Const) and isinstance(arg.value, (str, bytes)):
elts = arg.value
else:
Expand Down Expand Up @@ -399,6 +418,7 @@ def infer_dict(node: nodes.Call, context: InferenceContext | None = None) -> nod
args = call.positional_arguments
kwargs = list(call.keyword_arguments.items())

items: list[tuple[InferenceResult, InferenceResult]]
if not args and not kwargs:
# dict()
return nodes.Dict(
Expand Down Expand Up @@ -695,7 +715,9 @@ def infer_slice(node, context: InferenceContext | None = None):
return slice_node


def _infer_object__new__decorator(node, context: InferenceContext | None = None):
def _infer_object__new__decorator(
node: nodes.ClassDef, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[Instance]:
# Instantiate class immediately
# since that's what @object.__new__ does
return iter((node.instantiate_class(),))
Expand Down Expand Up @@ -944,10 +966,10 @@ def _build_dict_with_elements(elements):
if isinstance(inferred_values, nodes.Const) and isinstance(
inferred_values.value, (str, bytes)
):
elements = [
elements_with_value = [
(nodes.Const(element), default) for element in inferred_values.value
]
return _build_dict_with_elements(elements)
return _build_dict_with_elements(elements_with_value)
if isinstance(inferred_values, nodes.Dict):
keys = inferred_values.itered()
for key in keys:
Expand All @@ -964,7 +986,7 @@ def _build_dict_with_elements(elements):

def _infer_copy_method(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[InferenceResult]:
) -> Iterator[CopyResult]:
assert isinstance(node.func, nodes.Attribute)
inferred_orig, inferred_copy = itertools.tee(node.func.expr.infer(context=context))
if all(
Expand All @@ -973,9 +995,9 @@ def _infer_copy_method(
)
for inferred_node in inferred_orig
):
return inferred_copy
return cast(Iterator[CopyResult], inferred_copy)

raise UseInferenceDefault()
raise UseInferenceDefault


def _is_str_format_call(node: nodes.Call) -> bool:
Expand Down Expand Up @@ -1081,5 +1103,7 @@ def _infer_str_format_call(
)

AstroidManager().register_transform(
nodes.Call, inference_tip(_infer_str_format_call), _is_str_format_call
nodes.Call,
inference_tip(_infer_str_format_call),
_is_str_format_call,
)
12 changes: 6 additions & 6 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,9 +1864,7 @@ def __init__(
parent=parent,
)

def postinit(
self, items: list[tuple[SuccessfulInferenceResult, SuccessfulInferenceResult]]
) -> None:
def postinit(self, items: list[tuple[InferenceResult, InferenceResult]]) -> None:
"""Do some setup after initialisation.
:param items: The key-value pairs contained in the dictionary.
Expand Down Expand Up @@ -4058,11 +4056,13 @@ class EvaluatedObject(NodeNG):
_astroid_fields = ("original",)
_other_fields = ("value",)

def __init__(self, original: NodeNG, value: NodeNG | util.UninferableBase) -> None:
self.original: NodeNG = original
def __init__(
self, original: SuccessfulInferenceResult, value: InferenceResult
) -> None:
self.original: SuccessfulInferenceResult = original
"""The original node that has already been evaluated"""

self.value: NodeNG | util.UninferableBase = value
self.value: InferenceResult = value
"""The inferred value"""

super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from astroid.manager import AstroidManager
from astroid.nodes import NodeNG
from astroid.nodes.utils import Position
from astroid.typing import SuccessfulInferenceResult
from astroid.typing import InferenceResult

REDIRECT: Final[dict[str, str]] = {
"arguments": "Arguments",
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def visit_dict(self, node: ast.Dict, parent: NodeNG) -> nodes.Dict:
end_col_offset=node.end_col_offset,
parent=parent,
)
items: list[tuple[SuccessfulInferenceResult, SuccessfulInferenceResult]] = list(
items: list[tuple[InferenceResult, InferenceResult]] = list(
self._visit_dict_items(node, parent, newnode)
)
newnode.postinit(items)
Expand Down
4 changes: 3 additions & 1 deletion astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)

if TYPE_CHECKING:
from collections.abc import Iterator

from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec
Expand Down Expand Up @@ -84,7 +86,7 @@ def __call__(
node: _SuccessfulInferenceResultT_contra,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
) -> Iterator[InferenceResult]:
... # pragma: no cover


Expand Down
5 changes: 4 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,7 +1931,10 @@ def test_str_repr_no_warnings(node):

if "int" in param_type.annotation:
args[name] = random.randint(0, 50)
elif "NodeNG" in param_type.annotation:
elif (
"NodeNG" in param_type.annotation
or "SuccessfulInferenceResult" in param_type.annotation
):
args[name] = nodes.Unknown()
elif "str" in param_type.annotation:
args[name] = ""
Expand Down

0 comments on commit 8d57ce2

Please sign in to comment.