diff --git a/stubdefaulter.py b/stubdefaulter.py index e4f3c0e..030549b 100644 --- a/stubdefaulter.py +++ b/stubdefaulter.py @@ -69,6 +69,14 @@ def leave_Param( return updated_node +def get_end_lineno(node: ast.FunctionDef | ast.AsyncFunctionDef) -> int: + if sys.version_info >= (3, 8): + assert node.end_lineno is not None + return node.end_lineno + else: + return max(child.lineno for child in ast.iter_child_nodes(node)) + + def replace_defaults_in_func( stub_lines: list[str], node: ast.FunctionDef | ast.AsyncFunctionDef, @@ -78,8 +86,8 @@ def replace_defaults_in_func( sig = inspect.signature(runtime_func) except Exception: return {} - assert node.end_lineno is not None - lines = stub_lines[node.lineno - 1 : node.end_lineno] + end_lineno = get_end_lineno(node) + lines = stub_lines[node.lineno - 1 : end_lineno] indentation = len(lines[0]) - len(lines[0].lstrip()) cst = libcst.parse_statement( textwrap.dedent("".join(line + "\n" for line in lines)) @@ -88,7 +96,7 @@ def replace_defaults_in_func( assert isinstance(modified, libcst.FunctionDef) new_code = textwrap.indent(libcst.Module(body=[modified]).code, " " * indentation) output_dict = {node.lineno - 1: new_code.splitlines()} - for i in range(node.lineno, node.end_lineno): + for i in range(node.lineno, end_lineno): output_dict[i] = [] return output_dict