From 1c1d3959c3e4c929c84cb563bf3d9bb9dd56cfb8 Mon Sep 17 00:00:00 2001 From: Amar1729 Date: Sun, 28 Jul 2024 14:22:18 -0400 Subject: [PATCH] fix: check for parent ExceptHandler of Raise statements Introduces a `walk.walk_dfs` for DFS traversal of a node. This allows us to more easily keep track of direct parents of nested children, so that we can find an ExceptHandler that may be containing any Raise statement. Update tests for walk_dfs and some more complex edge cases for try/except blocks. --- pydoclint/utils/return_yield_raise.py | 31 ++++++++++++--- pydoclint/utils/walk.py | 7 ++++ tests/utils/test_returns_yields_raise.py | 50 +++++++++++++++++++++++- tests/utils/test_walk.py | 48 ++++++++++++++++++++++- 4 files changed, 128 insertions(+), 8 deletions(-) diff --git a/pydoclint/utils/return_yield_raise.py b/pydoclint/utils/return_yield_raise.py index 020f97a..e8343de 100644 --- a/pydoclint/utils/return_yield_raise.py +++ b/pydoclint/utils/return_yield_raise.py @@ -1,5 +1,5 @@ import ast -from typing import Callable, Dict, Generator, List, Tuple, Type +from typing import Callable, Dict, Generator, List, Optional, Tuple, Type from pydoclint.utils import walk from pydoclint.utils.annotation import unparseAnnotation @@ -107,9 +107,15 @@ def _getRaisedExceptions( # key: child lineno, value: (parent lineno, is parent a function?) familyTree: Dict[int, Tuple[int, bool]] = {} - for child, parent in walk.walk(node): + currentParentExceptHandler: Optional[ast.ExceptHandler] = None + + # depth-first guarantees the last-seen exception handler is a parent of child. + for child, parent in walk.walk_dfs(node): childLineNum = _updateFamilyTree(child, parent, familyTree) + if isinstance(parent, ast.ExceptHandler): + currentParentExceptHandler = parent + if ( isinstance(child, ast.Raise) and isinstance( @@ -128,10 +134,23 @@ def _getRaisedExceptions( yield subnode.id break else: - # if "raise" statement was alone, generally parent is ast.ExceptHandler. - for subnode, _ in walk.walk(parent): - if isinstance(subnode, ast.Name): - yield subnode.id + # if "raise" statement was alone, it must be inside an "except" + if currentParentExceptHandler: + yield from _extractExceptionsFromExcept( + currentParentExceptHandler, + ) + + +def _extractExceptionsFromExcept( + node: ast.ExceptHandler, +) -> Generator[str, None, None]: + if isinstance(node.type, ast.Name): + yield node.type.id + + if isinstance(node.type, ast.Tuple): + for child, _ in walk.walk(node.type): + if isinstance(child, ast.Name): + yield child.id def _hasExpectedStatements( diff --git a/pydoclint/utils/walk.py b/pydoclint/utils/walk.py index 3c16d0a..a2c07b6 100644 --- a/pydoclint/utils/walk.py +++ b/pydoclint/utils/walk.py @@ -32,6 +32,13 @@ def walk(node): yield node, parent +def walk_dfs(node): + """Depth-first traversal of AST. Modified from walk.walk, above.""" + for child, parent in iter_child_nodes(node): + yield (child, parent) + yield from walk_dfs(child) + + def iter_child_nodes(node): """ Yield all direct child nodes of *node*, that is, all fields that are nodes diff --git a/tests/utils/test_returns_yields_raise.py b/tests/utils/test_returns_yields_raise.py index 7387d2e..a49a9c4 100644 --- a/tests/utils/test_returns_yields_raise.py +++ b/tests/utils/test_returns_yields_raise.py @@ -345,6 +345,43 @@ def func7(arg0): 1 / 0 except ZeroDivisionError: raise RuntimeError("a different error") + + try: + pass + except OSError as e: + if e.args[0] == 2 and e.filename: + fp = None + else: + raise + +def func8(d): + try: + d[0][0] + except (KeyError, TypeError): + raise + finally: + pass + +def func9(d): + try: + d[0] + except IndexError: + try: + d[0][0] + except KeyError: + pass + except Exception: + pass + if True: + raise + +def func10(): + # no variable resolution is done. this function looks like it throws "GError". + GError = ZeroDivisionError + try: + 1 / 0 + except GError: + raise """ @@ -363,6 +400,9 @@ def testHasRaiseStatements() -> None: (26, 0, 'func6'): True, (21, 4, 'func5_child1'): True, (32, 0, 'func7'): True, + (54, 0, 'func8'): True, + (62, 0, 'func9'): True, + (75, 0, 'func10'): True, } assert result == expected @@ -382,7 +422,15 @@ def testWhichRaiseStatements() -> None: (20, 0, 'func5'): [], (26, 0, 'func6'): ['TypeError'], (21, 4, 'func5_child1'): ['ValueError'], - (32, 0, 'func7'): ['IndexError', 'RuntimeError', 'TypeError'], + (32, 0, 'func7'): [ + 'IndexError', + 'OSError', + 'RuntimeError', + 'TypeError', + ], + (54, 0, 'func8'): ['KeyError', 'TypeError'], + (62, 0, 'func9'): ['IndexError'], + (75, 0, 'func10'): ['GError'], } assert result == expected diff --git a/tests/utils/test_walk.py b/tests/utils/test_walk.py index 069c205..3d99040 100644 --- a/tests/utils/test_walk.py +++ b/tests/utils/test_walk.py @@ -3,7 +3,7 @@ import pytest -from pydoclint.utils.walk import walk +from pydoclint.utils.walk import walk, walk_dfs src1 = """ def func1(): @@ -96,3 +96,49 @@ def testWalk(src: str, expected: List[Tuple[str, str]]) -> None: result.append((node.name, parent_repr)) assert result == expected + + +@pytest.mark.parametrize( + 'src, expected', + [ + ( + src1, + [ + ('func1', 'ast.Module'), + ('func2', 'ast.Module'), + ('func3', 'ast.Module'), + ('func3_child1', 'func3'), + ('func3_child1_grandchild1', 'func3_child1'), + ('func4', 'ast.Module'), + ('func4_child1', 'func4'), + ('func4_child2', 'func4'), + ('func4_child2_grandchild1', 'func4_child2'), + ('func5', 'ast.Module'), + ('MyClass', 'ast.Module'), + ('__init__', 'MyClass'), + ('method1', 'MyClass'), + ('method1_child1', 'method1'), + ('classmethod1', 'MyClass'), + ('classmethod1_child1', 'classmethod1'), + ], + ), + ], +) +def testWalkDfs(src: str, expected: List[Tuple[str, str]]) -> None: + result: List[Tuple[str, str]] = [] + tree = ast.parse(src) + for node, parent in walk_dfs(tree): + if 'name' in node.__dict__: + parent_repr: str + if isinstance(parent, ast.Module): + parent_repr = 'ast.Module' + elif isinstance( + parent, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef) + ): + parent_repr = parent.name + else: + parent_repr = str(type(parent)) + + result.append((node.name, parent_repr)) + + assert result == expected