Skip to content

Commit

Permalink
Add typing to brain_dataclasses (#1292)
Browse files Browse the repository at this point in the history
Co-authored-by: Pierre Sassoulas <[email protected]>
Co-authored-by: Marc Mueller <[email protected]>
  • Loading branch information
3 people authored Dec 29, 2021
1 parent aa4f5be commit 0d12115
Showing 1 changed file with 44 additions and 28 deletions.
72 changes: 44 additions & 28 deletions astroid/brain/brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
- https://lovasoa.github.io/marshmallow_dataclass/
"""
from typing import FrozenSet, Generator, List, Optional, Tuple
import sys
from typing import FrozenSet, Generator, List, Optional, Tuple, Union

from astroid import context, inference_tip
from astroid.builder import parse
Expand All @@ -36,6 +37,15 @@
from astroid.nodes.scoped_nodes import ClassDef, FunctionDef
from astroid.util import Uninferable

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

_FieldDefaultReturn = Union[
None, Tuple[Literal["default"], NodeNG], Tuple[Literal["default_factory"], Call]
]

DATACLASSES_DECORATORS = frozenset(("dataclass",))
FIELD_NAME = "field"
DATACLASS_MODULES = frozenset(
Expand Down Expand Up @@ -115,7 +125,7 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator:
):
continue

if _is_class_var(assign_node.annotation):
if _is_class_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None
continue

if init:
Expand All @@ -124,12 +134,13 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator:
isinstance(value, Call)
and _looks_like_dataclass_field_call(value, check_scope=False)
and any(
keyword.arg == "init" and not keyword.value.bool_value()
keyword.arg == "init"
and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None
for keyword in value.keywords
)
):
continue
elif _is_init_var(assign_node.annotation):
elif _is_init_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None
continue

yield assign_node
Expand Down Expand Up @@ -159,7 +170,8 @@ def _check_generate_dataclass_init(node: ClassDef) -> bool:

# Check for keyword arguments of the form init=False
return all(
keyword.arg != "init" or keyword.value.bool_value()
keyword.arg != "init"
and keyword.value.bool_value() # type: ignore[union-attr] # value is never None
for keyword in found.keywords
)

Expand All @@ -174,7 +186,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str:
name, annotation, value = assign.target.name, assign.annotation, assign.value
target_names.append(name)

if _is_init_var(annotation):
if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
init_var = True
if isinstance(annotation, Subscript):
annotation = annotation.slice
Expand All @@ -196,16 +208,16 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str:
value, check_scope=False
):
result = _get_field_default(value)

default_type, default_node = result
if default_type == "default":
param_str += f" = {default_node.as_string()}"
elif default_type == "default_factory":
param_str += f" = {DEFAULT_FACTORY}"
assignment_str = (
f"self.{name} = {default_node.as_string()} "
f"if {name} is {DEFAULT_FACTORY} else {name}"
)
if result:
default_type, default_node = result
if default_type == "default":
param_str += f" = {default_node.as_string()}"
elif default_type == "default_factory":
param_str += f" = {DEFAULT_FACTORY}"
assignment_str = (
f"self.{name} = {default_node.as_string()} "
f"if {name} is {DEFAULT_FACTORY} else {name}"
)
else:
param_str += f" = {value.as_string()}"

Expand All @@ -219,7 +231,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str:


def infer_dataclass_attribute(
node: Unknown, ctx: context.InferenceContext = None
node: Unknown, ctx: Optional[context.InferenceContext] = None
) -> Generator:
"""Inference tip for an Unknown node that was dynamically generated to
represent a dataclass attribute.
Expand Down Expand Up @@ -247,16 +259,17 @@ def infer_dataclass_field_call(
"""Inference tip for dataclass field calls."""
if not isinstance(node.parent, (AnnAssign, Assign)):
raise UseInferenceDefault
field_call = node.parent.value
default_type, default = _get_field_default(field_call)
if not default_type:
result = _get_field_default(node)
if not result:
yield Uninferable
elif default_type == "default":
yield from default.infer(context=ctx)
else:
new_call = parse(default.as_string()).body[0].value
new_call.parent = field_call.parent
yield from new_call.infer(context=ctx)
default_type, default = result
if default_type == "default":
yield from default.infer(context=ctx)
else:
new_call = parse(default.as_string()).body[0].value
new_call.parent = node.parent
yield from new_call.infer(context=ctx)


def _looks_like_dataclass_decorator(
Expand Down Expand Up @@ -294,6 +307,9 @@ def _looks_like_dataclass_attribute(node: Unknown) -> bool:
statement.
"""
parent = node.parent
if not parent:
return False

scope = parent.scope()
return (
isinstance(parent, AnnAssign)
Expand Down Expand Up @@ -330,7 +346,7 @@ def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bo
return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES


def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]:
def _get_field_default(field_call: Call) -> _FieldDefaultReturn:
"""Return a the default value of a field call, and the corresponding keyword argument name.
field(default=...) results in the ... node
Expand Down Expand Up @@ -358,7 +374,7 @@ def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]:
new_call.postinit(func=default_factory)
return "default_factory", new_call

return "", None
return None


def _is_class_var(node: NodeNG) -> bool:
Expand Down Expand Up @@ -404,7 +420,7 @@ def _is_init_var(node: NodeNG) -> bool:


def _infer_instance_from_annotation(
node: NodeNG, ctx: context.InferenceContext = None
node: NodeNG, ctx: Optional[context.InferenceContext] = None
) -> Generator:
"""Infer an instance corresponding to the type annotation represented by node.
Expand Down

0 comments on commit 0d12115

Please sign in to comment.