Skip to content

Commit

Permalink
Fix mypy crash on dataclasses.field(**unpack) (#11137)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn authored Sep 21, 2021
1 parent fab534b commit a7d6e68
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
15 changes: 12 additions & 3 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue)
has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)

is_in_init_param = field_args.get('init')
if is_in_init_param is None:
Expand Down Expand Up @@ -447,7 +447,8 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
transformer.transform()


def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
def _collect_field_args(expr: Expression,
ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
Expand All @@ -460,7 +461,15 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
# field() only takes keyword arguments.
args = {}
for name, arg in zip(expr.arg_names, expr.args):
assert name is not None
if name is None:
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
ctx.api.fail(
'Unpacking **kwargs in "field()" is not supported',
expr,
)
return True, {}
args[name] = arg
return True, args
return False, {}
36 changes: 36 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,39 @@ a.x = x
a.x = x2 # E: Incompatible types in assignment (expression has type "Callable[[str], str]", variable has type "Callable[[int], int]")

[builtins fixtures/dataclasses.pyi]


[case testDataclassFieldDoesNotFailOnKwargsUnpacking]
# flags: --python-version 3.7
# https://github.com/python/mypy/issues/10879
from dataclasses import dataclass, field

@dataclass
class Foo:
bar: float = field(**{"repr": False})
[out]
main:7: error: Unpacking **kwargs in "field()" is not supported
main:7: error: No overload variant of "field" matches argument type "Dict[str, bool]"
main:7: note: Possible overload variants:
main:7: note: def [_T] field(*, default: _T, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
main:7: note: def [_T] field(*, default_factory: Callable[[], _T], init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
main:7: note: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any
[builtins fixtures/dataclasses.pyi]


[case testDataclassFieldWithTypedDictUnpacking]
# flags: --python-version 3.7
from dataclasses import dataclass, field
from typing_extensions import TypedDict

class FieldKwargs(TypedDict):
repr: bool

field_kwargs: FieldKwargs = {"repr": False}

@dataclass
class Foo:
bar: float = field(**field_kwargs) # E: Unpacking **kwargs in "field()" is not supported

reveal_type(Foo(bar=1.5)) # N: Revealed type is "__main__.Foo"
[builtins fixtures/dataclasses.pyi]
25 changes: 23 additions & 2 deletions test-data/unit/fixtures/dataclasses.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Generic, Sequence, TypeVar
from typing import (
Generic, Iterator, Iterable, Mapping, Optional, Sequence, Tuple,
TypeVar, Union, overload,
)

_T = TypeVar('_T')
_U = TypeVar('_U')
KT = TypeVar('KT')
VT = TypeVar('VT')

class object:
def __init__(self) -> None: pass
Expand All @@ -15,7 +20,23 @@ class int: pass
class float: pass
class str: pass
class bool(int): pass
class dict(Generic[_T, _U]): pass

class dict(Mapping[KT, VT]):
@overload
def __init__(self, **kwargs: VT) -> None: pass
@overload
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
def __getitem__(self, key: KT) -> VT: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def __contains__(self, item: object) -> int: pass
def update(self, a: Mapping[KT, VT]) -> None: pass
@overload
def get(self, k: KT) -> Optional[VT]: pass
@overload
def get(self, k: KT, default: Union[KT, _T]) -> Union[VT, _T]: pass
def __len__(self) -> int: ...

class list(Generic[_T], Sequence[_T]): pass
class function: pass
class classmethod: pass
Expand Down

0 comments on commit a7d6e68

Please sign in to comment.