Skip to content

Commit

Permalink
Add test skip decorator fixer (#498)
Browse files Browse the repository at this point in the history
Start of #364.
  • Loading branch information
adamchainz authored Oct 10, 2024
1 parent 3b1421f commit 666a6ce
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 37 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@ Changelog

Thanks to Tobias Funke for the report in `Issue #495 <https://github.com/adamchainz/django-upgrade/issues/495>`__.

* Add all-version fixer to remove outdated test skip decorators.
For example:

.. code-block:: diff
import django
from django.test import TestCase
class ExampleTests(TestCase):
- @unittest.skipIf(django.VERSION < (5, 1), "Django 5.1+")
def test_one(self):
...
- @unittest.skipUnless(django.VERSION >= (5, 1), "Django 5.1+")
def test_two(self):
...
`Issue #364 <https://github.com/adamchainz/django-upgrade/issues/364>`__.

* Drop Python 3.8 support.

* Support Python 3.13.
Expand Down
40 changes: 40 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ A single ``else`` block may be present, but ``elif`` is not supported.
See also `pyupgrade’s similar feature <https://github.com/asottile/pyupgrade/#python2-and-old-python3x-blocks>`__ that removes outdated code from checks on the Python version.

Versioned test skip decorators
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

**Name:** ``versioned_test_skip_decorators``

Removes outdated test skip decorators that compare to ``django.VERSION``.
Like the above, it requires comparisons of the form:

.. code-block:: text
django.VERSION <comparator> (<X>, <Y>)
Supports these test skip decorators:

* |unittest.skipIf|__:

.. |unittest.skipIf| replace:: ``@unittest.skipIf``
__ https://docs.python.org/3/library/unittest.html#unittest.skipIf

* |unittest.skipUnless|__:

.. |unittest.skipUnless| replace:: ``@unittest.skipUnless``
__ https://docs.python.org/3/library/unittest.html#unittest.skipUnless

For example:

.. code-block:: diff
import django
from django.test import TestCase
class ExampleTests(TestCase):
- @unittest.skipIf(django.VERSION < (5, 1), "Django 5.1+")
def test_one(self):
...
- @unittest.skipUnless(django.VERSION >= (5, 1), "Django 5.1+")
def test_two(self):
...
Django 5.1
----------

Expand Down
42 changes: 42 additions & 0 deletions src/django_upgrade/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import ast
import warnings
from typing import Literal
from typing import cast

from tokenize_rt import Offset

from django_upgrade.data import State


def ast_parse(contents_text: str) -> ast.Module:
# intentionally ignore warnings, we can't do anything about them
Expand Down Expand Up @@ -50,3 +53,42 @@ def looks_like_test_client_call(
and isinstance(node.func.value.value, ast.Name)
and node.func.value.value.id == "self"
)


def is_passing_comparison(
test: ast.Compare, state: State
) -> Literal["pass", "fail", None]:
"""
Return whether the given ast.Compare node compares a version tuple with
django.VERSION and would pass or fail for the current target version, or
None if no match or cannot determine.
"""
if not (
isinstance(left := test.left, ast.Attribute)
and isinstance(left.value, ast.Name)
and left.value.id == "django"
and left.attr == "VERSION"
and len(test.ops) == 1
and isinstance(test.ops[0], (ast.Gt, ast.GtE, ast.Lt, ast.LtE))
and len(test.comparators) == 1
and isinstance((comparator := test.comparators[0]), ast.Tuple)
and len(comparator.elts) == 2
and all(isinstance(e, ast.Constant) for e in comparator.elts)
and all(isinstance(cast(ast.Constant, e).value, int) for e in comparator.elts)
):
return None

min_version = tuple(cast(ast.Constant, e).value for e in comparator.elts)
if isinstance(test.ops[0], ast.Gt):
if state.settings.target_version > min_version:
return "pass"
elif isinstance(test.ops[0], ast.GtE):
if state.settings.target_version >= min_version:
return "pass"
elif isinstance(test.ops[0], ast.Lt):
if state.settings.target_version >= min_version:
return "fail"
else: # ast.LtE
if state.settings.target_version > min_version:
return "fail"
return None
42 changes: 5 additions & 37 deletions src/django_upgrade/fixers/versioned_branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from collections.abc import Iterable
from functools import partial
from typing import Literal
from typing import cast

from tokenize_rt import Offset
from tokenize_rt import Token

from django_upgrade.ast import ast_start_offset
from django_upgrade.ast import is_passing_comparison
from django_upgrade.data import Fixer
from django_upgrade.data import State
from django_upgrade.data import TokenFunc
Expand All @@ -37,52 +37,20 @@ def visit_If(
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
isinstance(node.test, ast.Compare)
and isinstance(left := node.test.left, ast.Attribute)
and isinstance(left.value, ast.Name)
and left.value.id == "django"
and left.attr == "VERSION"
and (keep_branch := _is_passing_comparison(node.test, state)) is not None
and (pass_fail := is_passing_comparison(node.test, state)) is not None
and (
# do not handle 'if ... elif ...'
not node.orelse
or not isinstance(node.orelse[0], ast.If)
)
):
yield ast_start_offset(node), partial(
_fix_block, node=node, keep_branch=keep_branch
_fix_block,
node=node,
keep_branch=("first" if pass_fail == "pass" else "second"),
)


def _is_passing_comparison(
test: ast.Compare, state: State
) -> Literal["first", "second", None]:
if not (
len(test.ops) == 1
and isinstance(test.ops[0], (ast.Gt, ast.GtE, ast.Lt, ast.LtE))
and len(test.comparators) == 1
and isinstance((comparator := test.comparators[0]), ast.Tuple)
and len(comparator.elts) == 2
and all(isinstance(e, ast.Constant) for e in comparator.elts)
and all(isinstance(cast(ast.Constant, e).value, int) for e in comparator.elts)
):
return None

min_version = tuple(cast(ast.Constant, e).value for e in comparator.elts)
if isinstance(test.ops[0], ast.Gt):
if state.settings.target_version > min_version:
return "first"
elif isinstance(test.ops[0], ast.GtE):
if state.settings.target_version >= min_version:
return "first"
elif isinstance(test.ops[0], ast.Lt):
if state.settings.target_version >= min_version:
return "second"
else: # ast.LtE
if state.settings.target_version > min_version:
return "second"
return None


def _fix_block(
tokens: list[Token],
i: int,
Expand Down
60 changes: 60 additions & 0 deletions src/django_upgrade/fixers/versioned_test_skip_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import ast
from collections.abc import Iterable
from functools import partial

from tokenize_rt import Offset
from tokenize_rt import Token

from django_upgrade.ast import ast_start_offset
from django_upgrade.ast import is_passing_comparison
from django_upgrade.data import Fixer
from django_upgrade.data import State
from django_upgrade.data import TokenFunc
from django_upgrade.tokens import OP
from django_upgrade.tokens import erase_node
from django_upgrade.tokens import reverse_find

fixer = Fixer(
__name__,
min_version=(0, 0),
)


@fixer.register(ast.FunctionDef)
def visit_FunctionDef(
state: State,
node: ast.FunctionDef,
parents: tuple[ast.AST, ...],
) -> Iterable[tuple[Offset, TokenFunc]]:
for decorator in node.decorator_list:
if (
isinstance(decorator, ast.Call)
and isinstance(decorator.func, ast.Attribute)
and isinstance(decorator.func.value, ast.Name)
and decorator.func.value.id == "unittest"
and decorator.func.attr in ("skipIf", "skipUnless")
and len(decorator.args) == 2
and isinstance(decorator.args[0], ast.Compare)
and (
(pass_fail := is_passing_comparison(decorator.args[0], state))
is not None
)
and (
(decorator.func.attr == "skipIf" and pass_fail == "fail")
or (decorator.func.attr == "skipUnless" and pass_fail == "pass")
)
):
yield ast_start_offset(decorator), partial(erase_decorator, node=decorator)


def erase_decorator(
tokens: list[Token],
i: int,
*,
node: ast.Call,
) -> None:
erase_node(tokens, i, node=node)
at_j = reverse_find(tokens, i, name=OP, src="@")
del tokens[at_j:i]
127 changes: 127 additions & 0 deletions tests/fixers/test_versioned_test_skip_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

from functools import partial

from django_upgrade.data import Settings
from tests.fixers import tools

settings = Settings(target_version=(4, 1))
check_noop = partial(tools.check_noop, settings=settings)
check_transformed = partial(tools.check_transformed, settings=settings)


def test_unittest_skip_left():
check_noop(
"""\
import unittest
@unittest.skip("Always skipped")
def test_thing(self):
pass
""",
)


def test_unittest_skipIf_too_few_args():
check_noop(
"""\
import unittest
import django
@unittest.skipIf(django.VERSION < (4, 1))
def test_thing(self):
pass
""",
)


def test_unittest_skipIf_too_many_args():
check_noop(
"""\
import unittest
import django
@unittest.skipIf(django.VERSION < (4, 1), "Django 4.1+", "what is this arg?")
def test_thing(self):
pass
""",
)


def test_unittest_skipIf_passing_comparison():
check_noop(
"""\
import unittest
import django
@unittest.skipIf(django.VERSION < (4, 2), "Django 4.2+")
def test_thing(self):
pass
""",
)


def test_unittest_skipIf_unknown_comparison():
check_noop(
"""\
import unittest
import django
@unittest.skipIf(django.VERSION < (4, 1, 1), "Django 4.1.1+")
def test_thing(self):
pass
""",
)


def test_unittest_skipUnless_failing_comparison():
check_noop(
"""\
import unittest
import django
@unittest.skipUnless(django.VERSION >= (4, 2), "Django 4.2+")
def test_thing(self):
pass
""",
)


def test_unittest_skipIf_removed():
check_transformed(
"""\
import unittest
import django
@unittest.skipIf(django.VERSION < (4, 1), "Django 4.1+")
def test_thing(self):
pass
""",
"""\
import unittest
import django
def test_thing(self):
pass
""",
)


def test_skipUnless_removed():
check_transformed(
"""\
import unittest
import django
@unittest.skipUnless(django.VERSION >= (4, 1), "Django 4.1+")
def test_thing(self):
pass
""",
"""\
import unittest
import django
def test_thing(self):
pass
""",
)

0 comments on commit 666a6ce

Please sign in to comment.