Skip to content

Commit

Permalink
Make TypeQuery more general, handling nonboolean queries.
Browse files Browse the repository at this point in the history
Instead of TypeQuery always returning a boolean and having the strategy be an
enum, the strategy is now a Callable describing how to combine partial results,
and the two default strategies are plain old funcitons.

To preserve the short-circuiting behavior of the previous code, this PR uses an
exception.

This is a pure refactor that I am using in my experimentation regarding fixing
python#1551.  It should result in exactly no
change to current behavior. It's separable from the other things I'm
experimenting with, so I'm filing it as a separate pull request now. It enables
me to rewrite the code that pulls type variables out of types as a TypeQuery.

Consider waiting to merge this PR until I have some code that uses it ready for
review.  Or merge it now, if you think it's a pleasant cleanup instead of an
ugly complication.  I'm of two minds on that particular question.
  • Loading branch information
sixolet committed Mar 29, 2017
1 parent 5434979 commit 41d8a66
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 59 deletions.
6 changes: 3 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2332,7 +2332,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


class ArgInferSecondPassQuery(types.TypeQuery):
class ArgInferSecondPassQuery(types.TypeQuery[bool]):
"""Query whether an argument type should be inferred in the second pass.
The result is True if the type has a type variable in a callable return
Expand All @@ -2346,7 +2346,7 @@ def visit_callable_type(self, t: CallableType) -> bool:
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())


class HasTypeVarQuery(types.TypeQuery):
class HasTypeVarQuery(types.TypeQuery[bool]):
"""Visitor for querying whether a type has a type variable component."""
def __init__(self) -> None:
super().__init__(False, types.ANY_TYPE_STRATEGY)
Expand All @@ -2359,7 +2359,7 @@ def has_erased_component(t: Type) -> bool:
return t is not None and t.accept(HasErasedComponentsQuery())


class HasErasedComponentsQuery(types.TypeQuery):
class HasErasedComponentsQuery(types.TypeQuery[bool]):
"""Visitor for querying whether a type has an erased component."""
def __init__(self) -> None:
super().__init__(False, types.ANY_TYPE_STRATEGY)
Expand Down
2 changes: 1 addition & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def is_complete_type(typ: Type) -> bool:
return typ.accept(CompleteTypeVisitor())


class CompleteTypeVisitor(TypeQuery):
class CompleteTypeVisitor(TypeQuery[bool]):
def __init__(self) -> None:
super().__init__(default=True, strategy=ALL_TYPES_STRATEGY)

Expand Down
2 changes: 1 addition & 1 deletion mypy/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def is_imprecise(t: Type) -> bool:
return t.accept(HasAnyQuery())


class HasAnyQuery(TypeQuery):
class HasAnyQuery(TypeQuery[bool]):
def __init__(self) -> None:
super().__init__(False, ANY_TYPE_STRATEGY)

Expand Down
108 changes: 54 additions & 54 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from typing import (
Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union, Iterable,
NamedTuple,
NamedTuple, Callable,
)

import mypy.nodes
Expand Down Expand Up @@ -1486,112 +1486,112 @@ def keywords_str(self, a: Iterable[Tuple[str, Type]]) -> str:
])


# These constants define the method used by TypeQuery to combine multiple
# query results, e.g. for tuple types. The strategy is not used for empty
# result lists; in that case the default value takes precedence.
ANY_TYPE_STRATEGY = 0 # Return True if any of the results are True.
ALL_TYPES_STRATEGY = 1 # Return True if all of the results are True.
# Combination strategies for boolean type queries
def ANY_TYPE_STRATEGY(current: bool, accumulated: bool) -> bool:
"""True if any type's result is True"""
if accumulated:
raise ShortCircuitQuery()
return current


class TypeQuery(TypeVisitor[bool]):
"""Visitor for performing simple boolean queries of types.
def ALL_TYPES_STRATEGY(current: bool, accumulated: bool) -> bool:
"""True if all types' results are True"""
if not accumulated:
raise ShortCircuitQuery()
return current

This class allows defining the default value for leafs to simplify the
implementation of many queries.
"""

default = False # Default result
strategy = 0 # Strategy for combining multiple values (ANY_TYPE_STRATEGY or ALL_TYPES_...).
class ShortCircuitQuery(Exception):
pass

def __init__(self, default: bool, strategy: int) -> None:
"""Construct a query visitor.

Use the given default result and strategy for combining
multiple results. The strategy must be either
ANY_TYPE_STRATEGY or ALL_TYPES_STRATEGY.
"""
class TypeQuery(Generic[T], TypeVisitor[T]):
"""Visitor for performing queries of types.
default is used as the query result unless a method for that type is
overridden.
strategy is used to combine a partial result with a result for a particular
type in a series of types.
Common use cases involve a boolean query using ANY_TYPE_STRATEGY and a
default of False or ALL_TYPES_STRATEGY and a default of True.
"""

def __init__(self, default: T, strategy: Callable[[T, T], T]) -> None:
self.default = default
self.strategy = strategy

def visit_unbound_type(self, t: UnboundType) -> bool:
def visit_unbound_type(self, t: UnboundType) -> T:
return self.default

def visit_type_list(self, t: TypeList) -> bool:
def visit_type_list(self, t: TypeList) -> T:
return self.default

def visit_error_type(self, t: ErrorType) -> bool:
def visit_error_type(self, t: ErrorType) -> T:
return self.default

def visit_any(self, t: AnyType) -> bool:
def visit_any(self, t: AnyType) -> T:
return self.default

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
def visit_uninhabited_type(self, t: UninhabitedType) -> T:
return self.default

def visit_none_type(self, t: NoneTyp) -> bool:
def visit_none_type(self, t: NoneTyp) -> T:
return self.default

def visit_erased_type(self, t: ErasedType) -> bool:
def visit_erased_type(self, t: ErasedType) -> T:
return self.default

def visit_deleted_type(self, t: DeletedType) -> bool:
def visit_deleted_type(self, t: DeletedType) -> T:
return self.default

def visit_type_var(self, t: TypeVarType) -> bool:
def visit_type_var(self, t: TypeVarType) -> T:
return self.default

def visit_partial_type(self, t: PartialType) -> bool:
def visit_partial_type(self, t: PartialType) -> T:
return self.default

def visit_instance(self, t: Instance) -> bool:
def visit_instance(self, t: Instance) -> T:
return self.query_types(t.args)

def visit_callable_type(self, t: CallableType) -> bool:
def visit_callable_type(self, t: CallableType) -> T:
# FIX generics
return self.query_types(t.arg_types + [t.ret_type])

def visit_tuple_type(self, t: TupleType) -> bool:
def visit_tuple_type(self, t: TupleType) -> T:
return self.query_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> bool:
def visit_typeddict_type(self, t: TypedDictType) -> T:
return self.query_types(t.items.values())

def visit_star_type(self, t: StarType) -> bool:
def visit_star_type(self, t: StarType) -> T:
return t.type.accept(self)

def visit_union_type(self, t: UnionType) -> bool:
def visit_union_type(self, t: UnionType) -> T:
return self.query_types(t.items)

def visit_overloaded(self, t: Overloaded) -> bool:
def visit_overloaded(self, t: Overloaded) -> T:
return self.query_types(t.items())

def visit_type_type(self, t: TypeType) -> bool:
def visit_type_type(self, t: TypeType) -> T:
return t.item.accept(self)

def query_types(self, types: Iterable[Type]) -> bool:
def query_types(self, types: Iterable[Type]) -> T:
"""Perform a query for a list of types.
Use the strategy constant to combine the results.
Use the strategy to combine the results.
"""
if not types:
# Use default result for empty list.
return self.default
if self.strategy == ANY_TYPE_STRATEGY:
# Return True if at least one component is true.
res = False
res = self.default
try:
for t in types:
res = res or t.accept(self)
if res:
break
return res
else:
# Return True if all components are true.
res = True
for t in types:
res = res and t.accept(self)
if not res:
break
return res
res = self.strategy(t.accept(self), res)
except ShortCircuitQuery:
pass
return res


def strip_type(typ: Type) -> Type:
Expand Down

0 comments on commit 41d8a66

Please sign in to comment.