diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index ac954d2fc..5acc80a85 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -30,7 +30,9 @@ ] -def _get_import_alias_names(import_aliases: Sequence[cst.ImportAlias]) -> Set[str]: +def _get_import_alias_names( + import_aliases: Sequence[cst.ImportAlias], +) -> Set[str]: import_names = set() for imported_name in import_aliases: asname = imported_name.asname @@ -41,7 +43,9 @@ def _get_import_alias_names(import_aliases: Sequence[cst.ImportAlias]) -> Set[st return import_names -def _get_import_names(imports: Sequence[Union[cst.Import, cst.ImportFrom]]) -> Set[str]: +def _get_import_names( + imports: Sequence[Union[cst.Import, cst.ImportFrom]], +) -> Set[str]: import_names = set() for _import in imports: if isinstance(_import, cst.Import): @@ -53,17 +57,23 @@ def _get_import_names(imports: Sequence[Union[cst.Import, cst.ImportFrom]]) -> S return import_names -def _is_set(x: Union[None, cst.CSTNode, cst.MaybeSentinel]) -> bool: +def _is_set( + x: Union[None, cst.CSTNode, cst.MaybeSentinel], +) -> bool: return x is not None and x != cst.MaybeSentinel.DEFAULT -def _get_string_value(node: cst.SimpleString) -> str: +def _get_string_value( + node: cst.SimpleString, +) -> str: s = node.value c = s[-1] return s[s.index(c) : -1] -def _find_generic_base(node: cst.ClassDef) -> Optional[cst.Arg]: +def _find_generic_base( + node: cst.ClassDef, +) -> Optional[cst.Arg]: for b in node.bases: if m.matches(b.value, m.Subscript(value=m.Name("Generic"))): return b @@ -79,13 +89,24 @@ class FunctionKey: star_kwarg: bool @classmethod - def make(cls, name: str, params: cst.Parameters) -> "FunctionKey": + def make( + cls, + name: str, + params: cst.Parameters, + ) -> "FunctionKey": pos = len(params.params) kwonly = ",".join(sorted(x.name.value for x in params.kwonly_params)) posonly = len(params.posonly_params) star_arg = _is_set(params.star_arg) star_kwarg = _is_set(params.star_kwarg) - return cls(name, pos, kwonly, posonly, star_arg, star_kwarg) + return cls( + name, + pos, + kwonly, + posonly, + star_arg, + star_kwarg, + ) @dataclass(frozen=True) @@ -104,7 +125,11 @@ class TypeCollector(m.MatcherDecoratableVisitor): QualifiedNameProvider, ) - def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: + def __init__( + self, + existing_imports: Set[str], + context: CodemodContext, + ) -> None: super().__init__() # Qualifier for storing the canonical name of the current function. self.qualifier: List[str] = [] @@ -118,7 +143,10 @@ def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: self.typevars: Dict[str, cst.Assign] = {} self.annotation_names: Set[str] = set() - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef( + self, + node: cst.ClassDef, + ) -> None: self.qualifier.append(node.name.value) new_bases = [] for base in node.bases: @@ -138,10 +166,16 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: self.class_definitions[node.name.value] = node.with_changes(bases=new_bases) - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + def leave_ClassDef( + self, + original_node: cst.ClassDef, + ) -> None: self.qualifier.pop() - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + def visit_FunctionDef( + self, + node: cst.FunctionDef, + ) -> bool: self.qualifier.append(node.name.value) returns = node.returns return_annotation = ( @@ -157,10 +191,16 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # pyi files don't support inner functions, return False to stop the traversal. return False - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + def leave_FunctionDef( + self, + original_node: cst.FunctionDef, + ) -> None: self.qualifier.pop() - def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: + def visit_AnnAssign( + self, + node: cst.AnnAssign, + ) -> bool: name = get_full_name_for_node(node.target) if name is not None: self.qualifier.append(name) @@ -168,18 +208,30 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: self.attribute_annotations[".".join(self.qualifier)] = annotation_value return True - def leave_AnnAssign(self, original_node: cst.AnnAssign) -> None: + def leave_AnnAssign( + self, + original_node: cst.AnnAssign, + ) -> None: self.qualifier.pop() - def visit_Assign(self, node: cst.Assign) -> None: + def visit_Assign( + self, + node: cst.Assign, + ) -> None: self.current_assign = node - def leave_Assign(self, original_node: cst.Assign) -> None: + def leave_Assign( + self, + original_node: cst.Assign, + ) -> None: self.current_assign = None @m.call_if_inside(m.Assign()) @m.visit(m.Call(func=m.Name("TypeVar"))) - def record_typevar(self, node: cst.Call) -> None: + def record_typevar( + self, + node: cst.Call, + ) -> None: # pyre-ignore current_assign is never None here name = get_full_name_for_node(self.current_assign.targets[0].target) if name: @@ -188,13 +240,19 @@ def record_typevar(self, node: cst.Call) -> None: self._handle_qualification_and_should_qualify("typing.TypeVar") self.current_assign = None - def leave_Module(self, original_node: cst.Module) -> None: + def leave_Module( + self, + original_node: cst.Module, + ) -> None: # Filter out unused typevars self.typevars = { k: v for k, v in self.typevars.items() if k in self.annotation_names } - def _get_unique_qualified_name(self, node: cst.CSTNode) -> str: + def _get_unique_qualified_name( + self, + node: cst.CSTNode, + ) -> str: name = None names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)] if len(names) == 0: @@ -221,7 +279,10 @@ def _get_qualified_name_and_dequalified_node( dequalified_node = node.attr if isinstance(node, cst.Attribute) else node return qualified_name, dequalified_node - def _module_and_target(self, qualified_name: str) -> Tuple[str, str]: + def _module_and_target( + self, + qualified_name: str, + ) -> Tuple[str, str]: relative_prefix = "" while qualified_name.startswith("."): relative_prefix += "." @@ -233,7 +294,10 @@ def _module_and_target(self, qualified_name: str) -> Tuple[str, str]: qualifier, target = split return (relative_prefix + qualifier, target) - def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: + def _handle_qualification_and_should_qualify( + self, + qualified_name: str, + ) -> bool: """ Based on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we @@ -248,7 +312,11 @@ def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: elif module in self.existing_imports: return True else: - AddImportsVisitor.add_needed_import(self.context, module, target) + AddImportsVisitor.add_needed_import( + self.context, + module, + target, + ) return False return False @@ -274,7 +342,10 @@ def _handle_NameOrAttribute( else: return dequalified_node - def _handle_Index(self, slice: cst.Index) -> cst.Index: + def _handle_Index( + self, + slice: cst.Index, + ) -> cst.Index: value = slice.value if isinstance(value, cst.Subscript): return slice.with_changes(value=self._handle_Subscript(value)) @@ -285,7 +356,10 @@ def _handle_Index(self, slice: cst.Index) -> cst.Index: self.annotation_names.add(_get_string_value(value)) return slice - def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: + def _handle_Subscript( + self, + node: cst.Subscript, + ) -> cst.Subscript: value = node.value if isinstance(value, NAME_OR_ATTRIBUTE): new_node = node.with_changes(value=self._handle_NameOrAttribute(value)) @@ -320,7 +394,10 @@ def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: else: return new_node - def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation: + def _handle_Annotation( + self, + annotation: cst.Annotation, + ) -> cst.Annotation: node = annotation.annotation if isinstance(node, cst.SimpleString): self.annotation_names.add(_get_string_value(node)) @@ -332,8 +409,13 @@ def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation: else: raise ValueError(f"Unexpected annotation node: {node}") - def _handle_Parameters(self, parameters: cst.Parameters) -> cst.Parameters: - def update_annotations(parameters: Sequence[cst.Param]) -> List[cst.Param]: + def _handle_Parameters( + self, + parameters: cst.Parameters, + ) -> cst.Parameters: + def update_annotations( + parameters: Sequence[cst.Param], + ) -> List[cst.Param]: updated_parameters = [] for parameter in list(parameters): annotation = parameter.annotation @@ -477,7 +559,10 @@ def store_stub_in_context( strict_annotation_matching, ) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: + def transform_module_impl( + self, + tree: cst.Module, + ) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. @@ -520,7 +605,13 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: tree_with_imports = AddImportsVisitor( context=self.context, imports=( - [ImportItem("__future__", "annotations", None)] + [ + ImportItem( + "__future__", + "annotations", + None, + ) + ] if self.use_future_annotations else () ), @@ -545,7 +636,11 @@ def _apply_annotation_to_attribute_or_global( self.annotation_counts.global_annotations += 1 else: self.annotation_counts.attribute_annotations += 1 - return cst.AnnAssign(cst.Name(name), annotation, value) + return cst.AnnAssign( + cst.Name(name), + annotation, + value, + ) def _apply_annotation_to_parameter( self, @@ -571,7 +666,9 @@ def _qualifier_name(self) -> str: return ".".join(self.qualifier) def _annotate_single_target( - self, node: cst.Assign, updated_node: cst.Assign + self, + node: cst.Assign, + updated_node: cst.Assign, ) -> Union[cst.Assign, cst.AnnAssign]: only_target = node.targets[0].target if isinstance(only_target, (cst.Tuple, cst.List)): @@ -604,7 +701,9 @@ def _annotate_single_target( return updated_node def _split_module( - self, module: cst.Module, updated_module: cst.Module + self, + module: cst.Module, + updated_module: cst.Module, ) -> Tuple[ List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], @@ -627,7 +726,10 @@ def _split_module( list(updated_module.body[import_add_location:]), ) - def _add_to_toplevel_annotations(self, name: str) -> None: + def _add_to_toplevel_annotations( + self, + name: str, + ) -> None: self.qualifier.append(name) if self._qualifier_name() in self.annotations.attribute_annotations: annotation = self.annotations.attribute_annotations[self._qualifier_name()] @@ -635,7 +737,9 @@ def _add_to_toplevel_annotations(self, name: str) -> None: self.qualifier.pop() def _update_parameters( - self, annotations: FunctionAnnotation, updated_node: cst.FunctionDef + self, + annotations: FunctionAnnotation, + updated_node: cst.FunctionDef, ) -> cst.Parameters: # Update params and default params with annotations # Don't override existing annotations or default values unless asked @@ -716,7 +820,8 @@ def _match_signatures( # noqa: C901: Too complex """Check that function annotations on both signatures are compatible.""" def compatible( - p: Optional[cst.Annotation], q: Optional[cst.Annotation] + p: Optional[cst.Annotation], + q: Optional[cst.Annotation], ) -> bool: if self.overwrite_existing_annotations or not _is_set(p) or not _is_set(q): return True @@ -726,7 +831,10 @@ def compatible( return True return p.annotation.deep_equals(q.annotation) # pyre-ignore[16] - def match_posargs(ps: Sequence[cst.Param], qs: Sequence[cst.Param]) -> bool: + def match_posargs( + ps: Sequence[cst.Param], + qs: Sequence[cst.Param], + ) -> bool: if len(ps) != len(qs): return False for p, q in zip(ps, qs): @@ -736,7 +844,10 @@ def match_posargs(ps: Sequence[cst.Param], qs: Sequence[cst.Param]) -> bool: return False return True - def match_kwargs(ps: Sequence[cst.Param], qs: Sequence[cst.Param]) -> bool: + def match_kwargs( + ps: Sequence[cst.Param], + qs: Sequence[cst.Param], + ) -> bool: ps_dict = {x.name.value: x for x in ps} qs_dict = {x.name.value: x for x in qs} if set(ps_dict.keys()) != set(qs_dict.keys()): @@ -746,10 +857,16 @@ def match_kwargs(ps: Sequence[cst.Param], qs: Sequence[cst.Param]) -> bool: return False return True - def match_star(p: StarParamType, q: StarParamType) -> bool: + def match_star( + p: StarParamType, + q: StarParamType, + ) -> bool: return _is_set(p) == _is_set(q) - def match_params(f: cst.FunctionDef, g: FunctionAnnotation) -> bool: + def match_params( + f: cst.FunctionDef, + g: FunctionAnnotation, + ) -> bool: p, q = f.params, g.parameters return ( match_posargs(p.params, q.params) @@ -759,7 +876,10 @@ def match_params(f: cst.FunctionDef, g: FunctionAnnotation) -> bool: and match_star(p.star_kwarg, q.star_kwarg) ) - def match_return(f: cst.FunctionDef, g: FunctionAnnotation) -> bool: + def match_return( + f: cst.FunctionDef, + g: FunctionAnnotation, + ) -> bool: return compatible(f.returns, g.returns) return match_params(function, annotations) and match_return( @@ -768,12 +888,17 @@ def match_return(f: cst.FunctionDef, g: FunctionAnnotation) -> bool: # transform API methods - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef( + self, + node: cst.ClassDef, + ) -> None: self.qualifier.append(node.name.value) self.visited_classes.add(node.name.value) def leave_ClassDef( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef + self, + original_node: cst.ClassDef, + updated_node: cst.ClassDef, ) -> cst.ClassDef: cls_name = ".".join(self.qualifier) self.qualifier.pop() @@ -787,13 +912,18 @@ def leave_ClassDef( return updated_node.with_changes(bases=new_bases) return updated_node - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + def visit_FunctionDef( + self, + node: cst.FunctionDef, + ) -> bool: self.qualifier.append(node.name.value) # pyi files don't support inner functions, return False to stop the traversal. return False def leave_FunctionDef( - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + self, + original_node: cst.FunctionDef, + updated_node: cst.FunctionDef, ) -> cst.FunctionDef: key = FunctionKey.make(self._qualifier_name(), updated_node.params) self.qualifier.pop() @@ -818,12 +948,18 @@ def leave_FunctionDef( return updated_node.with_changes(params=new_parameters) return updated_node - def visit_Assign(self, node: cst.Assign) -> None: + def visit_Assign( + self, + node: cst.Assign, + ) -> None: self.current_assign = node @m.call_if_inside(m.Assign()) @m.visit(m.Call(func=m.Name("TypeVar"))) - def record_typevar(self, node: cst.Call) -> None: + def record_typevar( + self, + node: cst.Call, + ) -> None: # pyre-ignore current_assign is never None here name = get_full_name_for_node(self.current_assign.targets[0].target) if name: @@ -836,7 +972,9 @@ def record_typevar(self, node: cst.Call) -> None: self.current_assign = None def leave_Assign( - self, original_node: cst.Assign, updated_node: cst.Assign + self, + original_node: cst.Assign, + updated_node: cst.Assign, ) -> Union[cst.Assign, cst.AnnAssign]: self.current_assign = None @@ -855,13 +993,17 @@ def leave_Assign( return self._annotate_single_target(original_node, updated_node) def leave_ImportFrom( - self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + self, + original_node: cst.ImportFrom, + updated_node: cst.ImportFrom, ) -> cst.ImportFrom: self.import_statements.append(original_node) return updated_node def leave_Module( - self, original_node: cst.Module, updated_node: cst.Module + self, + original_node: cst.Module, + updated_node: cst.Module, ) -> cst.Module: fresh_class_definitions = [ definition