diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index d5f8a2d5..9bcd9816 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -21,17 +21,17 @@ def ast_to_offset(node: ast.expr | ast.stmt) -> Offset: def is_name_attr( node: ast.AST, imports: dict[str, set[str]], - mod: str, + mods: tuple[str, ...], names: Container[str], ) -> bool: return ( isinstance(node, ast.Name) and node.id in names and - node.id in imports[mod] + any(node.id in imports[mod] for mod in mods) ) or ( isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and - node.value.id == mod and + node.value.id in mods and node.attr in names ) diff --git a/pyupgrade/_main.py b/pyupgrade/_main.py index c0b03c3a..b1fc43b1 100644 --- a/pyupgrade/_main.py +++ b/pyupgrade/_main.py @@ -26,6 +26,7 @@ from pyupgrade._ast_helpers import ast_to_offset from pyupgrade._ast_helpers import contains_await from pyupgrade._ast_helpers import has_starargs +from pyupgrade._ast_helpers import is_name_attr from pyupgrade._data import FUNCS from pyupgrade._data import Settings from pyupgrade._data import Version @@ -550,21 +551,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self._from_imports[node.module].add(name.name) self.generic_visit(node) - def _is_attr(self, node: ast.AST, mods: set[str], name: str) -> bool: - return ( - ( - isinstance(node, ast.Name) and - node.id == name and - any(name in self._from_imports[mod] for mod in mods) - ) or - ( - isinstance(node, ast.Attribute) and - node.attr == name and - isinstance(node.value, ast.Name) and - node.value.id in mods - ) - ) - def _parse(self, node: ast.Call) -> tuple[DotFormatPart, ...] | None: if not ( isinstance(node.func, ast.Attribute) and @@ -623,8 +609,11 @@ def visit_Assign(self, node: ast.Assign) -> None: not has_starargs(node.value) ): if ( - self._is_attr( - node.value.func, {'typing'}, 'NamedTuple', + is_name_attr( + node.value.func, + self._from_imports, + ('typing',), + ('NamedTuple',), ) and len(node.value.args) == 2 and not node.value.keywords and @@ -641,10 +630,11 @@ def visit_Assign(self, node: ast.Assign) -> None: ): self.named_tuples[ast_to_offset(node)] = node.value elif ( - self._is_attr( + is_name_attr( node.value.func, - {'typing', 'typing_extensions'}, - 'TypedDict', + self._from_imports, + ('typing', 'typing_extensions'), + ('TypedDict',), ) and len(node.value.args) == 1 and len(node.value.keywords) > 0 and @@ -655,10 +645,11 @@ def visit_Assign(self, node: ast.Assign) -> None: ): self.kw_typed_dicts[ast_to_offset(node)] = node.value elif ( - self._is_attr( + is_name_attr( node.value.func, - {'typing', 'typing_extensions'}, - 'TypedDict', + self._from_imports, + ('typing', 'typing_extensions'), + ('TypedDict',), ) and len(node.value.args) == 2 and ( diff --git a/pyupgrade/_plugins/lru_cache.py b/pyupgrade/_plugins/lru_cache.py index b0fff959..bccdf701 100644 --- a/pyupgrade/_plugins/lru_cache.py +++ b/pyupgrade/_plugins/lru_cache.py @@ -64,7 +64,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'functools', + ('functools',), ('lru_cache',), ) ): diff --git a/pyupgrade/_plugins/native_literals.py b/pyupgrade/_plugins/native_literals.py index 21f2ac6d..fce35cbe 100644 --- a/pyupgrade/_plugins/native_literals.py +++ b/pyupgrade/_plugins/native_literals.py @@ -36,7 +36,7 @@ def is_a_native_literal_call( ) -> bool: return ( ( - is_name_attr(node.func, from_imports, 'six', SIX_NATIVE_STR) or + is_name_attr(node.func, from_imports, ('six',), SIX_NATIVE_STR) or isinstance(node.func, ast.Name) and node.func.id == 'str' ) and not node.keywords and diff --git a/pyupgrade/_plugins/six_base_classes.py b/pyupgrade/_plugins/six_base_classes.py index 9a7a0e65..f24f3704 100644 --- a/pyupgrade/_plugins/six_base_classes.py +++ b/pyupgrade/_plugins/six_base_classes.py @@ -21,5 +21,5 @@ def visit_ClassDef( ) -> Iterable[tuple[Offset, TokenFunc]]: if state.settings.min_version >= (3,): for base in node.bases: - if is_name_attr(base, state.from_imports, 'six', ('Iterator',)): + if is_name_attr(base, state.from_imports, ('six',), ('Iterator',)): yield ast_to_offset(base), remove_base_class diff --git a/pyupgrade/_plugins/six_calls.py b/pyupgrade/_plugins/six_calls.py index 79b2a66d..4d149da8 100644 --- a/pyupgrade/_plugins/six_calls.py +++ b/pyupgrade/_plugins/six_calls.py @@ -86,7 +86,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'six', + ('six',), ('iteritems', 'iterkeys', 'itervalues'), ) and node.args and @@ -102,7 +102,12 @@ def visit_Call( ) yield ast_to_offset(node), func elif ( - is_name_attr(node.func, state.from_imports, 'six', SIX_CALLS) and + is_name_attr( + node.func, + state.from_imports, + ('six',), + SIX_CALLS, + ) and node.args and not has_starargs(node) ): @@ -120,7 +125,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'six', + ('six',), ('int2byte',), ) and node.args and @@ -136,7 +141,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'six', + ('six',), ('b', 'ensure_binary'), ) and not node.keywords and @@ -150,7 +155,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'six', + ('six',), ('raise_from',), ) and node.args and @@ -166,7 +171,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'six', + ('six',), ('reraise',), ) ): @@ -198,7 +203,7 @@ def visit_Call( is_name_attr( node.args[0].value.func, state.from_imports, - 'sys', + ('sys',), ('exc_info',), ) ): diff --git a/pyupgrade/_plugins/six_metaclasses.py b/pyupgrade/_plugins/six_metaclasses.py index 7692df6a..ea7f164f 100644 --- a/pyupgrade/_plugins/six_metaclasses.py +++ b/pyupgrade/_plugins/six_metaclasses.py @@ -85,7 +85,7 @@ def visit_ClassDef( is_name_attr( decorator.func, state.from_imports, - 'six', + ('six',), ('add_metaclass',), ) and not has_starargs(decorator) @@ -98,7 +98,7 @@ def visit_ClassDef( is_name_attr( node.bases[0].func, state.from_imports, - 'six', + ('six',), ('with_metaclass',), ) and not has_starargs(node.bases[0]) diff --git a/pyupgrade/_plugins/six_remove_decorators.py b/pyupgrade/_plugins/six_remove_decorators.py index 6b75102d..f1ac1483 100644 --- a/pyupgrade/_plugins/six_remove_decorators.py +++ b/pyupgrade/_plugins/six_remove_decorators.py @@ -24,7 +24,7 @@ def visit_ClassDef( if is_name_attr( decorator, state.from_imports, - 'six', + ('six',), ('python_2_unicode_compatible',), ): yield ast_to_offset(decorator), remove_decorator diff --git a/pyupgrade/_plugins/subprocess_run.py b/pyupgrade/_plugins/subprocess_run.py index 53cc03d3..9d1a3f2c 100644 --- a/pyupgrade/_plugins/subprocess_run.py +++ b/pyupgrade/_plugins/subprocess_run.py @@ -72,7 +72,7 @@ def visit_Call( is_name_attr( node.func, state.from_imports, - 'subprocess', + ('subprocess',), ('run',), ) ): @@ -84,14 +84,14 @@ def visit_Call( if keyword.arg == 'stdout' and is_name_attr( keyword.value, state.from_imports, - 'subprocess', + ('subprocess',), ('PIPE',), ): stdout_idx = n elif keyword.arg == 'stderr' and is_name_attr( keyword.value, state.from_imports, - 'subprocess', + ('subprocess',), ('PIPE',), ): stderr_idx = n diff --git a/pyupgrade/_plugins/typing_pep604.py b/pyupgrade/_plugins/typing_pep604.py index bb33d141..1a2ec3e6 100644 --- a/pyupgrade/_plugins/typing_pep604.py +++ b/pyupgrade/_plugins/typing_pep604.py @@ -155,9 +155,14 @@ def visit_Subscript( ): return - if is_name_attr(node.value, state.from_imports, 'typing', ('Optional',)): + if is_name_attr( + node.value, + state.from_imports, + ('typing',), + ('Optional',), + ): yield ast_to_offset(node), _fix_optional - elif is_name_attr(node.value, state.from_imports, 'typing', ('Union',)): + elif is_name_attr(node.value, state.from_imports, ('typing',), ('Union',)): if sys.version_info >= (3, 9): # pragma: >=3.9 cover node_slice = node.slice elif isinstance(node.slice, ast.Index): # pragma: <3.9 cover diff --git a/pyupgrade/_plugins/versioned_branches.py b/pyupgrade/_plugins/versioned_branches.py index cacf2f2d..eead85b6 100644 --- a/pyupgrade/_plugins/versioned_branches.py +++ b/pyupgrade/_plugins/versioned_branches.py @@ -112,7 +112,12 @@ def visit_If( if ( min_version >= (3,) and ( # if six.PY2: - is_name_attr(node.test, state.from_imports, 'six', ('PY2',)) or + is_name_attr( + node.test, + state.from_imports, + ('six',), + ('PY2',), + ) or # if not six.PY3: ( isinstance(node.test, ast.UnaryOp) and @@ -120,7 +125,7 @@ def visit_If( is_name_attr( node.test.operand, state.from_imports, - 'six', + ('six',), ('PY3',), ) ) or @@ -131,7 +136,7 @@ def visit_If( is_name_attr( node.test.left, state.from_imports, - 'sys', + ('sys',), ('version_info',), ) and len(node.test.ops) == 1 and ( @@ -150,7 +155,12 @@ def visit_If( elif ( min_version >= (3,) and ( # if six.PY3: - is_name_attr(node.test, state.from_imports, 'six', ('PY3',)) or + is_name_attr( + node.test, + state.from_imports, + ('six',), + ('PY3',), + ) or # if not six.PY2: ( isinstance(node.test, ast.UnaryOp) and @@ -158,7 +168,7 @@ def visit_If( is_name_attr( node.test.operand, state.from_imports, - 'six', + ('six',), ('PY2',), ) ) or @@ -170,7 +180,7 @@ def visit_If( is_name_attr( node.test.left, state.from_imports, - 'sys', + ('sys',), ('version_info',), ) and len(node.test.ops) == 1 and (