Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improve typing of builtins brain #2225

Merged
2 changes: 1 addition & 1 deletion astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _unpack_keywords(
keywords: list[tuple[str | None, nodes.NodeNG]],
context: InferenceContext | None = None,
):
values = {}
values: dict[str | None, InferenceResult] = {}
context = context or InferenceContext()
context.extra_context = self.argument_context_map
for name, value in keywords:
Expand Down
49 changes: 35 additions & 14 deletions astroid/brain/brain_builtin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterable, Iterator
from functools import partial
from typing import Any, Type, Union, cast
from typing import Any, NoReturn, Type, Union, cast

from astroid import arguments, helpers, inference_tip, nodes, objects, util
from astroid.builder import AstroidBuilder
Expand Down Expand Up @@ -43,6 +43,13 @@
Type[frozenset],
]

CopyResult = Union[
nodes.Dict,
nodes.List,
nodes.Set,
objects.FrozenSet,
]

OBJECT_DUNDER_NEW = "object.__new__"

STR_CLASS = """
Expand Down Expand Up @@ -127,6 +134,10 @@ def ljust(self, width, fillchar=None):
"""


def _use_default() -> NoReturn:
jacobtylerwalls marked this conversation as resolved.
Show resolved Hide resolved
raise UseInferenceDefault()


def _extend_string_class(class_node, code, rvalue):
"""Function to extend builtin str/unicode class."""
code = code.format(rvalue=rvalue)
Expand Down Expand Up @@ -193,7 +204,9 @@ def register_builtin_transform(transform, builtin_name) -> None:
an optional context.
"""

def _transform_wrapper(node, context: InferenceContext | None = None):
def _transform_wrapper(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
):
result = transform(node, context=context)
if result:
if not result.parent:
Expand Down Expand Up @@ -257,10 +270,12 @@ def _container_generic_transform(
iterables: tuple[type[nodes.BaseContainer] | type[ContainerObjects], ...],
build_elts: BuiltContainers,
) -> nodes.BaseContainer | None:
elts: Iterable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contradicts line 301.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 301 is the klass argument? Having trouble following.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry! New code line 300

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it now, thanks!


if isinstance(arg, klass):
return arg
if isinstance(arg, iterables):
arg = cast(ContainerObjects, arg)
arg = cast(nodes.BaseContainer, arg)
jacobtylerwalls marked this conversation as resolved.
Show resolved Hide resolved
if all(isinstance(elt, nodes.Const) for elt in arg.elts):
elts = [cast(nodes.Const, elt).value for elt in arg.elts]
else:
Expand All @@ -277,9 +292,10 @@ def _container_generic_transform(
elts.append(evaluated_object)
elif isinstance(arg, nodes.Dict):
# Dicts need to have consts as strings already.
if not all(isinstance(elt[0], nodes.Const) for elt in arg.items):
raise UseInferenceDefault()
elts = [item[0].value for item in arg.items]
elts = [
item[0].value if isinstance(item[0], nodes.Const) else _use_default()
for item in arg.items
]
elif isinstance(arg, nodes.Const) and isinstance(arg.value, (str, bytes)):
elts = arg.value
else:
Expand Down Expand Up @@ -399,6 +415,7 @@ def infer_dict(node: nodes.Call, context: InferenceContext | None = None) -> nod
args = call.positional_arguments
kwargs = list(call.keyword_arguments.items())

items: list[tuple[SuccessfulInferenceResult, SuccessfulInferenceResult]]
Copy link
Collaborator

@DanielNoord DanielNoord Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 377 of the new line contradicts this, although that might be too broad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! We should probably broaden it out. That means this is probably also too narrow:

def postinit(self, elts: list[SuccessfulInferenceResult]) -> None:
self.elts = elts

if not args and not kwargs:
# dict()
return nodes.Dict(
Expand Down Expand Up @@ -695,7 +712,9 @@ def infer_slice(node, context: InferenceContext | None = None):
return slice_node


def _infer_object__new__decorator(node, context: InferenceContext | None = None):
def _infer_object__new__decorator(
node: nodes.ClassDef, context: InferenceContext | None = None, **kwargs: Any
):
# Instantiate class immediately
# since that's what @object.__new__ does
return iter((node.instantiate_class(),))
Expand Down Expand Up @@ -944,10 +963,10 @@ def _build_dict_with_elements(elements):
if isinstance(inferred_values, nodes.Const) and isinstance(
inferred_values.value, (str, bytes)
):
elements = [
elements_with_value = [
(nodes.Const(element), default) for element in inferred_values.value
]
return _build_dict_with_elements(elements)
return _build_dict_with_elements(elements_with_value)
if isinstance(inferred_values, nodes.Dict):
keys = inferred_values.itered()
for key in keys:
Expand All @@ -964,7 +983,7 @@ def _build_dict_with_elements(elements):

def _infer_copy_method(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[InferenceResult]:
) -> Iterator[CopyResult]:
assert isinstance(node.func, nodes.Attribute)
inferred_orig, inferred_copy = itertools.tee(node.func.expr.infer(context=context))
if all(
Expand All @@ -975,7 +994,7 @@ def _infer_copy_method(
):
return inferred_copy

raise UseInferenceDefault()
raise UseInferenceDefault


def _is_str_format_call(node: nodes.Call) -> bool:
Expand Down Expand Up @@ -1075,11 +1094,13 @@ def _infer_str_format_call(

AstroidManager().register_transform(
nodes.Call,
inference_tip(_infer_copy_method),
inference_tip(_infer_copy_method), # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be avoided by using a Generator instead of an Iterator on the signature of this function.

Same goes for the other one, but that then creates other issues.. Can we yield from in that function? That would make it an iterator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, one more that I'm having trouble following :-)

Elsewhere we return iterators just fine, and they are typed this way:

astroid/astroid/bases.py

Lines 449 to 453 in e3ba1ca

def igetattr(
self, name: str, context: InferenceContext | None = None
) -> Iterator[InferenceResult]:
if name in self.special_attributes:
return iter((self.special_attributes.lookup(name),))

So why would we change to Generator?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because InferFn expects a Generator. I think because we define it so explicitly we need to do so here as well.. I think we might also get away with changing InferFn to be an Iterator but since we don't actually run mypy that might create issues somewhere else that we don't know about..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho! Got it now. Okay. I'll audit the uses of InferFn. It's start took like like it should be Iterator.

but since we don't actually run mypy that might create issues somewhere else that we don't know about..

That's why I'm trying to focus by file and make sure with each change the total number goes down and no new issues in the single file I'm focusing on (but yes... whackamole could ensue...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that does make the PRs much easier to review.

What I have also noticed is that doing PRs per concept (such as InferFn) is also a good way of grouping things. It might make the big number bigger as the improved typing exposes more issues, but it makes review and reasoning about the issues easier as it is all related.

lambda node: isinstance(node.func, nodes.Attribute)
and node.func.attrname == "copy",
)

AstroidManager().register_transform(
nodes.Call, inference_tip(_infer_str_format_call), _is_str_format_call
nodes.Call,
inference_tip(_infer_str_format_call), # type: ignore[arg-type]
_is_str_format_call,
)
8 changes: 5 additions & 3 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3911,11 +3911,13 @@ class EvaluatedObject(NodeNG):
_astroid_fields = ("original",)
_other_fields = ("value",)

def __init__(self, original: NodeNG, value: NodeNG | util.UninferableBase) -> None:
self.original: NodeNG = original
def __init__(
self, original: SuccessfulInferenceResult, value: InferenceResult
) -> None:
self.original: SuccessfulInferenceResult = original
"""The original node that has already been evaluated"""

self.value: NodeNG | util.UninferableBase = value
self.value: InferenceResult = value
"""The inferred value"""

super().__init__(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,10 @@ def test_str_repr_no_warnings(node):

if "int" in param_type.annotation:
args[name] = random.randint(0, 50)
elif "NodeNG" in param_type.annotation:
elif (
"NodeNG" in param_type.annotation
or "SuccessfulInferenceResult" in param_type.annotation
):
args[name] = nodes.Unknown()
elif "str" in param_type.annotation:
args[name] = ""
Expand Down