Skip to content

Commit

Permalink
Treat NewTypes like normal subclasses
Browse files Browse the repository at this point in the history
NewTypes are assumed not to inherit any members from their base classes.
This results in incorrect inference results. Avoid this by changing the
transformation for NewTypes to treat them like any other subclass.

pylint-dev/pylint#3162
pylint-dev/pylint#2296
  • Loading branch information
colatkinson committed Dec 23, 2021
1 parent 39c37c1 commit 0233147
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 11 deletions.
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ What's New in astroid 2.9.1?
============================
Release date: TBA

* Treat ``typing.NewType()`` values as normal subclasses.

Closes PyCQA/pylint#2296
Closes PyCQA/pylint#3162

* Prefer the module loader get_source() method in AstroidBuilder's
module_build() when possible to avoid assumptions about source
code being available on a filesystem. Otherwise the source cannot
Expand Down
81 changes: 70 additions & 11 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import typing
from functools import partial

from astroid import context, extract_node, inference_tip
from astroid import context, extract_node, inference_tip, nodes
from astroid.const import PY37_PLUS, PY38_PLUS, PY39_PLUS
from astroid.exceptions import (
AttributeInferenceError,
Expand All @@ -38,8 +38,6 @@
from astroid.util import Uninferable

TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
TYPING_TYPEVARS = {"TypeVar", "NewType"}
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
TYPING_TYPE_TEMPLATE = """
class Meta(type):
def __getitem__(self, item):
Expand All @@ -52,6 +50,13 @@ def __args__(self):
class {0}(metaclass=Meta):
pass
"""
# PEP484 suggests NewType is equivalent to this for typing purposes
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
TYPING_NEWTYPE_TEMPLATE = """
class {derived}({base}):
def __init__(self, val: {base}) -> None:
...
"""
TYPING_MEMBERS = set(getattr(typing, "__all__", []))

TYPING_ALIAS = frozenset(
Expand Down Expand Up @@ -106,23 +111,34 @@ def __class_getitem__(cls, item):
"""


def looks_like_typing_typevar_or_newtype(node):
def looks_like_typing_typevar(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname == "TypeVar"
if isinstance(func, Name):
return func.name == "TypeVar"
return False


def looks_like_typing_newtype(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname in TYPING_TYPEVARS
return func.attrname == "NewType"
if isinstance(func, Name):
return func.name in TYPING_TYPEVARS
return func.name == "NewType"
return False


def infer_typing_typevar_or_newtype(node, context_itton=None):
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
def infer_typing_typevar(
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
) -> typing.Iterator[nodes.ClassDef]:
"""Infer a typing.TypeVar(...) call"""
try:
func = next(node.func.infer(context=context_itton))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
if func.qname() != "typing.TypeVar":
raise UseInferenceDefault
if not node.args:
raise UseInferenceDefault
Expand All @@ -132,6 +148,44 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
return node.infer(context=context_itton)


def infer_typing_newtype(
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
) -> typing.Iterator[nodes.ClassDef]:
"""Infer a typing.NewType(...) call"""
try:
func = next(node.func.infer(context=context_itton))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if func.qname() != "typing.NewType":
raise UseInferenceDefault
if len(node.args) != 2:
raise UseInferenceDefault

derived, base = node.args
derived_name = derived.as_string().strip("'")
base_name = base.as_string().strip("'")

new_node: ClassDef = extract_node(
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
)
new_node.parent = node.parent

# Base type arg is a normal reference, so no need to do special lookups
if not isinstance(base, nodes.Const):
new_node.bases = [base]

# If the base type is given as a string (e.g. for a forward reference),
# make a naive attempt to find the corresponding node.
# Note that this will not work with imported types.
if isinstance(base, nodes.Const) and isinstance(base.value, str):
_, resolved_base = node.frame().lookup(base_name)
if resolved_base:
new_node.bases = [resolved_base[0]]

return new_node.infer(context=context_itton)


def _looks_like_typing_subscript(node):
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
if isinstance(node, Name):
Expand Down Expand Up @@ -409,8 +463,13 @@ def infer_typing_cast(

AstroidManager().register_transform(
Call,
inference_tip(infer_typing_typevar_or_newtype),
looks_like_typing_typevar_or_newtype,
inference_tip(infer_typing_typevar),
looks_like_typing_typevar,
)
AstroidManager().register_transform(
Call,
inference_tip(infer_typing_newtype),
looks_like_typing_newtype,
)
AstroidManager().register_transform(
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
Expand Down
122 changes: 122 additions & 0 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,128 @@ def test_typing_types(self) -> None:
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())

def test_typing_newtype_attrs(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
import decimal
from decimal import Decimal
NewType("Foo", str) #@
NewType("Bar", "int") #@
NewType("Baz", Decimal) #@
NewType("Qux", decimal.Decimal) #@
"""
)
assert isinstance(ast_nodes, list)

# Base type given by reference
foo_node = ast_nodes[0]
foo_inferred = next(foo_node.infer())
self.assertIsInstance(foo_inferred, astroid.ClassDef)

# Check base type method is inferred by accessing one of its methods
foo_base_class_method = foo_inferred.getattr("endswith")[0]
self.assertIsInstance(foo_base_class_method, astroid.FunctionDef)
self.assertEqual("builtins.str.endswith", foo_base_class_method.qname())

# Base type given by string (i.e. "int")
bar_node = ast_nodes[1]
bar_inferred = next(bar_node.infer())
self.assertIsInstance(bar_inferred, astroid.ClassDef)

bar_base_class_method = bar_inferred.getattr("bit_length")[0]
self.assertIsInstance(bar_base_class_method, astroid.FunctionDef)
self.assertEqual("builtins.int.bit_length", bar_base_class_method.qname())

# Decimal may be reexported from an implementation-defined module. For
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
# _pydecimal. So the expected qname needs to be grabbed dynamically.
decimal_quant_node = builder.extract_node(
"""
from decimal import Decimal
Decimal.quantize #@
"""
)
assert isinstance(decimal_quant_node, nodes.NodeNG)
decimal_quant_qname = next(decimal_quant_node.infer()).qname()

# Base type is from an "import from"
baz_node = ast_nodes[2]
baz_inferred = next(baz_node.infer())
self.assertIsInstance(baz_inferred, astroid.ClassDef)

baz_base_class_method = baz_inferred.getattr("quantize")[0]
self.assertIsInstance(baz_base_class_method, astroid.FunctionDef)
self.assertEqual(decimal_quant_qname, baz_base_class_method.qname())

# Base type is from an import
qux_node = ast_nodes[3]
qux_inferred = next(qux_node.infer())
self.assertIsInstance(qux_inferred, astroid.ClassDef)

qux_base_class_method = qux_inferred.getattr("quantize")[0]
self.assertIsInstance(qux_base_class_method, astroid.FunctionDef)
self.assertEqual(decimal_quant_qname, qux_base_class_method.qname())

def test_typing_newtype_user_defined(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
class A:
def __init__(self, value: int):
self.value = value
a = A(5)
a #@
B = NewType("B", A)
b = B(5)
b #@
"""
)
assert isinstance(ast_nodes, list)

for node in ast_nodes:
self._verify_node_has_expected_attr(node)

def test_typing_newtype_forward_reference(self) -> None:
# Similar to the test above, but using a forward reference for "A"
ast_nodes = builder.extract_node(
"""
from typing import NewType
B = NewType("B", "A")
class A:
def __init__(self, value: int):
self.value = value
a = A(5)
a #@
b = B(5)
b #@
"""
)
assert isinstance(ast_nodes, list)

for node in ast_nodes:
self._verify_node_has_expected_attr(node)

def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.Instance)

# Should be able to infer that the "value" attr is present on both types
val = inferred.getattr("value")[0]
self.assertIsInstance(val, astroid.AssignAttr)

# Sanity check: nonexistent attr is not inferred
with self.assertRaises(AttributeInferenceError):
inferred.getattr("bad_attr")

def test_namedtuple_nested_class(self):
result = builder.extract_node(
"""
Expand Down

0 comments on commit 0233147

Please sign in to comment.