Skip to content

Commit

Permalink
Fix dataclass/protocol crash on joining types (python#15629)
Browse files Browse the repository at this point in the history
The root cause is hacky creation of incomplete symbols; instead
switching to `add_method_to_class` which does the necessary
housekeeping.

Fixes python#15618.
  • Loading branch information
ikonst authored Jul 14, 2023
1 parent 2ebd51e commit 7a94183
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 87 deletions.
7 changes: 4 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,9 +1200,10 @@ def check_func_def(
elif isinstance(arg_type, TypeVarType):
# Refuse covariant parameter type variables
# TODO: check recursively for inner type variables
if arg_type.variance == COVARIANT and defn.name not in (
"__init__",
"__new__",
if (
arg_type.variance == COVARIANT
and defn.name not in ("__init__", "__new__", "__post_init__")
and not is_private(defn.name) # private methods are not inherited
):
ctx: Context = arg_type
if ctx.line < 0:
Expand Down
137 changes: 55 additions & 82 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Final, Iterator
from typing import TYPE_CHECKING, Final, Iterator, Literal

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type, expand_type_by_instance
Expand Down Expand Up @@ -86,7 +86,7 @@
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"


class DataclassAttribute:
Expand Down Expand Up @@ -118,14 +118,33 @@ def __init__(
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
def to_argument(
self, current_info: TypeInfo, *, of: Literal["__init__", "replace", "__post_init__"]
) -> Argument:
if of == "__init__":
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
elif of == "replace":
arg_kind = ARG_NAMED if self.is_init_var and not self.has_default else ARG_NAMED_OPT
elif of == "__post_init__":
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kind = ARG_POS
return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
Expand Down Expand Up @@ -236,7 +255,7 @@ def transform(self) -> bool:
and attributes
):
args = [
attr.to_argument(info)
attr.to_argument(info, of="__init__")
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
Expand Down Expand Up @@ -375,70 +394,26 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
to be used later whenever 'dataclasses.replace' is called for this dataclass.
"""
arg_types: list[Type] = []
arg_kinds = []
arg_names: list[str | None] = []

info = self._cls.info
for attr in attributes:
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
arg_kinds.append(
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
)

info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_REPLACE_SYM_NAME,
args=[attr.to_argument(self._cls.info, of="replace") for attr in attributes],
return_type=NoneType(),
is_staticmethod=True,
)

def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
arg_types: list[Type] = [fill_typevars(self._cls.info)]
arg_kinds = [ARG_POS]
arg_names: list[str | None] = ["self"]

info = self._cls.info
for attr in attributes:
if not attr.is_init_var:
continue
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kinds.append(ARG_POS)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name="__post_init__",
)

info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_POST_INIT_SYM_NAME,
args=[
attr.to_argument(self._cls.info, of="__post_init__")
for attr in attributes
if attr.is_init_var
],
return_type=NoneType(),
)

def add_slots(
Expand Down Expand Up @@ -1120,20 +1095,18 @@ def is_processed_dataclass(info: TypeInfo | None) -> bool:
def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
if defn.type is None:
return

ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
if ideal_sig is None or ideal_sig.type is None:
return

# We set it ourself, so it is always fine:
assert isinstance(ideal_sig.type, ProperType)
assert isinstance(ideal_sig.type, FunctionLike)
# Type of `FuncItem` is always `FunctionLike`:
assert isinstance(defn.type, FunctionLike)

ideal_sig_method = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
assert ideal_sig_method is not None and ideal_sig_method.type is not None
ideal_sig = ideal_sig_method.type
assert isinstance(ideal_sig, ProperType) # we set it ourselves
assert isinstance(ideal_sig, CallableType)
ideal_sig = ideal_sig.copy_modified(name="__post_init__")

api.check_override(
override=defn.type,
original=ideal_sig.type,
original=ideal_sig,
name="__post_init__",
name_in_super="__post_init__",
supertype="dataclass",
Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,17 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in

[builtins fixtures/dataclasses.pyi]

[case testDataclassGenericCovariant]
from dataclasses import dataclass
from typing import Generic, TypeVar

T_co = TypeVar("T_co", covariant=True)

@dataclass
class MyDataclass(Generic[T_co]):
a: T_co

[builtins fixtures/dataclasses.pyi]

[case testDataclassUntypedGenericInheritance]
# flags: --python-version 3.7
Expand Down Expand Up @@ -2449,3 +2460,15 @@ class Test(Protocol):
def reset(self) -> None:
self.x = DEFAULT
[builtins fixtures/dataclasses.pyi]

[case testProtocolNoCrashOnJoining]
from dataclasses import dataclass
from typing import Protocol

@dataclass
class MyDataclass(Protocol): ...

a: MyDataclass
b = [a, a] # trigger joining the types

[builtins fixtures/dataclasses.pyi]
4 changes: 2 additions & 2 deletions test-data/unit/deps.test
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ class B(A):
<m.A.(abstract)> -> <m.B.__init__>, m
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__mypy-replace> -> <m.B.__mypy-replace>
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
Expand Down Expand Up @@ -1420,7 +1420,7 @@ class B(A):
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__match_args__> -> <m.B.__match_args__>
<m.A.__mypy-replace> -> <m.B.__mypy-replace>
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
Expand Down

0 comments on commit 7a94183

Please sign in to comment.