Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for hasattr() checks #13544

Merged
merged 16 commits into from
Aug 29, 2022
Merged
94 changes: 92 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
true_only,
try_expanding_sum_type_to_union,
try_getting_int_literals_from_type,
try_getting_str_literals,
try_getting_str_literals_from_type,
tuple_fallback,
)
Expand Down Expand Up @@ -4701,7 +4702,7 @@ def _make_fake_typeinfo_and_full_name(
return None

curr_module.names[full_name] = SymbolTableNode(GDEF, info)
return Instance(info, [])
return Instance(info, [], extra_attrs=instances[0].extra_attrs or instances[1].extra_attrs)

def intersect_instance_callable(self, typ: Instance, callable_type: CallableType) -> Instance:
"""Creates a fake type that represents the intersection of an Instance and a CallableType.
Expand All @@ -4728,7 +4729,7 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType

cur_module.names[gen_name] = SymbolTableNode(GDEF, info)

return Instance(info, [])
return Instance(info, [], extra_attrs=typ.extra_attrs)

def make_fake_callable(self, typ: Instance) -> Instance:
"""Produce a new type that makes type Callable with a generic callable type."""
Expand Down Expand Up @@ -5032,6 +5033,12 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE:
vartype = self.lookup_type(expr)
return self.conditional_callable_type_map(expr, vartype)
elif refers_to_fullname(node.callee, "builtins.hasattr"):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
attr = try_getting_str_literals(node.args[1], self.lookup_type(node.args[1]))
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
Expand Down Expand Up @@ -6211,6 +6218,89 @@ class Foo(Enum):
and member_type.fallback.type == parent_type.type_object()
)

def add_any_attribute_to_type(self, typ: Type, name: str) -> Type:
"""Inject an extra attribute with Any type using fallbacks."""
orig_typ = typ
typ = get_proper_type(typ)
any_type = AnyType(TypeOfAny.unannotated)
if isinstance(typ, Instance):
return typ.copy_with_extra_attr(name, any_type)
if isinstance(typ, TupleType):
fallback = typ.partial_fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, CallableType):
fallback = typ.fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
return TypeType.make_normalized(self.add_any_attribute_to_type(typ.item, name))
if isinstance(typ, TypeVarType):
return typ.copy_modified(
upper_bound=self.add_any_attribute_to_type(typ.upper_bound, name),
values=[self.add_any_attribute_to_type(v, name) for v in typ.values],
)
if isinstance(typ, UnionType):
with_attr, without_attr = self.partition_union_by_attr(typ, name)
return make_simplified_union(
with_attr + [self.add_any_attribute_to_type(typ, name) for typ in without_attr]
)
return orig_typ

def hasattr_type_maps(
self, expr: Expression, source_type: Type, name: str
) -> tuple[TypeMap, TypeMap]:
"""Simple support for hasattr() checks.

Essentially the logic is following:
* In the if branch, keep types that already has a valid attribute as is,
for other inject an attribute with `Any` type.
* In the else branch, remove types that already have a valid attribute,
while keeping the rest.
"""
if self.has_valid_attribute(source_type, name):
return {expr: source_type}, {}

source_type = get_proper_type(source_type)
if isinstance(source_type, UnionType):
_, without_attr = self.partition_union_by_attr(source_type, name)
yes_map = {expr: self.add_any_attribute_to_type(source_type, name)}
return yes_map, {expr: make_simplified_union(without_attr)}

type_with_attr = self.add_any_attribute_to_type(source_type, name)
if type_with_attr != source_type:
return {expr: type_with_attr}, {}
return {}, {}

def partition_union_by_attr(
self, source_type: UnionType, name: str
) -> tuple[list[Type], list[Type]]:
with_attr = []
without_attr = []
for item in source_type.items:
if self.has_valid_attribute(item, name):
with_attr.append(item)
else:
without_attr.append(item)
return with_attr, without_attr

def has_valid_attribute(self, typ: Type, name: str) -> bool:
if isinstance(get_proper_type(typ), AnyType):
return False
with self.msg.filter_errors() as watcher:
analyze_member_access(
name,
typ,
TempNode(AnyType(TypeOfAny.special_form)),
False,
False,
False,
self.msg,
original_type=typ,
chk=self,
# This is not a real attribute lookup so don't mess with deferring nodes.
no_deferral=True,
)
return not watcher.has_new_errors()


class CollectArgTypes(TypeTraverserVisitor):
"""Collects the non-nested argument types in a set."""
Expand Down
16 changes: 15 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
chk: mypy.checker.TypeChecker,
self_type: Type | None,
module_symbol_table: SymbolTable | None = None,
no_deferral: bool = False,
) -> None:
self.is_lvalue = is_lvalue
self.is_super = is_super
Expand All @@ -100,6 +101,7 @@ def __init__(
self.msg = msg
self.chk = chk
self.module_symbol_table = module_symbol_table
self.no_deferral = no_deferral

def named_type(self, name: str) -> Instance:
return self.chk.named_type(name)
Expand All @@ -124,6 +126,7 @@ def copy_modified(
self.chk,
self.self_type,
self.module_symbol_table,
self.no_deferral,
)
if messages is not None:
mx.msg = messages
Expand All @@ -149,6 +152,7 @@ def analyze_member_access(
in_literal_context: bool = False,
self_type: Type | None = None,
module_symbol_table: SymbolTable | None = None,
no_deferral: bool = False,
) -> Type:
"""Return the type of attribute 'name' of 'typ'.

Expand Down Expand Up @@ -183,6 +187,7 @@ def analyze_member_access(
chk=chk,
self_type=self_type,
module_symbol_table=module_symbol_table,
no_deferral=no_deferral,
)
result = _analyze_member_access(name, typ, mx, override_info)
possible_literal = get_proper_type(result)
Expand Down Expand Up @@ -540,6 +545,11 @@ def analyze_member_var_access(
return AnyType(TypeOfAny.special_form)

# Could not find the member.
if itype.extra_attrs and name in itype.extra_attrs.attrs:
# For modules use direct symbol table lookup.
if not itype.extra_attrs.mod_name:
return itype.extra_attrs.attrs[name]

if mx.is_super:
mx.msg.undefined_in_superclass(name, mx.context)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -744,7 +754,7 @@ def analyze_var(
else:
result = expanded_signature
else:
if not var.is_ready:
if not var.is_ready and not mx.no_deferral:
mx.not_ready_callback(var.name, mx.context)
# Implicit 'Any' type.
result = AnyType(TypeOfAny.special_form)
Expand Down Expand Up @@ -858,6 +868,10 @@ def analyze_class_attribute_access(

node = info.get(name)
if not node:
if itype.extra_attrs and name in itype.extra_attrs.attrs:
# For modules use direct symbol table lookup.
if not itype.extra_attrs.mod_name:
return itype.extra_attrs.attrs[name]
if info.fallback_to_any:
return apply_class_attr_hook(mx, hook, AnyType(TypeOfAny.special_form))
return None
Expand Down
19 changes: 17 additions & 2 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from mypy.erasetype import erase_type
from mypy.maptype import map_instance_to_supertype
from mypy.state import state
from mypy.subtypes import is_callable_compatible, is_equivalent, is_proper_subtype, is_subtype
from mypy.subtypes import (
is_callable_compatible,
is_equivalent,
is_proper_subtype,
is_same_type,
is_subtype,
)
from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback
from mypy.types import (
AnyType,
Expand Down Expand Up @@ -61,11 +67,20 @@ def meet_types(s: Type, t: Type) -> ProperType:
"""Return the greatest lower bound of two types."""
if is_recursive_pair(s, t):
# This case can trigger an infinite recursion, general support for this will be
# tricky so we use a trivial meet (like for protocols).
# tricky, so we use a trivial meet (like for protocols).
return trivial_meet(s, t)
s = get_proper_type(s)
t = get_proper_type(t)

if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type:
# Code in checker.py should merge any extra_items where possible, so we
# should have only one instance with extra_items here. We check this before
# the below subtype check, so that extra_attrs will not get erased.
if is_same_type(s, t) and (s.extra_attrs or t.extra_attrs):
if s.extra_attrs:
return s
return t

if not isinstance(s, UnboundType) and not isinstance(t, UnboundType):
if is_proper_subtype(s, t, ignore_promotions=True):
return s
Expand Down
6 changes: 3 additions & 3 deletions mypy/server/objgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def get_edges(o: object) -> Iterator[tuple[object, object]]:
# in closures and self pointers to other objects

if hasattr(e, "__closure__"):
yield (s, "__closure__"), e.__closure__ # type: ignore[union-attr]
yield (s, "__closure__"), e.__closure__
if hasattr(e, "__self__"):
se = e.__self__ # type: ignore[union-attr]
se = e.__self__
if se is not o and se is not type(o) and hasattr(s, "__self__"):
yield s.__self__, se # type: ignore[attr-defined]
yield s.__self__, se
else:
if not type(e) in TYPE_BLACKLIST:
yield s, e
Expand Down
36 changes: 34 additions & 2 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeQuery,
TypeType,
Expand Down Expand Up @@ -104,7 +105,7 @@ def tuple_fallback(typ: TupleType) -> Instance:
raise NotImplementedError
else:
items.append(item)
return Instance(info, [join_type_list(items)])
return Instance(info, [join_type_list(items)], extra_attrs=typ.partial_fallback.extra_attrs)


def get_self_type(func: CallableType, default_self: Instance | TupleType) -> Type | None:
Expand Down Expand Up @@ -462,7 +463,20 @@ def make_simplified_union(
):
simplified_set = try_contracting_literals_in_union(simplified_set)

return get_proper_type(UnionType.make_union(simplified_set, line, column))
result = get_proper_type(UnionType.make_union(simplified_set, line, column))

# Step 4: At last, we erase any (inconsistent) extra attributes on instances.
extra_attrs_set = set()
for item in items:
instance = try_getting_instance_fallback(item)
if instance and instance.extra_attrs:
extra_attrs_set.add(instance.extra_attrs)

fallback = try_getting_instance_fallback(result)
if len(extra_attrs_set) > 1 and fallback:
fallback.extra_attrs = None

return result


def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
Expand Down Expand Up @@ -984,3 +998,21 @@ def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequen
union_items.append(item)

return literal_items, union_items


def try_getting_instance_fallback(typ: Type) -> Instance | None:
"""Returns the Instance fallback for this type if one exists or None."""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return typ
elif isinstance(typ, TupleType):
return typ.partial_fallback
elif isinstance(typ, TypedDictType):
return typ.fallback
elif isinstance(typ, FunctionLike):
return typ.fallback
elif isinstance(typ, LiteralType):
return typ.fallback
elif isinstance(typ, TypeVarType):
return try_getting_instance_fallback(typ.upper_bound)
return None
Loading