Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Expression and Statement proper subclasses of Node #2122

Closed
wants to merge 11 commits into from
87 changes: 42 additions & 45 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import os.path

from typing import (
Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple
Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Mapping
)

from mypy.errors import Errors, report_internal_error
from mypy.nodes import (
SymbolTable, Node, MypyFile, Var, Expression,
SymbolTable, Node, Statement, MypyFile, Var, Expression,
OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo,
ClassDef, GDEF, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr,
TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt,
Expand Down Expand Up @@ -229,7 +229,7 @@ def accept(self, node: Node, type_context: Type = None) -> Type:
else:
return typ

def accept_loop(self, body: Node, else_body: Node = None) -> Type:
def accept_loop(self, body: Statement, else_body: Statement = None) -> Type:
"""Repeatedly type check a loop body until the frame doesn't change.

Then check the else_body.
Expand Down Expand Up @@ -1039,7 +1039,8 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> Type:
for lv in s.lvalues[:-1]:
self.check_assignment(lv, rvalue, s.type is None)

def check_assignment(self, lvalue: Node, rvalue: Node, infer_lvalue_type: bool = True) -> None:
def check_assignment(self, lvalue: Expression, rvalue: Expression,
infer_lvalue_type: bool = True) -> None:
"""Type check a single assignment: lvalue = rvalue."""
if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr):
self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, lvalue,
Expand Down Expand Up @@ -1094,7 +1095,7 @@ def check_assignment(self, lvalue: Node, rvalue: Node, infer_lvalue_type: bool =
self.infer_variable_type(inferred, lvalue, self.accept(rvalue),
rvalue)

def check_assignment_to_multiple_lvalues(self, lvalues: List[Node], rvalue: Node,
def check_assignment_to_multiple_lvalues(self, lvalues: List[Expression], rvalue: Expression,
context: Context,
infer_lvalue_type: bool = True) -> None:
if isinstance(rvalue, TupleExpr) or isinstance(rvalue, ListExpr):
Expand Down Expand Up @@ -1129,7 +1130,7 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Node], rvalue: Node
else:
self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type)

def check_rvalue_count_in_assignment(self, lvalues: List[Node], rvalue_count: int,
def check_rvalue_count_in_assignment(self, lvalues: List[Expression], rvalue_count: int,
context: Context) -> bool:
if any(isinstance(lvalue, StarExpr) for lvalue in lvalues):
if len(lvalues) - 1 > rvalue_count:
Expand All @@ -1142,8 +1143,8 @@ def check_rvalue_count_in_assignment(self, lvalues: List[Node], rvalue_count: in
return False
return True

def check_multi_assignment(self, lvalues: List[Node],
rvalue: Node,
def check_multi_assignment(self, lvalues: List[Expression],
rvalue: Expression,
context: Context,
infer_lvalue_type: bool = True,
msg: str = None) -> None:
Expand All @@ -1165,7 +1166,7 @@ def check_multi_assignment(self, lvalues: List[Node],
self.check_multi_assignment_from_iterable(lvalues, rvalue_type,
context, infer_lvalue_type)

def check_multi_assignment_from_tuple(self, lvalues: List[Node], rvalue: Node,
def check_multi_assignment_from_tuple(self, lvalues: List[Expression], rvalue: Expression,
rvalue_type: TupleType, context: Context,
undefined_rvalue: bool,
infer_lvalue_type: bool = True) -> None:
Expand Down Expand Up @@ -1195,7 +1196,7 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Node], rvalue: Node,
for lv, rv_type in zip(right_lvs, right_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)

def lvalue_type_for_inference(self, lvalues: List[Node], rvalue_type: TupleType) -> Type:
def lvalue_type_for_inference(self, lvalues: List[Expression], rvalue_type: TupleType) -> Type:
star_index = next((i for i, lv in enumerate(lvalues)
if isinstance(lv, StarExpr)), len(lvalues))
left_lvs = lvalues[:star_index]
Expand All @@ -1206,7 +1207,7 @@ def lvalue_type_for_inference(self, lvalues: List[Node], rvalue_type: TupleType)

type_parameters = [] # type: List[Type]

def append_types_for_inference(lvs: List[Node], rv_types: List[Type]) -> None:
def append_types_for_inference(lvs: List[Expression], rv_types: List[Type]) -> None:
for lv, rv_type in zip(lvs, rv_types):
sub_lvalue_type, index_expr, inferred = self.check_lvalue(lv)
if sub_lvalue_type:
Expand Down Expand Up @@ -1251,7 +1252,7 @@ def type_is_iterable(self, type: Type) -> bool:
[AnyType()])) and
isinstance(type, Instance))

def check_multi_assignment_from_iterable(self, lvalues: List[Node], rvalue_type: Type,
def check_multi_assignment_from_iterable(self, lvalues: List[Expression], rvalue_type: Type,
context: Context,
infer_lvalue_type: bool = True) -> None:
if self.type_is_iterable(rvalue_type):
Expand All @@ -1266,7 +1267,7 @@ def check_multi_assignment_from_iterable(self, lvalues: List[Node], rvalue_type:
else:
self.msg.type_not_iterable(rvalue_type, context)

def check_lvalue(self, lvalue: Node) -> Tuple[Type, IndexExpr, Var]:
def check_lvalue(self, lvalue: Expression) -> Tuple[Type, IndexExpr, Var]:
lvalue_type = None # type: Type
index_lvalue = None # type: IndexExpr
inferred = None # type: Var
Expand Down Expand Up @@ -1296,7 +1297,7 @@ def check_lvalue(self, lvalue: Node) -> Tuple[Type, IndexExpr, Var]:

return lvalue_type, index_lvalue, inferred

def is_definition(self, s: Node) -> bool:
def is_definition(self, s: Expression) -> bool:
if isinstance(s, NameExpr):
if s.is_def:
return True
Expand All @@ -1312,7 +1313,7 @@ def is_definition(self, s: Node) -> bool:
return s.is_def
return False

def infer_variable_type(self, name: Var, lvalue: Node,
def infer_variable_type(self, name: Var, lvalue: Expression,
init_type: Type, context: Context) -> None:
"""Infer the type of initialized variables from initializer type."""
if self.typing_mode_weak():
Expand All @@ -1339,7 +1340,7 @@ def infer_variable_type(self, name: Var, lvalue: Node,

self.set_inferred_type(name, lvalue, init_type)

def infer_partial_type(self, name: Var, lvalue: Node, init_type: Type) -> bool:
def infer_partial_type(self, name: Var, lvalue: Expression, init_type: Type) -> bool:
if isinstance(init_type, (NoneTyp, UninhabitedType)):
partial_type = PartialType(None, name, [init_type])
elif isinstance(init_type, Instance):
Expand All @@ -1358,7 +1359,7 @@ def infer_partial_type(self, name: Var, lvalue: Node, init_type: Type) -> bool:
self.partial_types[-1][name] = lvalue
return True

def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None:
def set_inferred_type(self, var: Var, lvalue: Expression, type: Type) -> None:
"""Store inferred variable type.

Store the type to both the variable node and the expression node that
Expand All @@ -1368,7 +1369,7 @@ def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None:
var.type = type
self.store_type(lvalue, type)

def set_inference_error_fallback_type(self, var: Var, lvalue: Node, type: Type,
def set_inference_error_fallback_type(self, var: Var, lvalue: Expression, type: Type,
context: Context) -> None:
"""If errors on context line are ignored, store dummy type for variable.

Expand All @@ -1383,16 +1384,16 @@ def set_inference_error_fallback_type(self, var: Var, lvalue: Node, type: Type,
if context.get_line() in self.errors.ignored_lines[self.errors.file]:
self.set_inferred_type(var, lvalue, AnyType())

def narrow_type_from_binder(self, expr: Node, known_type: Type) -> Type:
def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
if expr.literal >= LITERAL_TYPE:
restriction = self.binder.get(expr)
if restriction:
ans = meet_simple(known_type, restriction)
return ans
return known_type

def check_simple_assignment(self, lvalue_type: Type, rvalue: Node,
context: Node,
def check_simple_assignment(self, lvalue_type: Type, rvalue: Expression,
context: Union[Expression, ImportBase],
msg: str = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
lvalue_name: str = 'variable',
rvalue_name: str = 'expression') -> Type:
Expand All @@ -1414,7 +1415,7 @@ def check_simple_assignment(self, lvalue_type: Type, rvalue: Node,
return rvalue_type

def check_indexed_assignment(self, lvalue: IndexExpr,
rvalue: Node, context: Context) -> None:
rvalue: Expression, context: Context) -> None:
"""Type check indexed assignment base[index] = rvalue.

The lvalue argument is the base[index] expression.
Expand All @@ -1429,7 +1430,7 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
context)

def try_infer_partial_type_from_indexed_assignment(
self, lvalue: IndexExpr, rvalue: Node) -> None:
self, lvalue: IndexExpr, rvalue: Expression) -> None:
# TODO: Should we share some of this with try_infer_partial_type?
if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var):
var = lvalue.base.node
Expand Down Expand Up @@ -1610,7 +1611,7 @@ def visit_raise_stmt(self, s: RaiseStmt) -> Type:
if s.from_expr:
self.type_check_raise(s.from_expr, s)

def type_check_raise(self, e: Node, s: RaiseStmt) -> None:
def type_check_raise(self, e: Expression, s: RaiseStmt) -> None:
typ = self.accept(e)
if isinstance(typ, FunctionLike):
if typ.is_type_obj():
Expand Down Expand Up @@ -1701,7 +1702,7 @@ def visit_try_without_finally(self, s: TryStmt) -> bool:
breaking_out = breaking_out and self.binder.last_pop_breaking_out
return breaking_out

def visit_except_handler_test(self, n: Node) -> Type:
def visit_except_handler_test(self, n: Expression) -> Type:
"""Type check an exception handler test clause."""
type = self.accept(n)

Expand Down Expand Up @@ -1737,7 +1738,7 @@ def visit_for_stmt(self, s: ForStmt) -> Type:
self.analyze_index_variables(s.index, item_type, s)
self.accept_loop(s.body, s.else_body)

def analyze_async_iterable_item_type(self, expr: Node) -> Type:
def analyze_async_iterable_item_type(self, expr: Expression) -> Type:
"""Analyse async iterable expression and return iterator item type."""
iterable = self.accept(expr)

Expand All @@ -1756,7 +1757,7 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type:
return self.check_awaitable_expr(awaitable, expr,
messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR)

def analyze_iterable_item_type(self, expr: Node) -> Type:
def analyze_iterable_item_type(self, expr: Expression) -> Type:
"""Analyse iterable expression and return iterator item type."""
iterable = self.accept(expr)

Expand Down Expand Up @@ -1791,7 +1792,7 @@ def analyze_iterable_item_type(self, expr: Node) -> Type:
expr)
return echk.check_call(method, [], [], expr)[0]

def analyze_index_variables(self, index: Node, item_type: Type,
def analyze_index_variables(self, index: Expression, item_type: Type,
context: Context) -> None:
"""Type check or infer for loop or list comprehension index vars."""
self.check_assignment(index, self.temp_node(item_type, context))
Expand All @@ -1805,7 +1806,7 @@ def visit_del_stmt(self, s: DelStmt) -> Type:
c.line = s.line
return c.accept(self)
else:
def flatten(t: Node) -> List[Node]:
def flatten(t: Expression) -> List[Expression]:
"""Flatten a nested sequence of tuples/lists into one list of nodes."""
if isinstance(t, TupleExpr) or isinstance(t, ListExpr):
return [b for a in t.items for b in flatten(a)]
Expand Down Expand Up @@ -2330,7 +2331,7 @@ def check_usable_type(self, typ: Type, context: Context) -> None:
if self.is_unusable_type(typ):
self.msg.does_not_return_value(typ, context)

def temp_node(self, t: Type, context: Context = None) -> Node:
def temp_node(self, t: Type, context: Context = None) -> Expression:
"""Create a temporary node with the given, fixed type."""
temp = TempNode(t)
if context:
Expand Down Expand Up @@ -2368,12 +2369,10 @@ def method_type(self, func: FuncBase) -> FunctionLike:
# probably be better to have the dict keyed by the nodes' literal_hash
# field instead.

# NB: This should be `TypeMap = Optional[Dict[Node, Type]]`!
# But see https://github.com/python/mypy/issues/1637
TypeMap = Dict[Node, Type]
TypeMap = Optional[Dict[Expression, Type]]


def conditional_type_map(expr: Node,
def conditional_type_map(expr: Expression,
current_type: Optional[Type],
proposed_type: Optional[Type],
*,
Expand Down Expand Up @@ -2405,7 +2404,7 @@ def conditional_type_map(expr: Node,
return {}, {}


def is_literal_none(n: Node) -> bool:
def is_literal_none(n: Expression) -> bool:
return isinstance(n, NameExpr) and n.fullname == 'builtins.None'


Expand Down Expand Up @@ -2453,8 +2452,8 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
return result


def find_isinstance_check(node: Node,
type_map: Dict[Node, Type],
def find_isinstance_check(node: Expression,
type_map: Mapping[Node, Type],
weak: bool=False
) -> Tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
Expand All @@ -2473,15 +2472,15 @@ def find_isinstance_check(node: Node,
expr = node.args[0]
if expr.literal == LITERAL_TYPE:
vartype = type_map[expr]
type = get_isinstance_type(node.args[1], type_map)
type = get_isinstance_type(node.args[1], type_map[node.args[1]])
return conditional_type_map(expr, vartype, type, weak=weak)
elif (isinstance(node, ComparisonExpr) and any(is_literal_none(n) for n in node.operands) and
experiments.STRICT_OPTIONAL):
# Check for `x is None` and `x is not None`.
is_not = node.operators == ['is not']
if is_not or node.operators == ['is']:
if_vars = {} # type: Dict[Node, Type]
else_vars = {} # type: Dict[Node, Type]
if_vars = {} # type: Dict[Expression, Type]
else_vars = {} # type: Dict[Expression, Type]
for expr in node.operands:
if expr.literal == LITERAL_TYPE and not is_literal_none(expr) and expr in type_map:
# This should only be true at most once: there should be
Expand All @@ -2500,7 +2499,7 @@ def find_isinstance_check(node: Node,
vartype = type_map[node]
if_type = true_only(vartype)
else_type = false_only(vartype)
ref = node # type: Node
ref = node # type: Expression
if_map = {ref: if_type} if not isinstance(if_type, UninhabitedType) else None
else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None
return if_map, else_map
Expand Down Expand Up @@ -2546,9 +2545,7 @@ def find_isinstance_check(node: Node,
return {}, {}


def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type:
type = type_map[node]

def get_isinstance_type(node: Expression, type: Type) -> Type:
if isinstance(type, TupleType):
all_types = type.items
else:
Expand All @@ -2573,7 +2570,7 @@ def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type:
return UnionType(types)


def expand_node(defn: Node, map: Dict[TypeVarId, Type]) -> Node:
def expand_node(defn: FuncItem, map: Dict[TypeVarId, Type]) -> Node:
visitor = TypeTransformVisitor(map)
return defn.accept(visitor)

Expand Down
Loading