Skip to content

Commit

Permalink
Refactor PT017 into separate visitor (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-burst authored Jan 10, 2025
1 parent e1f1ab8 commit a5de985
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 34 deletions.
2 changes: 2 additions & 0 deletions flake8_pytest_style/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PatchVisitor,
RaisesVisitor,
TFunctionsVisitor,
TryExceptVisitor,
UnittestAssertionVisitor,
)

Expand All @@ -41,6 +42,7 @@ class PytestStylePlugin(Plugin[Config]):
ParametrizeVisitor,
RaisesVisitor,
TFunctionsVisitor,
TryExceptVisitor,
UnittestAssertionVisitor,
]

Expand Down
2 changes: 2 additions & 0 deletions flake8_pytest_style/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .patch import PatchVisitor
from .raises import RaisesVisitor
from .t_functions import TFunctionsVisitor
from .try_except import TryExceptVisitor

__all__ = (
'AssertionVisitor',
Expand All @@ -18,5 +19,6 @@
'PatchVisitor',
'RaisesVisitor',
'TFunctionsVisitor',
'TryExceptVisitor',
'UnittestAssertionVisitor',
)
31 changes: 0 additions & 31 deletions flake8_pytest_style/visitors/raises.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import ast
from typing import List, Optional

from flake8_plugin_utils import Visitor, is_none

from flake8_pytest_style.config import Config
from flake8_pytest_style.errors import (
AssertInExcept,
RaisesTooBroad,
RaisesWithMultipleStatements,
RaisesWithoutException,
Expand All @@ -21,11 +19,6 @@


class RaisesVisitor(Visitor[Config]):
def __init__(self, config: Optional[Config] = None) -> None:
super().__init__(config=config)
self._exception_names: List[str] = []
self._current_assert: Optional[ast.Assert] = None

def _check_raises_call(self, node: ast.Call) -> None:
"""
Checks for violations regarding `pytest.raises` call args (PT010 and PT011).
Expand Down Expand Up @@ -76,27 +69,3 @@ def visit_With(self, node: ast.With) -> None:
self._check_raises_with(node)

self.generic_visit(node)

def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
if node.name:
self._exception_names.append(node.name)
try:
self.generic_visit(node)
finally:
if node.name:
self._exception_names.pop()

def visit_Assert(self, node: ast.Assert) -> None:
self._current_assert = node
try:
self.visit(node.test)
finally:
self._current_assert = None

if node.msg:
self.visit(node.msg)

def visit_Name(self, node: ast.Name) -> None:
if self._current_assert:
if node.id in self._exception_names:
self.error_from_node(AssertInExcept, self._current_assert, name=node.id)
38 changes: 38 additions & 0 deletions flake8_pytest_style/visitors/try_except.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import ast
from typing import List, Optional

from flake8_plugin_utils import Visitor

from flake8_pytest_style.config import Config
from flake8_pytest_style.errors import AssertInExcept


class TryExceptVisitor(Visitor[Config]):
def __init__(self, config: Optional[Config] = None) -> None:
super().__init__(config=config)
self._exception_names: List[str] = []
self._current_assert: Optional[ast.Assert] = None

def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
if node.name:
self._exception_names.append(node.name)
try:
self.generic_visit(node)
finally:
if node.name:
self._exception_names.pop()

def visit_Assert(self, node: ast.Assert) -> None:
self._current_assert = node
try:
self.visit(node.test)
finally:
self._current_assert = None

if node.msg:
self.visit(node.msg)

def visit_Name(self, node: ast.Name) -> None:
if self._current_assert:
if node.id in self._exception_names:
self.error_from_node(AssertInExcept, self._current_assert, name=node.id)
8 changes: 5 additions & 3 deletions tests/test_PT017_assert_in_except.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from flake8_pytest_style.config import DEFAULT_CONFIG
from flake8_pytest_style.errors import AssertInExcept
from flake8_pytest_style.visitors import RaisesVisitor
from flake8_pytest_style.visitors import TryExceptVisitor


def test_ok():
Expand All @@ -17,7 +17,7 @@ def test_xxx():
1 / 0
assert e.value.message
"""
assert_not_error(RaisesVisitor, code, config=DEFAULT_CONFIG)
assert_not_error(TryExceptVisitor, code, config=DEFAULT_CONFIG)


def test_error():
Expand All @@ -28,4 +28,6 @@ def test_xxx():
except Exception as e:
assert e.message, 'blah blah'
"""
assert_error(RaisesVisitor, code, AssertInExcept, name='e', config=DEFAULT_CONFIG)
assert_error(
TryExceptVisitor, code, AssertInExcept, name='e', config=DEFAULT_CONFIG
)

0 comments on commit a5de985

Please sign in to comment.