Skip to content

Commit

Permalink
change is_name_attr to accept multiple modules
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Apr 6, 2022
1 parent 24f1e90 commit 643db55
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 50 deletions.
6 changes: 3 additions & 3 deletions pyupgrade/_ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def ast_to_offset(node: ast.expr | ast.stmt) -> Offset:
def is_name_attr(
node: ast.AST,
imports: dict[str, set[str]],
mod: str,
mods: tuple[str, ...],
names: Container[str],
) -> bool:
return (
isinstance(node, ast.Name) and
node.id in names and
node.id in imports[mod]
any(node.id in imports[mod] for mod in mods)
) or (
isinstance(node, ast.Attribute) and
isinstance(node.value, ast.Name) and
node.value.id == mod and
node.value.id in mods and
node.attr in names
)

Expand Down
37 changes: 14 additions & 23 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._ast_helpers import has_starargs
from pyupgrade._ast_helpers import is_name_attr
from pyupgrade._data import FUNCS
from pyupgrade._data import Settings
from pyupgrade._data import Version
Expand Down Expand Up @@ -550,21 +551,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self._from_imports[node.module].add(name.name)
self.generic_visit(node)

def _is_attr(self, node: ast.AST, mods: set[str], name: str) -> bool:
return (
(
isinstance(node, ast.Name) and
node.id == name and
any(name in self._from_imports[mod] for mod in mods)
) or
(
isinstance(node, ast.Attribute) and
node.attr == name and
isinstance(node.value, ast.Name) and
node.value.id in mods
)
)

def _parse(self, node: ast.Call) -> tuple[DotFormatPart, ...] | None:
if not (
isinstance(node.func, ast.Attribute) and
Expand Down Expand Up @@ -623,8 +609,11 @@ def visit_Assign(self, node: ast.Assign) -> None:
not has_starargs(node.value)
):
if (
self._is_attr(
node.value.func, {'typing'}, 'NamedTuple',
is_name_attr(
node.value.func,
self._from_imports,
('typing',),
('NamedTuple',),
) and
len(node.value.args) == 2 and
not node.value.keywords and
Expand All @@ -641,10 +630,11 @@ def visit_Assign(self, node: ast.Assign) -> None:
):
self.named_tuples[ast_to_offset(node)] = node.value
elif (
self._is_attr(
is_name_attr(
node.value.func,
{'typing', 'typing_extensions'},
'TypedDict',
self._from_imports,
('typing', 'typing_extensions'),
('TypedDict',),
) and
len(node.value.args) == 1 and
len(node.value.keywords) > 0 and
Expand All @@ -655,10 +645,11 @@ def visit_Assign(self, node: ast.Assign) -> None:
):
self.kw_typed_dicts[ast_to_offset(node)] = node.value
elif (
self._is_attr(
is_name_attr(
node.value.func,
{'typing', 'typing_extensions'},
'TypedDict',
self._from_imports,
('typing', 'typing_extensions'),
('TypedDict',),
) and
len(node.value.args) == 2 and
(
Expand Down
2 changes: 1 addition & 1 deletion pyupgrade/_plugins/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'functools',
('functools',),
('lru_cache',),
)
):
Expand Down
2 changes: 1 addition & 1 deletion pyupgrade/_plugins/native_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def is_a_native_literal_call(
) -> bool:
return (
(
is_name_attr(node.func, from_imports, 'six', SIX_NATIVE_STR) or
is_name_attr(node.func, from_imports, ('six',), SIX_NATIVE_STR) or
isinstance(node.func, ast.Name) and node.func.id == 'str'
) and
not node.keywords and
Expand Down
2 changes: 1 addition & 1 deletion pyupgrade/_plugins/six_base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def visit_ClassDef(
) -> Iterable[tuple[Offset, TokenFunc]]:
if state.settings.min_version >= (3,):
for base in node.bases:
if is_name_attr(base, state.from_imports, 'six', ('Iterator',)):
if is_name_attr(base, state.from_imports, ('six',), ('Iterator',)):
yield ast_to_offset(base), remove_base_class
19 changes: 12 additions & 7 deletions pyupgrade/_plugins/six_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'six',
('six',),
('iteritems', 'iterkeys', 'itervalues'),
) and
node.args and
Expand All @@ -102,7 +102,12 @@ def visit_Call(
)
yield ast_to_offset(node), func
elif (
is_name_attr(node.func, state.from_imports, 'six', SIX_CALLS) and
is_name_attr(
node.func,
state.from_imports,
('six',),
SIX_CALLS,
) and
node.args and
not has_starargs(node)
):
Expand All @@ -120,7 +125,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'six',
('six',),
('int2byte',),
) and
node.args and
Expand All @@ -136,7 +141,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'six',
('six',),
('b', 'ensure_binary'),
) and
not node.keywords and
Expand All @@ -150,7 +155,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'six',
('six',),
('raise_from',),
) and
node.args and
Expand All @@ -166,7 +171,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'six',
('six',),
('reraise',),
)
):
Expand Down Expand Up @@ -198,7 +203,7 @@ def visit_Call(
is_name_attr(
node.args[0].value.func,
state.from_imports,
'sys',
('sys',),
('exc_info',),
)
):
Expand Down
4 changes: 2 additions & 2 deletions pyupgrade/_plugins/six_metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def visit_ClassDef(
is_name_attr(
decorator.func,
state.from_imports,
'six',
('six',),
('add_metaclass',),
) and
not has_starargs(decorator)
Expand All @@ -98,7 +98,7 @@ def visit_ClassDef(
is_name_attr(
node.bases[0].func,
state.from_imports,
'six',
('six',),
('with_metaclass',),
) and
not has_starargs(node.bases[0])
Expand Down
2 changes: 1 addition & 1 deletion pyupgrade/_plugins/six_remove_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def visit_ClassDef(
if is_name_attr(
decorator,
state.from_imports,
'six',
('six',),
('python_2_unicode_compatible',),
):
yield ast_to_offset(decorator), remove_decorator
6 changes: 3 additions & 3 deletions pyupgrade/_plugins/subprocess_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def visit_Call(
is_name_attr(
node.func,
state.from_imports,
'subprocess',
('subprocess',),
('run',),
)
):
Expand All @@ -84,14 +84,14 @@ def visit_Call(
if keyword.arg == 'stdout' and is_name_attr(
keyword.value,
state.from_imports,
'subprocess',
('subprocess',),
('PIPE',),
):
stdout_idx = n
elif keyword.arg == 'stderr' and is_name_attr(
keyword.value,
state.from_imports,
'subprocess',
('subprocess',),
('PIPE',),
):
stderr_idx = n
Expand Down
9 changes: 7 additions & 2 deletions pyupgrade/_plugins/typing_pep604.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ def visit_Subscript(
):
return

if is_name_attr(node.value, state.from_imports, 'typing', ('Optional',)):
if is_name_attr(
node.value,
state.from_imports,
('typing',),
('Optional',),
):
yield ast_to_offset(node), _fix_optional
elif is_name_attr(node.value, state.from_imports, 'typing', ('Union',)):
elif is_name_attr(node.value, state.from_imports, ('typing',), ('Union',)):
if sys.version_info >= (3, 9): # pragma: >=3.9 cover
node_slice = node.slice
elif isinstance(node.slice, ast.Index): # pragma: <3.9 cover
Expand Down
22 changes: 16 additions & 6 deletions pyupgrade/_plugins/versioned_branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,20 @@ def visit_If(
if (
min_version >= (3,) and (
# if six.PY2:
is_name_attr(node.test, state.from_imports, 'six', ('PY2',)) or
is_name_attr(
node.test,
state.from_imports,
('six',),
('PY2',),
) or
# if not six.PY3:
(
isinstance(node.test, ast.UnaryOp) and
isinstance(node.test.op, ast.Not) and
is_name_attr(
node.test.operand,
state.from_imports,
'six',
('six',),
('PY3',),
)
) or
Expand All @@ -131,7 +136,7 @@ def visit_If(
is_name_attr(
node.test.left,
state.from_imports,
'sys',
('sys',),
('version_info',),
) and
len(node.test.ops) == 1 and (
Expand All @@ -150,15 +155,20 @@ def visit_If(
elif (
min_version >= (3,) and (
# if six.PY3:
is_name_attr(node.test, state.from_imports, 'six', ('PY3',)) or
is_name_attr(
node.test,
state.from_imports,
('six',),
('PY3',),
) or
# if not six.PY2:
(
isinstance(node.test, ast.UnaryOp) and
isinstance(node.test.op, ast.Not) and
is_name_attr(
node.test.operand,
state.from_imports,
'six',
('six',),
('PY2',),
)
) or
Expand All @@ -170,7 +180,7 @@ def visit_If(
is_name_attr(
node.test.left,
state.from_imports,
'sys',
('sys',),
('version_info',),
) and
len(node.test.ops) == 1 and (
Expand Down

0 comments on commit 643db55

Please sign in to comment.