Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat NewTypes like normal subclasses #1301

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Release date: TBA
* Fix test for Python ``3.11``. In some instances ``err.__traceback__`` will
be uninferable now.

* Treat ``typing.NewType()`` values as normal subclasses.

Closes PyCQA/pylint#2296
Closes PyCQA/pylint#3162

What's New in astroid 2.11.6?
=============================
Release date: TBA
Expand Down
176 changes: 162 additions & 14 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from collections.abc import Iterator
from functools import partial

from astroid import context, extract_node, inference_tip
from astroid import context, extract_node, inference_tip, nodes
from astroid.builder import _extract_single_node
from astroid.const import PY38_PLUS, PY39_PLUS
from astroid.exceptions import (
AstroidImportError,
AttributeInferenceError,
InferenceError,
UseInferenceDefault,
Expand All @@ -35,8 +36,6 @@
from astroid.util import Uninferable

TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
TYPING_TYPEVARS = {"TypeVar", "NewType"}
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
TYPING_TYPE_TEMPLATE = """
class Meta(type):
def __getitem__(self, item):
Expand All @@ -49,6 +48,13 @@ def __args__(self):
class {0}(metaclass=Meta):
pass
"""
# PEP484 suggests NewType is equivalent to this for typing purposes
DanielNoord marked this conversation as resolved.
Show resolved Hide resolved
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
TYPING_NEWTYPE_TEMPLATE = """
class {derived}({base}):
def __init__(self, val: {base}) -> None:
...
"""
TYPING_MEMBERS = set(getattr(typing, "__all__", []))

TYPING_ALIAS = frozenset(
Expand Down Expand Up @@ -103,24 +109,33 @@ def __class_getitem__(cls, item):
"""


def looks_like_typing_typevar_or_newtype(node):
def looks_like_typing_typevar(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname in TYPING_TYPEVARS
return func.attrname == "TypeVar"
if isinstance(func, Name):
return func.name in TYPING_TYPEVARS
return func.name == "TypeVar"
return False


def infer_typing_typevar_or_newtype(node, context_itton=None):
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
def looks_like_typing_newtype(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname == "NewType"
if isinstance(func, Name):
return func.name == "NewType"
return False


def infer_typing_typevar(
node: nodes.Call, ctx: context.InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Infer a typing.TypeVar(...) call"""
try:
func = next(node.func.infer(context=context_itton))
next(node.func.infer(context=ctx))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
raise UseInferenceDefault
if not node.args:
raise UseInferenceDefault
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L127-128 has a drop in coverage. Could you re-create a test for it?

# Cannot infer from a dynamic class name (f-string)
Expand All @@ -129,7 +144,135 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):

typename = node.args[0].as_string().strip("'")
node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
return node.infer(context=context_itton)
return node.infer(context=ctx)


def infer_typing_newtype(
node: nodes.Call, ctx: context.InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Infer a typing.NewType(...) call"""
try:
next(node.func.infer(context=ctx))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if len(node.args) != 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create a test for this? It is currently uncovered.

raise UseInferenceDefault

# Cannot infer from a dynamic class name (f-string)
if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr):
raise UseInferenceDefault

derived, base = node.args
derived_name = derived.as_string().strip("'")
base_name = base.as_string().strip("'")

new_node: ClassDef = extract_node(
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
)
new_node.parent = node.parent

new_bases: list[NodeNG] = []

if not isinstance(base, nodes.Const):
# Base type arg is a normal reference, so no need to do special lookups
new_bases = [base]
elif isinstance(base, nodes.Const) and isinstance(base.value, str):
# If the base type is given as a string (e.g. for a forward reference),
# make a naive attempt to find the corresponding node.
_, resolved_base = node.frame().lookup(base_name)
if resolved_base:
base_node = resolved_base[0]

# If the value is from an "import from" statement, follow the import chain
if isinstance(base_node, nodes.ImportFrom):
ctx = ctx.clone() if ctx else context.InferenceContext()
ctx.lookupname = base_name
base_node = next(base_node.infer(context=ctx))

new_bases = [base_node]
elif "." in base.value:
possible_base = _try_find_imported_object_from_str(node, base.value, ctx)
if possible_base:
new_bases = [possible_base]

if new_bases:
new_node.postinit(
bases=new_bases, body=new_node.body, decorators=new_node.decorators
)
Comment on lines +200 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_node is fully constructed already. It's enough to set bases manually.

Suggested change
new_node.postinit(
bases=new_bases, body=new_node.body, decorators=new_node.decorators
)
new_node.bases = new_bases


return new_node.infer(context=ctx)


def _try_find_imported_object_from_str(
node: nodes.Call,
name: str,
ctx: context.InferenceContext | None,
) -> nodes.NodeNG | None:
for statement_mod_name, _ in _possible_module_object_splits(name):
# Find import statements that may pull in the appropriate modules
# The name used to find this statement may not correspond to the name of the module actually being imported
# For example, "import email.charset" is found by lookup("email")
_, resolved_bases = node.frame().lookup(statement_mod_name)
if not resolved_bases:
continue

resolved_base = resolved_bases[0]
if isinstance(resolved_base, nodes.Import):
# Extract the names of the module as they are accessed from actual code
scope_names = {(alias or name) for (name, alias) in resolved_base.names}
aliases = {alias: name for (name, alias) in resolved_base.names if alias}

# Find potential mod_name, obj_name splits that work with the available names
# for the module in this scope
import_targets = [
(mod_name, obj_name)
for (mod_name, obj_name) in _possible_module_object_splits(name)
if mod_name in scope_names
]
if not import_targets:
continue

import_target, name_in_mod = import_targets[0]
import_target = aliases.get(import_target, import_target)

# Try to import the module and find the object in it
try:
resolved_mod: nodes.Module = resolved_base.do_import_module(
import_target
)
except AstroidImportError:
# If the module doesn't actually exist, try the next option
continue

# Try to find the appropriate ClassDef or other such node in the target module
_, object_results_in_mod = resolved_mod.lookup(name_in_mod)
if not object_results_in_mod:
continue

base_node = object_results_in_mod[0]

# If the value is from an "import from" statement, follow the import chain
if isinstance(base_node, nodes.ImportFrom):
ctx = ctx.clone() if ctx else context.InferenceContext()
ctx.lookupname = name_in_mod
base_node = next(base_node.infer(context=ctx))

return base_node

return None


def _possible_module_object_splits(
dot_str: str,
) -> Iterator[tuple[str, str]]:
components = dot_str.split(".")
popped = []

while components:
popped.append(components.pop())

yield ".".join(components), ".".join(reversed(popped))


def _looks_like_typing_subscript(node):
Expand Down Expand Up @@ -403,8 +546,13 @@ def infer_typing_cast(

AstroidManager().register_transform(
Call,
inference_tip(infer_typing_typevar_or_newtype),
looks_like_typing_typevar_or_newtype,
inference_tip(infer_typing_typevar),
looks_like_typing_typevar,
)
AstroidManager().register_transform(
Call,
inference_tip(infer_typing_newtype),
looks_like_typing_newtype,
)
AstroidManager().register_transform(
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
Expand Down
Loading