diff --git a/flake8_trio.py b/flake8_trio.py index 2088ed8..12c5018 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -19,15 +19,30 @@ Error_codes = { "TRIO100": "{} context contains no checkpoints, add `await trio.sleep(0)`", - "TRIO101": "yield inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling", - "TRIO102": "await inside {0.name} on line {0.lineno} must have shielded cancel scope with a timeout", + "TRIO101": ( + "yield inside a nursery or cancel scope is only safe when implementing " + "a context manager - otherwise, it breaks exception handling" + ), + "TRIO102": ( + "await inside {0.name} on line {0.lineno} must have shielded cancel " + "scope with a timeout" + ), "TRIO103": "{} block with a code path that doesn't re-raise the error", "TRIO104": "Cancelled (and therefore BaseException) must be re-raised", "TRIO105": "trio async function {} must be immediately awaited", "TRIO106": "trio must be imported with `import trio` for the linter to work", - "TRIO107": "{0} from async function with no guaranteed checkpoint or exception since function definition on line {1.lineno}", - "TRIO108": "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno}", - "TRIO109": "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead", + "TRIO107": ( + "{0} from async function with no guaranteed checkpoint or exception " + "since function definition on line {1.lineno}" + ), + "TRIO108": ( + "{0} from async iterable with no guaranteed checkpoint since {1.name} " + "on line {1.lineno}" + ), + "TRIO109": ( + "Async function definition with a `timeout` parameter - use " + "`trio.[fail/move_on]_[after/at]` instead" + ), "TRIO110": "`while : await trio.sleep()` should be replaced by a `trio.Event`.", } @@ -35,11 +50,17 @@ class Statement(NamedTuple): name: str lineno: int - col_offset: int = 0 + col_offset: int = -1 - # ignore col offset since many tests don't supply that def __eq__(self, other: Any) -> bool: - return isinstance(other, Statement) and self[:2] == other[:2] + return ( + isinstance(other, Statement) + and self[:2] == other[:2] + and ( + self.col_offset == other.col_offset + or -1 in (self.col_offset, other.col_offset) + ) + ) HasLineInfo = Union[ast.expr, ast.stmt, ast.arg, ast.excepthandler, Statement] @@ -51,9 +72,12 @@ def __init__(self, node: ast.Call, funcname: str): self.funcname = funcname self.variable_name: Optional[str] = None self.shielded: bool = False - self.has_timeout: bool = False + self.has_timeout: bool = True + + # scope.shield is assigned to in visit_Assign if self.funcname == "CancelScope": + self.has_timeout = False for kw in node.keywords: # Only accepts constant values if kw.arg == "shield" and isinstance(kw.value, ast.Constant): @@ -61,8 +85,6 @@ def __init__(self, node: ast.Call, funcname: str): # sets to True even if timeout is explicitly set to inf if kw.arg == "deadline": self.has_timeout = True - else: - self.has_timeout = True def __str__(self): # Not supporting other ways of importing trio, per TRIO106 @@ -298,16 +320,15 @@ def has_exception(node: Optional[ast.expr]) -> str: if node.type is None: return Statement("bare except", node.lineno, node.col_offset) # several exceptions - elif isinstance(node.type, ast.Tuple): + if isinstance(node.type, ast.Tuple): for element in node.type.elts: name = has_exception(element) if name: return Statement(name, element.lineno, element.col_offset) # single exception, either a Name or an Attribute - else: - name = has_exception(node.type) - if name: - return Statement(name, node.type.lineno, node.type.col_offset) + name = has_exception(node.type) + if name: + return Statement(name, node.type.lineno, node.type.col_offset) return None diff --git a/tests/test_changelog_and_version.py b/tests/test_changelog_and_version.py index 9ce7962..be88c4b 100644 --- a/tests/test_changelog_and_version.py +++ b/tests/test_changelog_and_version.py @@ -20,7 +20,7 @@ def from_string(cls, string): def get_releases() -> Iterable[Version]: valid_pattern = re.compile(r"^## (\d\d\.\d?\d\.\d?\d)$") - with open(Path(__file__).parent.parent / "CHANGELOG.md") as f: + with open(Path(__file__).parent.parent / "CHANGELOG.md", encoding="utf-8") as f: lines = f.readlines() for aline in lines: version_match = valid_pattern.match(aline) @@ -54,7 +54,7 @@ def runTest(self): "CHANGELOG.md", "README.md", ): - with open(Path(__file__).parent.parent / filename) as f: + with open(Path(__file__).parent.parent / filename, encoding="utf-8") as f: lines = f.readlines() documented_errors[filename] = set() for line in lines: diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 2165fdd..7fa00df 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -28,9 +28,8 @@ class ParseError(Exception): ... -@pytest.mark.parametrize("test, path", test_files) -def test_eval(test: str, path: str): - # version check +# check for presence of _pyXX, skip if version is later, and prune parameter +def check_version(test: str) -> str: python_version = re.search(r"(?<=_PY)\d*", test) if python_version: version_str = python_version.group() @@ -38,13 +37,20 @@ def test_eval(test: str, path: str): v_i = sys.version_info if (v_i.major, v_i.minor) < (int(major), int(minor)): raise unittest.SkipTest("v_i, major, minor") - test = test.split("_")[0] + return test.split("_")[0] + return test + + +@pytest.mark.parametrize("test, path", test_files) +def test_eval(test: str, path: str): + # version check + test = check_version(test) assert test in Error_codes.keys(), "error code not defined in flake8_trio.py" include = [test] expected: List[Error] = [] - with open(os.path.join("tests", path)) as file: + with open(os.path.join("tests", path), encoding="utf-8") as file: lines = file.readlines() for lineno, line in enumerate(lines, start=1): @@ -191,7 +197,7 @@ def assert_correct_lines_and_codes(errors: Iterable[Error], expected: Iterable[E for line in all_lines: if error_dict[line] == expected_dict[line]: continue - for code in {*error_dict[line], *expected_dict[line]}: + for code in sorted({*error_dict[line], *expected_dict[line]}): if not any_error: print( "Lines with different # of errors:", diff --git a/tests/trio103.py b/tests/trio103.py index c304635..a09aefd 100644 --- a/tests/trio103.py +++ b/tests/trio103.py @@ -15,15 +15,6 @@ except trio.Cancelled: # error: 7, "trio.Cancelled" pass -# raise different exception -except BaseException: - raise ValueError() # TRIO104 -except trio.Cancelled as e: - raise ValueError() from e # TRIO104 -except trio.Cancelled as e: - # see https://github.com/Zac-HD/flake8-trio/pull/8#discussion_r932737341 - raise BaseException() from e # TRIO104 - # if except BaseException as e: # error: 7, "BaseException" if True: