Skip to content

Commit

Permalink
refactor: Simplify AST imports, stop using deprecated code from ast
Browse files Browse the repository at this point in the history
Issue #179: #179
  • Loading branch information
pawamoy committed Jul 15, 2023
1 parent 48a7162 commit 21d5832
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 434 deletions.
41 changes: 16 additions & 25 deletions src/griffe/agents/nodes/_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

from __future__ import annotations

from ast import AST
from ast import Assign as NodeAssign
from ast import AugAssign as NodeAugAssign
from ast import BinOp as NodeBinOp
from ast import Constant as NodeConstant
from ast import List as NodeList
from ast import Name as NodeName
from ast import Set as NodeSet
from ast import Starred as NodeStarred
from ast import Tuple as NodeTuple
import ast
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
Expand All @@ -27,47 +18,47 @@
logger = get_logger(__name__)


def _extract_constant(node: NodeConstant, parent: Module) -> list[str | Name]:
def _extract_constant(node: ast.Constant, parent: Module) -> list[str | Name]:
return [node.value]


def _extract_name(node: NodeName, parent: Module) -> list[str | Name]:
def _extract_name(node: ast.Name, parent: Module) -> list[str | Name]:
return [Name(node.id, partial(parent.resolve, node.id))]


def _extract_starred(node: NodeStarred, parent: Module) -> list[str | Name]:
def _extract_starred(node: ast.Starred, parent: Module) -> list[str | Name]:
return _extract(node.value, parent)


def _extract_sequence(node: NodeList | NodeSet | NodeTuple, parent: Module) -> list[str | Name]:
def _extract_sequence(node: ast.List | ast.Set | ast.Tuple, parent: Module) -> list[str | Name]:
sequence = []
for elt in node.elts:
sequence.extend(_extract(elt, parent))
return sequence


def _extract_binop(node: NodeBinOp, parent: Module) -> list[str | Name]:
def _extract_binop(node: ast.BinOp, parent: Module) -> list[str | Name]:
left = _extract(node.left, parent)
right = _extract(node.right, parent)
return left + right


_node_map: dict[type, Callable[[Any, Module], list[str | Name]]] = {
NodeConstant: _extract_constant,
NodeName: _extract_name,
NodeStarred: _extract_starred,
NodeList: _extract_sequence,
NodeSet: _extract_sequence,
NodeTuple: _extract_sequence,
NodeBinOp: _extract_binop,
ast.Constant: _extract_constant,
ast.Name: _extract_name,
ast.Starred: _extract_starred,
ast.List: _extract_sequence,
ast.Set: _extract_sequence,
ast.Tuple: _extract_sequence,
ast.BinOp: _extract_binop,
}


def _extract(node: AST, parent: Module) -> list[str | Name]:
def _extract(node: ast.AST, parent: Module) -> list[str | Name]:
return _node_map[type(node)](node, parent)


def get__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | Name]:
def get__all__(node: ast.Assign | ast.AugAssign, parent: Module) -> list[str | Name]:
"""Get the values declared in `__all__`.
Parameters:
Expand All @@ -83,7 +74,7 @@ def get__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | N


def safe_get__all__(
node: NodeAssign | NodeAugAssign,
node: ast.Assign | ast.AugAssign,
parent: Module,
log_level: LogLevel = LogLevel.debug, # TODO: set to error when we handle more things
) -> list[str | Name]:
Expand Down
16 changes: 5 additions & 11 deletions src/griffe/agents/nodes/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@

from __future__ import annotations

from ast import AST
from ast import Constant as NodeConstant
from ast import Expr as NodeExpr
from ast import Str as NodeStr
import ast

from griffe.logger import get_logger

logger = get_logger(__name__)


def get_docstring(
node: AST,
node: ast.AST,
*,
strict: bool = False,
) -> tuple[str | None, int | None, int | None]:
Expand All @@ -27,17 +24,14 @@ def get_docstring(
A tuple with the value and line numbers of the docstring.
"""
# TODO: possible optimization using a type map
if isinstance(node, NodeExpr):
if isinstance(node, ast.Expr):
doc = node.value
elif node.body and isinstance(node.body[0], NodeExpr) and not strict: # type: ignore[attr-defined]
elif node.body and isinstance(node.body[0], ast.Expr) and not strict: # type: ignore[attr-defined]
doc = node.body[0].value # type: ignore[attr-defined]
else:
return None, None, None
if isinstance(doc, NodeConstant) and isinstance(doc.value, str):
if isinstance(doc, ast.Constant) and isinstance(doc.value, str):
return doc.value, doc.lineno, doc.end_lineno
if isinstance(doc, NodeStr):
lineno = doc.lineno
return doc.s, lineno, doc.end_lineno
return None, None, None


Expand Down
Loading

0 comments on commit 21d5832

Please sign in to comment.