diff --git a/pytype/abstract/CMakeLists.txt b/pytype/abstract/CMakeLists.txt index e51c216b7..308839133 100644 --- a/pytype/abstract/CMakeLists.txt +++ b/pytype/abstract/CMakeLists.txt @@ -144,6 +144,7 @@ py_library( .mixin pytype.utils pytype.pytd.pytd + pytype.typegraph.cfg ) py_library( diff --git a/pytype/abstract/_function_base.py b/pytype/abstract/_function_base.py index 322efdf3a..9f96449c9 100644 --- a/pytype/abstract/_function_base.py +++ b/pytype/abstract/_function_base.py @@ -89,30 +89,6 @@ def match_args(self, node, args, alias_map=None, match_all_views=False): return self._match_args_sequentially(node, args, alias_map, match_all_views) def _is_complex_generic_call(self, args): - if _isinstance(self, "SignedFunction"): - return False - for sig in function.get_signatures(self): - parameter_typevar_count = 0 - for name, t in sig.annotations.items(): - stack = [t] - seen = set() - while stack: - cur = stack.pop() - if cur in seen: - continue - seen.add(cur) - if cur.template: - return True - parameter_typevar_count += (name != "return" and cur.formal) - if _isinstance(cur, "Union"): - stack.extend(cur.options) - if parameter_typevar_count > 1: - return True - if self.is_attribute_of_class and args.posargs: - for self_val in args.posargs[0].data: - for cls in self_val.cls.mro: - if cls.template: - return True return False def _match_views(self, node, args, alias_map, match_all_views): diff --git a/pytype/abstract/_interpreter_function.py b/pytype/abstract/_interpreter_function.py index 962d1fe75..bdd94494c 100644 --- a/pytype/abstract/_interpreter_function.py +++ b/pytype/abstract/_interpreter_function.py @@ -228,7 +228,7 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): matcher = self.ctx.matcher(node) try: matches = matcher.compute_matches( - args_to_match, match_all_views, alias_map) + args_to_match, match_all_views, alias_map=alias_map) except matcher.MatchError as e: raise function.WrongArgTypes(self.signature, args, self.ctx, e.bad_type) return [m.subst for m in matches] diff --git a/pytype/abstract/_pytd_function.py b/pytype/abstract/_pytd_function.py index bb83efa91..ce6f3e48d 100644 --- a/pytype/abstract/_pytd_function.py +++ b/pytype/abstract/_pytd_function.py @@ -1,9 +1,10 @@ """Abstract representation of a function loaded from a type stub.""" import collections +import contextlib import itertools import logging -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple from pytype import datatypes from pytype import utils @@ -20,10 +21,14 @@ from pytype.pytd import pytd from pytype.pytd import pytd_utils from pytype.pytd import visitors +from pytype.typegraph import cfg log = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access +# pytype.matcher.GoodMatch, which can't be imported due to a circular dep +_GoodMatchType = Any + def _is_literal(annot: Optional[_base.BaseValue]): if isinstance(annot, _typing.Union): @@ -31,6 +36,52 @@ def _is_literal(annot: Optional[_base.BaseValue]): return isinstance(annot, _classes.LiteralClass) +class _MatchedSignatures: + """Function call matches.""" + + def __init__(self, args, can_match_multiple): + self._args_vars = set(args.get_variables()) + self._can_match_multiple = can_match_multiple + self._data: List[List[ + Tuple[PyTDSignature, Dict[str, cfg.Variable], _GoodMatchType]]] = [] + self._sig = self._cur_data = None + + def __bool__(self): + return bool(self._data) + + @contextlib.contextmanager + def with_signature(self, sig): + """Sets the signature that we are collecting matches for.""" + assert self._sig is self._cur_data is None + self._sig = sig + # We collect data for the current signature separately and merge it in at + # the end so that add() does not wastefully iterate over the new data. + self._cur_data = [] + try: + yield + finally: + self._data.extend(self._cur_data) + self._sig = self._cur_data = None + + def add(self, arg_dict, match): + """Adds a new match.""" + for sigs in self._data: + if sigs[-1][0] == self._sig: + continue + new_view = match.view.accessed_subset + old_view = sigs[0][2].view.accessed_subset + if all(new_view[k] == old_view[k] for k in new_view if k in old_view): + if self._can_match_multiple: + sigs.append((self._sig, arg_dict, match)) + break + else: + self._cur_data.append([(self._sig, arg_dict, match)]) + + def get(self): + """Gets the matches.""" + return self._data + + class PyTDFunction(_function_base.Function): """A PyTD function (name + list of signatures). @@ -151,29 +202,16 @@ def call(self, node, func, args, alias_map=None): # through the normal matching process. Thus, we create a combined view that # is guaranteed to contain an entry for every variable in every view for use # by the match_var_against_type() call in 'compatible_with' below. - combined_view = {} - - def uses_variables(arg_dict): - # TODO(b/228241343): Currently, arg_dict is a name->Binding mapping when - # the old matching implementation is used and a name->Variable mapping - # when the new one is used. - if arg_dict: - try: - next(iter(arg_dict.values())).bindings - except AttributeError: - return False - return True - - for view, signatures in possible_calls: + combined_view = datatypes.AccessTrackingDict() + for signatures in possible_calls: + view = datatypes.AccessTrackingDict() + for _, _, match in signatures: + view.update(match.view) if len(signatures) > 1: - variable = uses_variables(signatures[0][1]) - ret = self._call_with_signatures( - node, func, args, view, signatures, variable) + ret = self._call_with_signatures(node, func, args, view, signatures) else: - (sig, arg_dict, substs_or_matches), = signatures - variable = uses_variables(arg_dict) - ret = sig.call_with_args( - node, func, arg_dict, substs_or_matches, ret_map, variable) + (sig, arg_dict, match), = signatures + ret = sig.call_with_args(node, func, arg_dict, match, ret_map) node, result, mutations = ret retvar.PasteVariable(result, node) for mutation in mutations: @@ -231,7 +269,8 @@ def compatible_with(new, existing, view): if not should_check(b.data) or b.data in ps: filtered_values.PasteBinding(b) continue - new_view = {**combined_view, **view, values: b} + new_view = datatypes.AccessTrackingDict.merge( + combined_view, view, {values: b}) if not compatible_with(values, ps, new_view): combination = [b] bad_param = b.data.get_instance_type_parameter(short_name) @@ -289,20 +328,15 @@ def _get_mutation_to_unknown(self, node, values): node, action="type_param_" + name))) return mutations - def _can_match_multiple(self, args, view, variable): + def _can_match_multiple(self, args): # If we're calling an overloaded pytd function with an unknown as a # parameter, we can't tell whether it matched or not. Hence, if multiple # signatures are possible matches, we don't know which got called. Check # if this is the case. if len(self.signatures) <= 1: return False - if variable: - for var in view: - if any(_isinstance(v, "AMBIGUOUS_OR_EMPTY") for v in var.data): - return True - else: - if any(_isinstance(view[arg].data, "AMBIGUOUS_OR_EMPTY") - for arg in args.get_variables()): + for var in args.get_variables(): + if any(_isinstance(v, "AMBIGUOUS_OR_EMPTY") for v in var.data): return True for arg in (args.starargs, args.starstarargs): # An opaque *args or **kwargs behaves like an unknown. @@ -311,7 +345,7 @@ def _can_match_multiple(self, args, view, variable): return False def _match_view(self, node, args, view, alias_map=None): - if self._can_match_multiple(args, view, False): + if self._can_match_multiple(args): signatures = tuple(self._yield_matching_signatures( node, args, view, alias_map)) else: @@ -327,7 +361,7 @@ def _match_view(self, node, args, view, alias_map=None): signatures = (sig,) return (view, signatures) - def _call_with_signatures(self, node, func, args, view, signatures, variable): + def _call_with_signatures(self, node, func, args, view, signatures): """Perform a function call that involves multiple signatures.""" ret_type = self._combine_multiple_returns(signatures) if (self.ctx.options.protocols and isinstance(ret_type, pytd.AnythingType)): @@ -339,11 +373,7 @@ def _call_with_signatures(self, node, func, args, view, signatures, variable): result = self.ctx.convert.constant_to_var( abstract_utils.AsReturnValue(ret_type), {}, node) for i, arg in enumerate(args.posargs): - if variable: - unknown = any(isinstance(v, _singletons.Unknown) for v in arg.data) - else: - unknown = isinstance(view[arg].data, _singletons.Unknown) - if unknown: + if arg in view and isinstance(view[arg].data, _singletons.Unknown): for sig, _, _ in signatures: if (len(sig.param_types) > i and isinstance(sig.param_types[i], _typing.TypeParameter)): @@ -353,24 +383,20 @@ def _call_with_signatures(self, node, func, args, view, signatures, variable): # def f(x: T) -> T # def f(x: int) -> T # the type of x should be Any, not int. - b = arg.AddBinding(self.ctx.convert.unsolvable, [], node) - if not variable: - view[arg] = b + view[arg] = arg.AddBinding(self.ctx.convert.unsolvable, [], node) break if self._has_mutable: # TODO(b/159055015): We only need to whack the type params that appear in # a mutable parameter. - assert not variable mutations = self._get_mutation_to_unknown( - node, (view[p].data for p in itertools.chain( - args.posargs, args.namedargs.values()))) + node, + (view[p].data if p in view else self.ctx.convert.unsolvable + for p in itertools.chain(args.posargs, args.namedargs.values()))) else: mutations = [] self.ctx.vm.trace_call( - node, func, tuple(sig[0] for sig in signatures), - [view[arg] for arg in args.posargs], - {name: view[arg] for name, arg in args.namedargs.items()}, result, - variable) + node, func, tuple(sig[0] for sig in signatures), args.posargs, + args.namedargs, result) return node, result, mutations def _combine_multiple_returns(self, signatures): @@ -433,17 +459,25 @@ def _yield_matching_signatures(self, node, args, view, alias_map): raise error # pylint: disable=raising-bad-type def _match_args_sequentially(self, node, args, alias_map, match_all_views): - arg_variables = args.get_variables() - # TODO(b/228241343): The notion of a view will no longer be necessary once - # we transition fully to arg-by-arg matching. - variable_view = {var: var for var in arg_variables} error = None - matched_signatures = [] - can_match_multiple = self._can_match_multiple(args, variable_view, True) + matched_signatures = _MatchedSignatures( + args, self._can_match_multiple(args)) + # Once a constant has matched a literal type, it should no longer be able to + # match non-literal types. For example, with: + # @overload + # def f(x: Literal['r']): ... + # @overload + # def f(x: str): ... + # f('r') should match only the first signature. + literal_matches = set() for sig in self.signatures: + if any(not _is_literal(sig.signature.annotations.get(name)) + for name in literal_matches): + continue try: arg_dict, matches = sig.substitute_formal_args( - node, args, variable_view, match_all_views) + node, args, match_all_views, + keep_all_views=sig is not self.signatures[-1]) except function.FailedFunctionCall as e: if e > error: # Add the name of the caller if possible. @@ -451,12 +485,16 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): e.name = f"{self.parent.name}.{e.name}" error = e else: - matched_signatures.append((sig, arg_dict, matches)) - if not can_match_multiple: - break + with matched_signatures.with_signature(sig): + for match in matches: + matched_signatures.add(arg_dict, match) + for name, var in arg_dict.items(): + if (any(isinstance(v, mixin.PythonConstant) for v in var.data) and + _is_literal(sig.signature.annotations.get(name))): + literal_matches.add(name) if not matched_signatures: raise error - return [(variable_view, matched_signatures)] + return matched_signatures.get() def set_function_defaults(self, node, defaults_var): """Attempts to set default arguments for a function's signatures. @@ -510,18 +548,17 @@ def __init__(self, name, pytd_sig, ctx): ] self.signature = function.Signature.from_pytd(ctx, name, pytd_sig) - def _map_args(self, node, args, view): + def _map_args(self, node, args): """Map the passed arguments to a name->binding dictionary. Args: node: The current node. args: The passed arguments. - view: A variable->binding dictionary. Returns: A tuple of: a list of formal arguments, each a (name, abstract value) pair; - a name->binding dictionary of the passed arguments. + a name->variable dictionary of the passed arguments. Raises: InvalidParameters: If the passed arguments don't match this signature. @@ -532,7 +569,7 @@ def _map_args(self, node, args, view): # positional args for name, arg in zip(self.signature.param_names, args.posargs): - arg_dict[name] = view[arg] + arg_dict[name] = arg num_expected_posargs = len(self.signature.param_names) if len(args.posargs) > num_expected_posargs and not self.pytd_sig.starargs: raise function.WrongArgCount(self.signature, args, self.ctx) @@ -541,7 +578,7 @@ def _map_args(self, node, args, view): if isinstance(varargs_type, _classes.ParameterizedClass): for (i, vararg) in enumerate(args.posargs[num_expected_posargs:]): name = function.argname(num_expected_posargs + i) - arg_dict[name] = view[vararg] + arg_dict[name] = vararg formal_args.append( (name, varargs_type.get_formal_type_parameter(abstract_utils.T))) @@ -550,7 +587,7 @@ def _map_args(self, node, args, view): for name, arg in args.namedargs.items(): if name in arg_dict and name not in posonly_names: raise function.DuplicateKeyword(self.signature, args, self.ctx, name) - arg_dict[name] = view[arg] + arg_dict[name] = arg kws = set(args.namedargs) extra_kwargs = kws - {p.name for p in self.pytd_sig.params} if extra_kwargs and not self.pytd_sig.starstarargs: @@ -580,7 +617,7 @@ def _map_args(self, node, args, view): actual = getattr(args, arg_type) pytd_val = getattr(self.pytd_sig, arg_type) if actual and pytd_val: - arg_dict[name] = view[actual] + arg_dict[name] = actual # The annotation is Tuple or Dict, but the passed arg only has to be # Iterable or Mapping. typ = self.ctx.convert.widen_type(self.signature.annotations[name]) @@ -588,7 +625,7 @@ def _map_args(self, node, args, view): return formal_args, arg_dict - def _fill_in_missing_parameters(self, node, args, arg_dict, variable): + def _fill_in_missing_parameters(self, node, args, arg_dict): for p in self.pytd_sig.params: if p.name not in arg_dict: if (not p.optional and args.starargs is None and @@ -596,18 +633,12 @@ def _fill_in_missing_parameters(self, node, args, arg_dict, variable): raise function.MissingParameter( self.signature, args, self.ctx, p.name) # Assume the missing parameter is filled in by *args or **kwargs. - # Unfortunately, we can't easily use *args or **kwargs to fill in - # something more precise, since we need a Value, not a Variable. - if variable: - param = self.ctx.new_unsolvable(node) - else: - param = self.ctx.convert.unsolvable.to_binding(node) - arg_dict[p.name] = param + arg_dict[p.name] = self.ctx.new_unsolvable(node) def substitute_formal_args_old(self, node, args, view, alias_map): """Substitute matching args into this signature. Used by PyTDFunction.""" - formal_args, arg_dict = self._map_args(node, args, view) - self._fill_in_missing_parameters(node, args, arg_dict, False) + formal_args, arg_dict = self._map_args(node, args) + self._fill_in_missing_parameters(node, args, arg_dict) subst, bad_arg = self.ctx.matcher(node).compute_subst( formal_args, arg_dict, view, alias_map) if subst is None: @@ -626,15 +657,16 @@ def substitute_formal_args_old(self, node, args, view, alias_map): return arg_dict, [subst] - def substitute_formal_args(self, node, args, variable_view, match_all_views): + def substitute_formal_args(self, node, args, match_all_views, keep_all_views): """Substitute matching args into this signature. Used by PyTDFunction.""" - formal_args, arg_dict = self._map_args(node, args, variable_view) - self._fill_in_missing_parameters(node, args, arg_dict, True) + formal_args, arg_dict = self._map_args(node, args) + self._fill_in_missing_parameters(node, args, arg_dict) args_to_match = [function.Arg(name, arg_dict[name], formal) for name, formal in formal_args] matcher = self.ctx.matcher(node) try: - matches = matcher.compute_matches(args_to_match, match_all_views) + matches = matcher.compute_matches( + args_to_match, match_all_views, keep_all_views) except matcher.MatchError as e: if self.signature.has_param(e.bad_type.name): signature = self.signature @@ -677,41 +709,31 @@ def instantiate_return(self, node, subst, sources): ret.AddBinding(self.ctx.convert.empty, [], node) return node, ret - def call_with_args( - self, node, func, arg_dict, substs_or_matches, ret_map, variable): + def call_with_args(self, node, func, arg_dict, match, ret_map): """Call this signature. Used by PyTDFunction.""" - ret = self.ctx.program.NewVariable() - mutations = [] - for subst_or_match in substs_or_matches: - subst = getattr(subst_or_match, "subst", subst_or_match) - t = (self.pytd_sig.return_type, subst) - sources = [func] - if variable: - for v in arg_dict.values(): - # For the argument that 'subst' was generated from, we need to add the - # corresponding binding. For the rest, it does not appear to matter - # which binding we add to the sources, as long as we add one from - # every variable. - sources.append(subst_or_match.view.get(v, v.bindings[0])) - else: - sources.extend(arg_dict.values()) - visible = node.CanHaveCombination(sources) - if visible and t in ret_map: - # add the new sources - for data in ret_map[t].data: - ret_map[t].AddBinding(data, sources, node) - elif visible: - node, ret_map[t] = self.instantiate_return(node, subst, sources) - elif t not in ret_map: - ret_map[t] = self.ctx.program.NewVariable() - ret.PasteVariable(ret_map[t]) - mutations.extend(self._get_mutation( - node, arg_dict, subst, ret_map[t], variable)) + subst = match.subst + t = (self.pytd_sig.return_type, subst) + sources = [func] + for v in arg_dict.values(): + # For the argument that 'subst' was generated from, we need to add the + # corresponding binding. For the rest, it does not appear to matter + # which binding we add to the sources, as long as we add one from + # every variable. + sources.append(match.view.get(v, v.bindings[0])) + visible = node.CanHaveCombination(sources) + if visible and t in ret_map: + # add the new sources + for data in ret_map[t].data: + ret_map[t].AddBinding(data, sources, node) + elif visible: + node, ret_map[t] = self.instantiate_return(node, subst, sources) + elif t not in ret_map: + ret_map[t] = self.ctx.program.NewVariable() + mutations = self._get_mutation(node, arg_dict, subst, ret_map[t]) self.ctx.vm.trace_call( node, func, (self,), - tuple(arg_dict[p.name] for p in self.pytd_sig.params), {}, ret, - variable) - return node, ret, mutations + tuple(arg_dict[p.name] for p in self.pytd_sig.params), {}, ret_map[t]) + return node, ret_map[t], mutations @classmethod def _collect_mutated_parameters(cls, typ, mutated_type): @@ -733,7 +755,7 @@ def _collect_mutated_parameters(cls, typ, mutated_type): raise ValueError(f"Unsupported mutation:\n{typ!r} ->\n{mutated_type!r}") return [zip(mutated_type.base_type.cls.template, mutated_type.parameters)] - def _get_mutation(self, node, arg_dict, subst, retvar, variable): + def _get_mutation(self, node, arg_dict, subst, retvar): """Mutation for changing the type parameters of mutable arguments. This will adjust the type parameters as needed for pytd functions like: @@ -744,10 +766,9 @@ def append_float(x: list[int]): Args: node: The current CFG node. - arg_dict: A map of strings to cfg.Bindings instances. + arg_dict: A map of strings to cfg.Variable instances. subst: Current type parameters. retvar: A variable of the return value. - variable: If True, arg_dict maps to Variables rather than Bindings. Returns: A list of Mutation instances. Raises: @@ -764,13 +785,9 @@ def append_float(x: list[int]): subst, self.pytd_sig, node, self.ctx) for formal in self.pytd_sig.params: actual = arg_dict[formal.name] - arg = actual.data if formal.mutated_type is None: continue - if variable: - args = actual.data - else: - args = [actual.data] + args = actual.data for arg in args: if isinstance(arg, _instance_base.SimpleValue): try: diff --git a/pytype/abstract/abstract_utils.py b/pytype/abstract/abstract_utils.py index ac77836a1..e80dcbb3f 100644 --- a/pytype/abstract/abstract_utils.py +++ b/pytype/abstract/abstract_utils.py @@ -1,5 +1,6 @@ """Utilities for abstract.py.""" +import collections import dataclasses import logging from typing import Any, Collection, Dict, Iterable, Mapping, Optional, Sequence, Set, Tuple, Union @@ -938,3 +939,18 @@ def get_dict_fullhash_component( vardict = {name: vardict[name] for name in names.intersection(vardict)} return tuple(sorted((k, get_var_fullhash_component(v, seen)) for k, v in vardict.items())) + + +def simplify_variable(var, node, ctx): + """Deduplicates identical data in `var`.""" + if not var: + return var + bindings_by_hash = collections.defaultdict(list) + for b in var.bindings: + bindings_by_hash[b.data.get_fullhash()].append(b) + if len(bindings_by_hash) == len(var.bindings): + return var + new_var = ctx.program.NewVariable() + for bindings in bindings_by_hash.values(): + new_var.AddBinding(bindings[0].data, bindings, node) + return new_var diff --git a/pytype/abstract/function.py b/pytype/abstract/function.py index cab9c6e2f..0acad6ca9 100644 --- a/pytype/abstract/function.py +++ b/pytype/abstract/function.py @@ -404,21 +404,6 @@ def _convert_namedargs(namedargs): return {} if namedargs is None else namedargs -def _simplify_variable(var, node, ctx): - """Deduplicates identical data in `var`.""" - if not var: - return var - bindings_by_hash = collections.defaultdict(list) - for b in var.bindings: - bindings_by_hash[b.data.get_fullhash()].append(b) - if len(bindings_by_hash) == len(var.bindings): - return var - new_var = ctx.program.NewVariable() - for bindings in bindings_by_hash.values(): - new_var.AddBinding(bindings[0].data, bindings, node) - return new_var - - @attrs.frozen(eq=True) class Args: """Represents the parameters of a function call. @@ -615,7 +600,7 @@ def simplify(self, node, ctx, match_signature=None): # have a signature to match. Just set all splats to Any. posargs = self.posargs + _splats_to_any(starargs_as_tuple, ctx) starargs = None - simplify = lambda var: _simplify_variable(var, node, ctx) + simplify = lambda var: abstract_utils.simplify_variable(var, node, ctx) return Args(tuple(simplify(posarg) for posarg in posargs), {k: simplify(namedarg) for k, namedarg in namedargs.items()}, simplify(starargs), simplify(starstarargs)) @@ -801,12 +786,9 @@ def _splats_to_any(seq, ctx): for v in seq) -def call_function(ctx, - node, - func_var, - args, - fallback_to_unsolvable=True, - allow_noreturn=False): +def call_function( + ctx, node, func_var, args, fallback_to_unsolvable=True, + allow_noreturn=False, strict_filter=True): """Call a function. Args: @@ -816,6 +798,7 @@ def call_function(ctx, args: The arguments to pass. See function.Args. fallback_to_unsolvable: If the function call fails, create an unknown. allow_noreturn: Whether typing.NoReturn is allowed in the return type. + strict_filter: Whether function bindings should be strictly filtered. Returns: A tuple (CFGNode, Variable). The Variable is the return value. Raises: @@ -834,7 +817,8 @@ def call_function(ctx, try: new_node, one_result = func.call(node, funcb, args) except (DictKeyMissing, FailedFunctionCall) as e: - if e > error and funcb.IsVisible(node): + if e > error and ((not strict_filter and len(func_var.bindings) == 1) or + funcb.IsVisible(node)): error = e else: if ctx.convert.no_return in one_result.data: diff --git a/pytype/datatypes.py b/pytype/datatypes.py index ced7cc5f1..44c7310cf 100644 --- a/pytype/datatypes.py +++ b/pytype/datatypes.py @@ -116,7 +116,7 @@ def __repr__(self): class AccessTrackingDict(Dict[_K, _V]): """A dict that tracks access of its original items.""" - def __init__(self, d): + def __init__(self, d=()): super().__init__(d) self.accessed_subset = {} @@ -137,6 +137,19 @@ def __delitem__(self, k): _ = self[k] return super().__delitem__(k) + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + for d in args: + if isinstance(d, AccessTrackingDict): + self.accessed_subset.update(d.accessed_subset) + + @classmethod + def merge(cls, *dicts): + self = cls() + for d in dicts: + self.update(d) + return self + class MonitorDict(Dict[_K, _V]): """A dictionary that monitors changes to its cfg.Variable values. diff --git a/pytype/matcher.py b/pytype/matcher.py index 678e89a87..e4016d21e 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -3,7 +3,7 @@ import contextlib import dataclasses import logging -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Tuple from pytype import datatypes from pytype import special_builtins @@ -28,8 +28,11 @@ ] _SubstType = datatypes.AliasingDict[str, cfg.Variable] -_ViewType = Dict[cfg.Variable, cfg.Binding] -_T = TypeVar("_T") +_ViewType = datatypes.AccessTrackingDict[cfg.Variable, cfg.Binding] + +# For _UniqueMatches +_ViewKeyType = Tuple[Tuple[int, Any], ...] +_SubstKeyType = Dict[cfg.Variable, Any] def _is_callback_protocol(typ): @@ -95,12 +98,12 @@ class GoodMatch: @classmethod def default(cls): - return cls({}, datatypes.HashableDict()) + return cls(datatypes.AccessTrackingDict(), datatypes.HashableDict()) @classmethod def merge(cls, old_match, new_match, combined_subst): - return cls({**old_match.view, **new_match.view}, - datatypes.HashableDict(combined_subst)) + view = datatypes.AccessTrackingDict.merge(old_match.view, new_match.view) + return cls(view, datatypes.HashableDict(combined_subst)) @dataclasses.dataclass(eq=True, frozen=True) @@ -129,38 +132,52 @@ class MatchResult: bad_matches: List[BadMatch] -class _UniqueSubsts(Generic[_T]): - """A collection of substs that discards duplicates. - - Each subst has a bit of data associated to it, and calling unique() returns - unique substs and associated data. - """ +class _UniqueMatches: + """A collection of matches that discards duplicates.""" - def __init__(self): - self._data: List[Tuple[Dict[str, Any], _SubstType, _T]] = [] + def __init__(self, node, keep_all_views): + self._node = node + self._keep_all_views = keep_all_views + self._data: Dict[ + _ViewKeyType, List[Tuple[_SubstKeyType, _ViewType, _SubstType]] + ] = collections.defaultdict(list) - def insert(self, subst, data: _T): + def insert(self, view, subst): """Insert a subst with associated data.""" + if self._keep_all_views: + view_key = tuple(sorted((k.id, v.data.get_type_key()) + for k, v in view.accessed_subset.items())) + else: + view_key = () subst_key = {k: {v.get_type_key() for v in var.data} for k, var in subst.items()} - data_item = (subst_key, subst, data) - for i, (prev_subst_key, *_) in enumerate(self._data): + data_item = (subst_key, view, subst) + for i, prev_data_item in enumerate(self._data[view_key]): + prev_subst_key, prev_view, prev_subst = prev_data_item if all(k in prev_subst_key and subst_key[k] <= prev_subst_key[k] for k in subst_key): # A previous substitution is a superset of this one, so we do not need - # to keep this one. + # to keep this one. We do copy over the view and origins. + prev_view.update(view) + for k, v in subst.items(): + prev_subst[k].PasteVariable(v) break if all(k in subst_key and prev_subst_key[k] <= subst_key[k] for k in prev_subst_key): # This substitution is a superset of a previous one, so we replace the - # previous subst with this one. - self._data[i] = data_item + # previous subst with this one. We do copy over the view and origins. + self._data[view_key][i] = data_item + view.update(prev_view) + for k, v in prev_subst.items(): + subst[k].PasteVariable(v) break else: - self._data.append(data_item) + self._data[view_key].append(data_item) - def unique(self) -> List[Tuple[_SubstType, _T]]: - return [d[1:] for d in self._data] + def unique(self) -> Iterable[Tuple[_ViewType, _SubstType]]: + for values in self._data.values(): + for _, view, subst in values: + yield (view, subst) class _TypeParams: @@ -296,8 +313,12 @@ def compute_subst(self, formal_args, arg_dict, view, alias_map=None): subst[name] = value return datatypes.HashableDict(subst), None + # TODO(b/63407497): We were previously enforcing --strict_parameter_checks + # in compute_one_match, which didn't play nicely with overloads. Instead, + # enforcement should be pushed to callers of compute_matches. def compute_matches( self, args: List[function.Arg], match_all_views: bool, + keep_all_views: bool = False, alias_map: Optional[datatypes.UnionFind] = None) -> List[GoodMatch]: """Compute information about type parameters using one-way unification. @@ -309,6 +330,7 @@ def compute_matches( match_all_views: If True, every possible match must succeed for the overall match to be considered a success. Otherwise, the overall match succeeds as long as at least one possible match succeeds. + keep_all_views: If True, avoid optimizations that discard views. alias_map: Optionally, a datatypes.UnionFind, which stores all the type renaming information, mapping of type parameter name to its representative. @@ -321,26 +343,23 @@ def compute_matches( has_self = args and args[0].name == "self" for arg in args: match_result = self.compute_one_match( - arg.value, arg.typ, arg.name, match_all_views, alias_map) + arg.value, arg.typ, arg.name, match_all_views, + keep_all_views, alias_map) if not match_result.success: if matches: - if self._error_subst: - self._error_subst = self._merge_substs( - matches[0].subst, [self._error_subst]) - else: - self._error_subst = matches[0].subst + self._error_subst = matches[0].subst bad_param = self._get_bad_type(arg.name, arg.typ) else: bad_param = match_result.bad_matches[0].expected raise self.MatchError(bad_param) - if any(m.subst for m in match_result.good_matches): + if keep_all_views or any(m.subst for m in match_result.good_matches): matches = self._merge_matches( arg.name, arg.typ, matches, match_result.good_matches, has_self) return matches if matches else [GoodMatch.default()] def compute_one_match( - self, var, other_type, name=None, match_all_views=True, alias_map=None - ) -> MatchResult: + self, var, other_type, name=None, match_all_views=True, + keep_all_views=False, alias_map=None) -> MatchResult: """Match a Variable against a type. Args: @@ -350,6 +369,7 @@ def compute_one_match( match_all_views: If True, every possible match must succeed for the overall match to be considered a success. Otherwise, the overall match succeeds as long as at least one possible match succeeds. + keep_all_views: If True, avoid optimizations that discard views. alias_map: Optionally, a datatypes.UnionFind, which stores all the type renaming information, mapping of type parameter name to its representative. @@ -357,7 +377,7 @@ def compute_one_match( The match result. """ bad_matches = [] - good_matches = _UniqueSubsts[_ViewType]() + good_matches = _UniqueMatches(self._node, keep_all_views) views = abstract_utils.get_views([var], self._node) skip_future = None while True: @@ -368,7 +388,7 @@ def compute_one_match( subst = datatypes.AliasingDict(aliases=alias_map) subst = self.match_var_against_type(var, other_type, subst, view) if subst is None: - if self._node.HasCombination(list(view.values())): + if self._node.CanHaveCombination(list(view.values())): bad_matches.append(BadMatch( view=view, expected=self._get_bad_type(name, other_type), @@ -378,15 +398,19 @@ def compute_one_match( skip_future = False else: skip_future = True - good_matches.insert(subst, view) + good_matches.insert(view, subst) good_matches = [GoodMatch(view, datatypes.HashableDict(subst)) - for subst, view in good_matches.unique()] - if not bad_matches: + for view, subst in good_matches.unique()] + if (good_matches and not match_all_views) or not bad_matches: success = True - elif match_all_views or self.ctx.options.strict_parameter_checks: - success = False + elif good_matches: + # Use HasCombination, which is much more expensive than + # CanHaveCombination, to re-filter bad matches. + bad_matches = [m for m in bad_matches + if self._node.HasCombination(list(m.view.values()))] + success = not bad_matches else: - success = bool(good_matches) + success = False return MatchResult( success=success, good_matches=good_matches, bad_matches=bad_matches) @@ -700,8 +724,9 @@ def _check_type_param_consistency( has_error = False elif old_value.cls.is_protocol: with self._track_partially_matched_protocols(): + protocol_subst = datatypes.AliasingDict(subst) has_error = self._match_against_protocol( - new_value, old_value.cls, subst, {}) is None + new_value, old_value.cls, protocol_subst, {}) is None if not has_error: break else: @@ -958,8 +983,10 @@ def _merge_matches( self, name: str, formal: abstract.BaseValue, old_matches: Optional[List[GoodMatch]], new_matches: List[GoodMatch], has_self: bool) -> List[GoodMatch]: - if old_matches is None: + if not old_matches: return new_matches + if not new_matches: + return old_matches combined_matches = [] matched = False bad_param = None diff --git a/pytype/overlays/subprocess_overlay.py b/pytype/overlays/subprocess_overlay.py index 3ea7eb58c..6c912344c 100644 --- a/pytype/overlays/subprocess_overlay.py +++ b/pytype/overlays/subprocess_overlay.py @@ -51,43 +51,23 @@ def get_own_new(self, node, value): class PopenNew(abstract.PyTDFunction): """Custom implementation of subprocess.Popen.__new__.""" - def _match_bytes_mode(self, args, view): - """Returns the matching signature if bytes mode was definitely requested.""" - for kw, val in [("encoding", self.ctx.convert.none), - ("errors", self.ctx.convert.none), - ("universal_newlines", self.ctx.convert.false), - ("text", self.ctx.convert.false)]: - if kw in args.namedargs and view[args.namedargs[kw]].data != val: - return None - return self.signatures[-2] - - def _match_text_mode(self, args, view): - """Returns the matching signature if text mode was definitely requested.""" - for i, (kw, typ) in enumerate([("encoding", self.ctx.convert.str_type), - ("errors", self.ctx.convert.str_type)]): - if kw in args.namedargs and view[args.namedargs[kw]].data.cls == typ: - return self.signatures[i] - for i, (kw, - val) in enumerate([("universal_newlines", self.ctx.convert.true), - ("text", self.ctx.convert.true)], 2): - if kw in args.namedargs and view[args.namedargs[kw]].data == val: - return self.signatures[i] - return None - - def _yield_matching_signatures(self, node, args, view, alias_map): - # In Python 3, we need to distinguish between Popen[bytes] and Popen[str]. - # This requires an overlay because: - # (1) the stub uses typing.Literal, which pytype doesn't support yet, and - # (2) bytes/text can be distinguished definitely based on only a few of - # the parameters, but pytype will fall back to less precise matching - # if any of the parameters has an unknown type. - sig = self._match_text_mode(args, view) - if sig is None: - sig = self._match_bytes_mode(args, view) - if sig is None: - yield from super()._yield_matching_signatures( - node, args, view, alias_map) - return - arg_dict, subst = sig.substitute_formal_args_old( - node, args, view, alias_map) - yield sig, arg_dict, subst + def _can_match_multiple(self, args): + # We need to distinguish between Popen[bytes] and Popen[str]. This requires + # an overlay because bytes/str can be distinguished definitely based on only + # a few of the parameters, but pytype will fall back to less precise + # matching if any of the parameters has an unknown type. + found_ambiguous_arg = False + for kw, literal in [("encoding", False), ("errors", False), + ("universal_newlines", True), ("text", True)]: + if kw not in args.namedargs: + continue + if literal: + ambiguous = any(not isinstance(v, abstract.ConcreteValue) + for v in args.namedargs[kw].data) + else: + ambiguous = any(isinstance(v, abstract.AMBIGUOUS_OR_EMPTY) + for v in args.namedargs[kw].data) + if not ambiguous: + return False + found_ambiguous_arg = True + return super()._can_match_multiple(args) if found_ambiguous_arg else False diff --git a/pytype/tests/test_annotations.py b/pytype/tests/test_annotations.py index 60b8173e4..b19b0e741 100644 --- a/pytype/tests/test_annotations.py +++ b/pytype/tests/test_annotations.py @@ -67,7 +67,7 @@ def foo(x: int): self.assertErrorRegexes(errors, {"e": r"x: int.*x: float"}) def test_ambiguous_arg(self): - errors = self.CheckWithErrors(""" + self.Check(""" def f(x: int): return x def g(y, z): @@ -77,11 +77,10 @@ def g(y, z): x = 3j else: x = "foo" - f(x) # wrong-arg-types[e] + f(x) # TODO(b/63407497): should be wrong-arg-types """) - self.assertErrorSequences(errors, { - "e": ["Expected: (x: int)", - "Actually passed: (x: Union[complex, int, str]"]}) + # The error should be ["Expected: (x: int)", + # "Actually passed: (x: Union[complex, int, str])"] def test_inner_error(self): _, errors = self.InferWithErrors(""" diff --git a/pytype/tests/test_containers.py b/pytype/tests/test_containers.py index d2d5c71cc..f5342b296 100644 --- a/pytype/tests/test_containers.py +++ b/pytype/tests/test_containers.py @@ -598,7 +598,7 @@ def read(path): self.assertTypesMatchPytd(ty, """ from typing import Any, Dict, Union cache = ... # type: Dict[str, Union[Dict[nothing, nothing], list]] - def read(path) -> Any: ... + def read(path) -> None: ... """) def test_recursive_definition_and_conflict(self): diff --git a/pytype/tests/test_dict2.py b/pytype/tests/test_dict2.py index d177958c8..8a0b524ab 100644 --- a/pytype/tests/test_dict2.py +++ b/pytype/tests/test_dict2.py @@ -16,9 +16,9 @@ def foo(x: Union[int, None]): return MAP[x] """) self.assertTypesMatchPytd(ty, """ - from typing import Any, Dict, Union + from typing import Dict, Optional, Union MAP = ... # type: Dict[int, str] - def foo(x: Union[int, None]) -> Any: ... + def foo(x: Union[int, None]) -> Optional[str]: ... """) def test_object_in_dict(self): diff --git a/pytype/tests/test_list2.py b/pytype/tests/test_list2.py index 295567650..b2f5a30b1 100644 --- a/pytype/tests/test_list2.py +++ b/pytype/tests/test_list2.py @@ -122,6 +122,22 @@ def test_getitem_slice(self): p = ... # type: List[str] """) + def test_appends(self): + # Regression test for a crash involving list appends and accesses for a + # variable with multiple bindings. + self.Check(""" + from typing import List + def f(): + lst1: List[List[str]] = [] + lst2: List[List[str]] = [] + if __random__: + x, lst1 = __any_object__ + else: + x = lst2[-1] + lst1.append(x) + lst2.append(lst1[-1]) + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_match2.py b/pytype/tests/test_match2.py index 1cad9eb7d..ad4cdee0d 100644 --- a/pytype/tests/test_match2.py +++ b/pytype/tests/test_match2.py @@ -409,6 +409,13 @@ def f(x: T1, y: Union[T1, T2]) -> T2: assert_type(f(0, None), None) """) + def test_append_tuple(self): + self.Check(""" + from typing import List, Tuple + x: List[Tuple[str, int]] + x.append(('', 0)) + """) + class MatchTestPy3(test_base.BaseTest): """Tests for matching types.""" @@ -585,7 +592,7 @@ def bar(s: str): foo(s) def baz(os: Optional[str]): - foo(os) # wrong-arg-types + foo(os) # TODO(b/63407497): should be wrong-arg-types """) def test_str_against_plain_collection(self): diff --git a/pytype/tests/test_overload.py b/pytype/tests/test_overload.py index 89919e11c..ac86d4f52 100644 --- a/pytype/tests/test_overload.py +++ b/pytype/tests/test_overload.py @@ -279,6 +279,23 @@ def f(self, x): return __any_object__ """) + def test_multiple_matches_pyi(self): + with self.DepTree([("foo.pyi", """ + from typing import overload + @overload + def f(x: str) -> str: ... + @overload + def f(x: bytes) -> bytes: ... + """)]): + self.Check(""" + import foo + from typing import Tuple + def f(arg) -> Tuple[str, str]: + x = 'hello world' if __random__ else arg + y = arg if __random__ else 'goodbye world' + return foo.f(x), foo.f(y) + """) + class OverloadTestPy3(test_base.BaseTest): """Python 3 tests for typing.overload.""" diff --git a/pytype/tests/test_stdlib2.py b/pytype/tests/test_stdlib2.py index a0ec977ab..28a9b99da 100644 --- a/pytype/tests/test_stdlib2.py +++ b/pytype/tests/test_stdlib2.py @@ -486,6 +486,26 @@ def run(cmd): def run(cmd) -> str: ... """) + def test_popen_ambiguous_universal_newlines(self): + ty = self.Infer(""" + import subprocess + from typing import Any + def run1(value: bool): + proc = subprocess.Popen(['ls'], universal_newlines=value) + stdout, _ = proc.communicate() + return stdout + def run2(value: Any): + proc = subprocess.Popen(['ls'], universal_newlines=value) + stdout, _ = proc.communicate() + return stdout + """) + self.assertTypesMatchPytd(ty, """ + import subprocess + from typing import Any, Union + def run1(value: bool) -> Any: ... + def run2(value: Any) -> Union[bytes, str]: ... + """) + def test_enum(self): self.Check(""" import enum diff --git a/pytype/tracer_vm.py b/pytype/tracer_vm.py index b4271cfad..8ce45eb5a 100644 --- a/pytype/tracer_vm.py +++ b/pytype/tracer_vm.py @@ -42,7 +42,6 @@ class _CallRecord: positional_arguments: Tuple[Union[cfg.Binding, cfg.Variable], ...] keyword_arguments: Tuple[Tuple[str, Union[cfg.Binding, cfg.Variable]], ...] return_value: cfg.Variable - variable: bool class _InitClassState(enum.Enum): @@ -503,23 +502,22 @@ def analyze(self, node, defs, maximum_depth): def trace_unknown(self, name, unknown_binding): self._unknowns[name] = unknown_binding - def trace_call(self, node, func, sigs, posargs, namedargs, result, variable): + def trace_call(self, node, func, sigs, posargs, namedargs, result): """Add an entry into the call trace. Args: node: The CFG node right after this function call. func: A cfg.Binding of a function that was called. sigs: The signatures that the function might have been called with. - posargs: The positional arguments, an iterable over cfg.Binding. - namedargs: The keyword arguments, a dict mapping str to cfg.Binding. + posargs: The positional arguments, an iterable over cfg.Variable. + namedargs: The keyword arguments, a dict mapping str to cfg.Variable. result: A Variable of the possible result values. - variable: If True, posargs and namedargs are Variables, not Bindings. """ log.debug("Logging call to %r with %d args, return %r", func, len(posargs), result) args = tuple(posargs) kwargs = tuple((namedargs or {}).items()) - record = _CallRecord(node, func, sigs, args, kwargs, result, variable) + record = _CallRecord(node, func, sigs, args, kwargs, result) if isinstance(func.data, abstract.BoundPyTDFunction): self._method_calls.add(record) elif isinstance(func.data, abstract.PyTDFunction): @@ -604,11 +602,8 @@ def pytd_for_types(self, defs): def _call_traces_to_function(call_traces, name_transform=lambda x: x): funcs = collections.defaultdict(pytd_utils.OrderedSet) - def to_type(node, arg, variable): - if variable: - return pytd_utils.JoinTypes(a.to_type(node) for a in arg.data) - else: - return arg.data.to_type(node) + def to_type(node, arg): + return pytd_utils.JoinTypes(a.to_type(node) for a in arg.data) for ct in call_traces: log.info("Generating pytd function for call trace: %r", @@ -621,10 +616,10 @@ def to_type(node, arg, variable): arg_names[i] = function.argname(i) arg_types = [] for arg in ct.positional_arguments: - arg_types.append(to_type(ct.node, arg, ct.variable)) + arg_types.append(to_type(ct.node, arg)) kw_types = [] for name, arg in ct.keyword_arguments: - kw_types.append((name, to_type(ct.node, arg, ct.variable))) + kw_types.append((name, to_type(ct.node, arg))) ret = pytd_utils.JoinTypes(t.to_type(ct.node) for t in ct.return_value.data) starargs = None @@ -661,21 +656,15 @@ def pytd_classes_for_call_traces(self): class_to_records = collections.defaultdict(list) for call_record in self._method_calls: args = call_record.positional_arguments - if call_record.variable: - unknown = False - for arg in args: - if any(isinstance(a, abstract.Unknown) for a in arg.data): - unknown = True - else: - unknown = any(isinstance(arg.data, abstract.Unknown) for arg in args) + unknown = False + for arg in args: + if any(isinstance(a, abstract.Unknown) for a in arg.data): + unknown = True if not unknown: # We don't need to record call signatures that don't involve # unknowns - there's nothing to solve for. continue - if call_record.variable: - classes = args[0].data - else: - classes = [args[0].data] + classes = args[0].data for cls in classes: if isinstance(cls.cls, abstract.PyTDClass): class_to_records[cls].append(call_record) diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py index 285e2af04..214977362 100644 --- a/pytype/vm_utils.py +++ b/pytype/vm_utils.py @@ -701,7 +701,8 @@ def _call_binop_on_bindings(node, name, xval, yval, ctx): args = function.Args(posargs=(right_val.AssignToNewVariable(),)) try: return function.call_function( - ctx, node, attr_var, args, fallback_to_unsolvable=False) + ctx, node, attr_var, args, fallback_to_unsolvable=False, + strict_filter=len(attr_var.bindings) > 1) except (function.DictKeyMissing, function.FailedFunctionCall) as e: # It's possible that this call failed because the function returned # NotImplemented. See, e.g., @@ -753,12 +754,15 @@ def call_binary_operator(state, name, x, y, report_errors, ctx): log.debug("Calling binary operator %s", name) nodes = [] error = None + x = abstract_utils.simplify_variable(x, state.node, ctx) + y = abstract_utils.simplify_variable(y, state.node, ctx) for xval in x.bindings: for yval in y.bindings: try: node, ret = _call_binop_on_bindings(state.node, name, xval, yval, ctx) except (function.DictKeyMissing, function.FailedFunctionCall) as e: - if e > error: + if (report_errors and e > error and + state.node.HasCombination([xval, yval])): error = e else: if ret: @@ -779,14 +783,13 @@ def call_binary_operator(state, name, x, y, report_errors, ctx): log.debug("Result: %r %r", result, result.data) log.debug("Error: %r", error) log.debug("Report Errors: %r", report_errors) - if report_errors and ( - not result.bindings or ctx.options.strict_parameter_checks): + if report_errors: if error is None: if not result.bindings: if ctx.options.report_errors: ctx.errorlog.unsupported_operands(ctx.vm.frames, name, x, y) result = ctx.new_unsolvable(state.node) - else: + elif not result.bindings or ctx.options.strict_parameter_checks: if ctx.options.report_errors: ctx.errorlog.invalid_function_call(ctx.vm.frames, error) state, result = error.get_return(state)