Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic end-to-end support for Literal types #5947

29 changes: 23 additions & 6 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Translate an Expression to a Type value."""

from mypy.nodes import (
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr,
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr, IntExpr, UnaryExpr,
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
get_member_expr_fullname
)
from mypy.fastparse import parse_type_comment
from mypy.types import (
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument, TypeOfAny
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument, TypeOfAny,
LiteralType, RawLiteralType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two large scale questions:

  • Would it be better to call this just RawLiteral?
  • Does it need to be a subclass of Type?

It is OK to still put it in types.py, but wouldn't it be simpler to update UnboundType.args to be List[Union[Type, RawLiteral]]? I understand this may need some work, but currently we must implement several visitors and this type can sneak into type checking phase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into doing this, but I believe we currently reference UnboundType.args in approximately 60 difference places throughout our codebase. I'm worried that propagating this change won't be cheap: we'd likely end up having to change a bunch of type signatures so we can pass through Union[Type, RawLiteral] and/or add a bunch of runtime checks to filter out RawLiteral.

(We'd also need to modify both fastparse.TypeConverter and exprtotype so they return a Union[Type, RawLiteral], but I think this is a more tractable problem: we could just implement some sort of wrapper around them that reports an error and returns AnyType if the top-level type is RawLiteral)

If we're worried about RawLiterals and similar types sneaking through, I'd rather just add another layer to semantic analysis that does a final scan of all types to make sure that RawLiteral, UnboundType, and any other "bad" types are fully removed. That way, would need to scan and check each type just once, instead of in potentially multiple places whenever we manipulate UnboundType.args.

Or even better, we could add a new base type named TypeLike that Type, UnboundType, and RawLiteral all inherit from (along with any other fake types we don't want sneaking into the typechecking phase). UnboundType would then be modified so that its args contain instances of TypeLike rather then Type.

The semantic analysis layer would always work with TypeLike, and we'd modify TypeAnalyzer so it traverses over TypeLike and only ever returns Type -- this would probably let us skip the extra traversal while still giving us the guarantees we want in a typesafe way.

The last solution would probably be just as much work, if not more, as changing UnboundType.args to be of type Union[Type, RawLiteral], but I think it would be more principled and help us eliminate the root problem in a more fundamental way. (I'd also be willing to tackle making this change if you + the others think it's a good idea.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into doing this, but I believe we currently reference UnboundType.args in approximately 60 difference places throughout our codebase. I'm worried that propagating this change won't be cheap

I would recommend just trying this, if it will be soon obvious that it is too hard, then don't spend time on this. The current approach is also OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, changing the type of UnboundType.args seems a bit ad hoc to me. We already have some Type subclasses such as TypeList that are only meaningful during semantic analysis, and I think that adding another one doesn't make things significantly worse if we are careful. Most visitors will crash if they encounter the new type subclass, so issues will likely be found soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most visitors will crash if they encounter the new type subclass, so issues will likely be found soon.

Fair point.

)


Expand Down Expand Up @@ -37,7 +38,12 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
name = None # type: Optional[str]
if isinstance(expr, NameExpr):
name = expr.name
return UnboundType(name, line=expr.line, column=expr.column)
if name == 'True':
return RawLiteralType(True, 'builtins.bool', line=expr.line, column=expr.column)
elif name == 'False':
return RawLiteralType(False, 'builtins.bool', line=expr.line, column=expr.column)
else:
return UnboundType(name, line=expr.line, column=expr.column)
elif isinstance(expr, MemberExpr):
fullname = get_member_expr_fullname(expr)
if fullname:
Expand Down Expand Up @@ -108,11 +114,22 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
elif isinstance(expr, (StrExpr, BytesExpr, UnicodeExpr)):
# Parse string literal type.
try:
result = parse_type_comment(expr.value, expr.line, None)
assert result is not None
node = parse_type_comment(expr.value, expr.line, None)
assert node is not None
if isinstance(node, UnboundType) and node.original_str_expr is None:
node.original_str_expr = expr.value
return node
except SyntaxError:
return RawLiteralType(expr.value, 'builtins.str', line=expr.line, column=expr.column)
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr)
if isinstance(typ, RawLiteralType) and isinstance(typ.value, int) and expr.op == '-':
typ.value *= -1
return typ
else:
raise TypeTranslationError()
return result
elif isinstance(expr, IntExpr):
return RawLiteralType(expr.value, 'builtins.int', line=expr.line, column=expr.column)
elif isinstance(expr, EllipsisExpr):
return EllipsisType(expr.line)
else:
Expand Down
41 changes: 37 additions & 4 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from mypy.types import (
Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument,
TypeOfAny, Instance,
TypeOfAny, Instance, RawLiteralType,
)
from mypy import defaults
from mypy import messages
Expand All @@ -53,6 +53,9 @@
Expression as ast3_Expression,
Str,
Index,
Num,
UnaryOp,
USub,
)
except ImportError:
if sys.version_info.minor > 2:
Expand Down Expand Up @@ -1138,12 +1141,42 @@ def visit_Name(self, n: Name) -> Type:
return UnboundType(n.id, line=self.line)

def visit_NameConstant(self, n: NameConstant) -> Type:
return UnboundType(str(n.value))
if isinstance(n.value, bool):
return RawLiteralType(n.value, 'builtins.bool', line=self.line)
else:
return UnboundType(str(n.value), line=self.line)

# UnaryOp(op, operand)
def visit_UnaryOp(self, n: UnaryOp) -> Type:
# We support specifically Literal[-4] and nothing else.
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
# For example, Literal[+4] or Literal[~6] is not supported.
typ = self.visit(n.operand)
if isinstance(typ, RawLiteralType) and isinstance(n.op, USub):
if isinstance(typ.value, int):
typ.value *= -1
return typ
self.fail(TYPE_COMMENT_AST_ERROR, self.line, getattr(n, 'col_offset', -1))
return AnyType(TypeOfAny.from_error)

# Num(number n)
def visit_Num(self, n: Num) -> Type:
# Could be either float or int
numeric_value = n.n
if isinstance(numeric_value, int):
return RawLiteralType(numeric_value, 'builtins.int', line=self.line)
else:
self.fail(TYPE_COMMENT_AST_ERROR, self.line, getattr(n, 'col_offset', -1))
return AnyType(TypeOfAny.from_error)

# Str(string s)
def visit_Str(self, n: Str) -> Type:
return (parse_type_comment(n.s.strip(), self.line, self.errors) or
AnyType(TypeOfAny.from_error))
try:
node = parse_type_comment(n.s.strip(), self.line, errors=None)
if isinstance(node, UnboundType) and node.original_str_expr is None:
node.original_str_expr = n.s
return node or AnyType(TypeOfAny.from_error)
except SyntaxError:
return RawLiteralType(n.s, 'builtins.str', line=self.line)

# Subscript(expr value, slice slice, expr_context ctx)
def visit_Subscript(self, n: ast3.Subscript) -> Type:
Expand Down
3 changes: 3 additions & 0 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def visit_tuple_type(self, t: types.TupleType) -> Set[str]:
def visit_typeddict_type(self, t: types.TypedDictType) -> Set[str]:
return self._visit(t.items.values()) | self._visit(t.fallback)

def visit_raw_literal_type(self, t: types.RawLiteralType) -> Set[str]:
return set()
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

def visit_literal_type(self, t: types.LiteralType) -> Set[str]:
return self._visit(t.fallback)

Expand Down
7 changes: 7 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool:
elif isinstance(right, CallableType):
right = right.fallback

if isinstance(left, LiteralType) and isinstance(right, LiteralType):
return left == right
elif isinstance(left, LiteralType):
left = left.fallback
elif isinstance(right, LiteralType):
right = right.fallback

# Finally, we handle the case where left and right are instances.

if isinstance(left, Instance) and isinstance(right, Instance):
Expand Down
4 changes: 3 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy.erasetype import erase_type
from mypy.errors import Errors
from mypy.types import (
Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType,
Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType,
UnionType, NoneTyp, AnyType, Overloaded, FunctionLike, DeletedType, TypeType,
UninhabitedType, TypeOfAny, ForwardRef, UnboundType
)
Expand Down Expand Up @@ -297,6 +297,8 @@ def format_bare(self, typ: Type, verbosity: int = 0) -> str:
self.format_bare(item_type)))
s = 'TypedDict({{{}}})'.format(', '.join(items))
return s
elif isinstance(typ, LiteralType):
return str(typ)
elif isinstance(typ, UnionType):
# Only print Unions as Optionals if the Optional wouldn't have to contain another Union
print_as_optional = (len(typ.items) -
Expand Down
9 changes: 8 additions & 1 deletion mypy/semanal_pass3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from mypy.types import (
Type, Instance, AnyType, TypeOfAny, CallableType, TupleType, TypeVarType, TypedDictType,
UnionType, TypeType, Overloaded, ForwardRef, TypeTranslator, function_type
UnionType, TypeType, Overloaded, ForwardRef, TypeTranslator, function_type, LiteralType,
)
from mypy.errors import Errors, report_internal_error
from mypy.options import Options
Expand Down Expand Up @@ -704,6 +704,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
assert isinstance(fallback, Instance)
return TypedDictType(items, t.required_keys, fallback, t.line, t.column)

def visit_literal_type(self, t: LiteralType) -> Type:
if self.check_recursion(t):
return AnyType(TypeOfAny.from_error)
fallback = self.visit_instance(t.fallback, from_fallback=True)
assert isinstance(fallback, Instance)
return LiteralType(t.value, fallback, t.line, t.column)

def visit_union_type(self, t: UnionType) -> Type:
if self.check_recursion(t):
return AnyType(TypeOfAny.from_error)
Expand Down
4 changes: 4 additions & 0 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Type, SyntheticTypeVisitor, Instance, AnyType, NoneTyp, CallableType, DeletedType, PartialType,
TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType,
Overloaded, TypeVarDef, TypeList, CallableArgument, EllipsisType, StarType, LiteralType,
RawLiteralType,
)
from mypy.util import get_prefix, replace_object_state
from mypy.typestate import TypeState
Expand Down Expand Up @@ -391,6 +392,9 @@ def visit_typeddict_type(self, typ: TypedDictType) -> None:
value_type.accept(self)
typ.fallback.accept(self)

def visit_raw_literal_type(self, t: RawLiteralType) -> None:
pass
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

def visit_literal_type(self, typ: LiteralType) -> None:
typ.fallback.accept(self)

Expand Down
12 changes: 9 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,11 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
else:
return False

def visit_literal_type(self, t: LiteralType) -> bool:
raise NotImplementedError()
def visit_literal_type(self, left: LiteralType) -> bool:
if isinstance(self.right, LiteralType):
return left == self.right
else:
return self._is_subtype(left.fallback, self.right)
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

def visit_overloaded(self, left: Overloaded) -> bool:
right = self.right
Expand Down Expand Up @@ -1172,7 +1175,10 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
return self._is_proper_subtype(left.fallback, right)

def visit_literal_type(self, left: LiteralType) -> bool:
raise NotImplementedError()
if isinstance(self.right, LiteralType):
return left == self.right
else:
return self._is_proper_subtype(left.fallback, self.right)
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

def visit_overloaded(self, left: Overloaded) -> bool:
# TODO: What's the right thing to do here?
Expand Down
11 changes: 10 additions & 1 deletion mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
join = posixpath.join # type: ignore

out_section_missing = case.suite.required_out_section
normalize_output = True

files = [] # type: List[Tuple[str, str]] # path and contents
output_files = [] # type: List[Tuple[str, str]] # path and contents for output files
Expand Down Expand Up @@ -98,8 +99,11 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
full = join(base_path, m.group(1))
deleted_paths.setdefault(num, set()).add(full)
elif re.match(r'out[0-9]*$', item.id):
if item.arg == 'skip-path-normalization':
normalize_output = False

tmp_output = [expand_variables(line) for line in item.data]
if os.path.sep == '\\':
if os.path.sep == '\\' and normalize_output:
tmp_output = [fix_win_path(line) for line in tmp_output]
if item.id == 'out' or item.id == 'out1':
output = tmp_output
Expand Down Expand Up @@ -147,6 +151,7 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
case.expected_rechecked_modules = rechecked_modules
case.deleted_paths = deleted_paths
case.triggered = triggered or []
case.normalize_output = normalize_output


class DataDrivenTestCase(pytest.Item): # type: ignore # inheriting from Any
Expand All @@ -168,6 +173,10 @@ class DataDrivenTestCase(pytest.Item): # type: ignore # inheriting from Any
# Files/directories to clean up after test case; (is directory, path) tuples
clean_up = None # type: List[Tuple[bool, str]]

# Whether or not we should normalize the output to standardize things like
# forward vs backward slashes in file paths for Windows vs Linux.
normalize_output = True

def __init__(self,
parent: 'DataSuiteCollector',
suite: 'DataSuite',
Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
'check-ctypes.test',
'check-dataclasses.test',
'check-final.test',
'check-literal.test',
]


Expand Down Expand Up @@ -177,7 +178,8 @@ def run_case_once(self, testcase: DataDrivenTestCase,
assert sys.path[0] == plugin_dir
del sys.path[0]

a = normalize_error_messages(a)
if testcase.normalize_output:
a = normalize_error_messages(a)

# Make sure error messages match
if incremental_step == 0:
Expand Down
13 changes: 9 additions & 4 deletions mypy/test/testcmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ def test_python_cmdline(testcase: DataDrivenTestCase) -> None:
actual_output_content = output_file.read().splitlines()
normalized_output = normalize_file_output(actual_output_content,
os.path.abspath(test_temp_dir))
if testcase.suite.native_sep and os.path.sep == '\\':
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
# We always normalize things like timestamp, but only handle operating-system
# specific things if requested.
if testcase.normalize_output:
if testcase.suite.native_sep and os.path.sep == '\\':
normalized_output = [fix_cobertura_filename(line)
for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
'Output file {} did not match its expected output'.format(
path))
else:
out = normalize_error_messages(err + out)
if testcase.normalize_output:
out = normalize_error_messages(err + out)
obvious_result = 1 if out else 0
if obvious_result != result:
out.append('== Return code: {}'.format(result))
Expand Down
3 changes: 2 additions & 1 deletion mypy/test/testmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
# Verify that old AST nodes are removed from the expression type map.
assert expr not in new_types

a = normalize_error_messages(a)
if testcase.normalize_output:
a = normalize_error_messages(a)

assert_string_arrays_equal(
testcase.output, a,
Expand Down
9 changes: 7 additions & 2 deletions mypy/test/testsemanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'semanal-abstractclasses.test',
'semanal-namedtuple.test',
'semanal-typeddict.test',
'semenal-literal.test',
'semanal-classvar.test',
'semanal-python2.test']

Expand Down Expand Up @@ -78,6 +79,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
if (not f.path.endswith((os.sep + 'builtins.pyi',
'typing.pyi',
'mypy_extensions.pyi',
'typing_extensions.pyi',
'abc.pyi',
'collections.pyi'))
and not os.path.basename(f.path).startswith('_')
Expand All @@ -86,7 +88,8 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
if testcase.normalize_output:
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down Expand Up @@ -116,8 +119,10 @@ def test_semanal_error(testcase: DataDrivenTestCase) -> None:
# Verify that there was a compile error and that the error messages
# are equivalent.
a = e.messages
if testcase.normalize_output:
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, normalize_error_messages(a),
testcase.output, a,
'Invalid compiler output ({}, line {})'.format(testcase.file, testcase.line))


Expand Down
3 changes: 2 additions & 1 deletion mypy/test/testtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_transform(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
if testcase.normalize_output:
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down
Loading