Skip to content

Commit

Permalink
fix: check for parent ExceptHandler of Raise statements
Browse files Browse the repository at this point in the history
Introduces a `walk.walk_dfs` for DFS traversal of a node. This allows us
to more easily keep track of direct parents of nested children, so that
we can find an ExceptHandler that may be containing any Raise statement.

Update tests for walk_dfs and some more complex edge cases for
try/except blocks.
  • Loading branch information
Amar1729 committed Jul 28, 2024
1 parent 3d62c7d commit 1c1d395
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 8 deletions.
31 changes: 25 additions & 6 deletions pydoclint/utils/return_yield_raise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Callable, Dict, Generator, List, Tuple, Type
from typing import Callable, Dict, Generator, List, Optional, Tuple, Type

from pydoclint.utils import walk
from pydoclint.utils.annotation import unparseAnnotation
Expand Down Expand Up @@ -107,9 +107,15 @@ def _getRaisedExceptions(
# key: child lineno, value: (parent lineno, is parent a function?)
familyTree: Dict[int, Tuple[int, bool]] = {}

for child, parent in walk.walk(node):
currentParentExceptHandler: Optional[ast.ExceptHandler] = None

# depth-first guarantees the last-seen exception handler is a parent of child.
for child, parent in walk.walk_dfs(node):
childLineNum = _updateFamilyTree(child, parent, familyTree)

if isinstance(parent, ast.ExceptHandler):
currentParentExceptHandler = parent

if (
isinstance(child, ast.Raise)
and isinstance(
Expand All @@ -128,10 +134,23 @@ def _getRaisedExceptions(
yield subnode.id
break
else:
# if "raise" statement was alone, generally parent is ast.ExceptHandler.
for subnode, _ in walk.walk(parent):
if isinstance(subnode, ast.Name):
yield subnode.id
# if "raise" statement was alone, it must be inside an "except"
if currentParentExceptHandler:
yield from _extractExceptionsFromExcept(
currentParentExceptHandler,
)


def _extractExceptionsFromExcept(
node: ast.ExceptHandler,
) -> Generator[str, None, None]:
if isinstance(node.type, ast.Name):
yield node.type.id

if isinstance(node.type, ast.Tuple):
for child, _ in walk.walk(node.type):
if isinstance(child, ast.Name):
yield child.id


def _hasExpectedStatements(
Expand Down
7 changes: 7 additions & 0 deletions pydoclint/utils/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def walk(node):
yield node, parent


def walk_dfs(node):
"""Depth-first traversal of AST. Modified from walk.walk, above."""
for child, parent in iter_child_nodes(node):
yield (child, parent)
yield from walk_dfs(child)


def iter_child_nodes(node):
"""
Yield all direct child nodes of *node*, that is, all fields that are nodes
Expand Down
50 changes: 49 additions & 1 deletion tests/utils/test_returns_yields_raise.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,43 @@ def func7(arg0):
1 / 0
except ZeroDivisionError:
raise RuntimeError("a different error")
try:
pass
except OSError as e:
if e.args[0] == 2 and e.filename:
fp = None
else:
raise
def func8(d):
try:
d[0][0]
except (KeyError, TypeError):
raise
finally:
pass
def func9(d):
try:
d[0]
except IndexError:
try:
d[0][0]
except KeyError:
pass
except Exception:
pass
if True:
raise
def func10():
# no variable resolution is done. this function looks like it throws "GError".
GError = ZeroDivisionError
try:
1 / 0
except GError:
raise
"""


Expand All @@ -363,6 +400,9 @@ def testHasRaiseStatements() -> None:
(26, 0, 'func6'): True,
(21, 4, 'func5_child1'): True,
(32, 0, 'func7'): True,
(54, 0, 'func8'): True,
(62, 0, 'func9'): True,
(75, 0, 'func10'): True,
}

assert result == expected
Expand All @@ -382,7 +422,15 @@ def testWhichRaiseStatements() -> None:
(20, 0, 'func5'): [],
(26, 0, 'func6'): ['TypeError'],
(21, 4, 'func5_child1'): ['ValueError'],
(32, 0, 'func7'): ['IndexError', 'RuntimeError', 'TypeError'],
(32, 0, 'func7'): [
'IndexError',
'OSError',
'RuntimeError',
'TypeError',
],
(54, 0, 'func8'): ['KeyError', 'TypeError'],
(62, 0, 'func9'): ['IndexError'],
(75, 0, 'func10'): ['GError'],
}

assert result == expected
48 changes: 47 additions & 1 deletion tests/utils/test_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from pydoclint.utils.walk import walk
from pydoclint.utils.walk import walk, walk_dfs

src1 = """
def func1():
Expand Down Expand Up @@ -96,3 +96,49 @@ def testWalk(src: str, expected: List[Tuple[str, str]]) -> None:
result.append((node.name, parent_repr))

assert result == expected


@pytest.mark.parametrize(
'src, expected',
[
(
src1,
[
('func1', 'ast.Module'),
('func2', 'ast.Module'),
('func3', 'ast.Module'),
('func3_child1', 'func3'),
('func3_child1_grandchild1', 'func3_child1'),
('func4', 'ast.Module'),
('func4_child1', 'func4'),
('func4_child2', 'func4'),
('func4_child2_grandchild1', 'func4_child2'),
('func5', 'ast.Module'),
('MyClass', 'ast.Module'),
('__init__', 'MyClass'),
('method1', 'MyClass'),
('method1_child1', 'method1'),
('classmethod1', 'MyClass'),
('classmethod1_child1', 'classmethod1'),
],
),
],
)
def testWalkDfs(src: str, expected: List[Tuple[str, str]]) -> None:
result: List[Tuple[str, str]] = []
tree = ast.parse(src)
for node, parent in walk_dfs(tree):
if 'name' in node.__dict__:
parent_repr: str
if isinstance(parent, ast.Module):
parent_repr = 'ast.Module'
elif isinstance(
parent, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef)
):
parent_repr = parent.name
else:
parent_repr = str(type(parent))

result.append((node.name, parent_repr))

assert result == expected

0 comments on commit 1c1d395

Please sign in to comment.