Skip to content
This repository has been archived by the owner on Dec 27, 2021. It is now read-only.

Commit

Permalink
fix: resolve typevar before generating dataclasses init method
Browse files Browse the repository at this point in the history
Resolves python#7520
  • Loading branch information
natemcmaster committed Aug 30, 2020
1 parent 6e25daa commit fb6f791
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 1 deletion.
15 changes: 14 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
)
from mypy.typeops import map_type_from_supertype
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
from mypy.server.trigger import make_wildcard_trigger

Expand All @@ -34,6 +35,7 @@ def __init__(
line: int,
column: int,
type: Optional[Type],
info: TypeInfo,
) -> None:
self.name = name
self.is_in_init = is_in_init
Expand All @@ -42,6 +44,7 @@ def __init__(
self.line = line
self.column = column
self.type = type
self.info = info

def to_argument(self) -> Argument:
return Argument(
Expand Down Expand Up @@ -72,7 +75,15 @@ def deserialize(
) -> 'DataclassAttribute':
data = data.copy()
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, **data)
return cls(type=typ, info=info, **data)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if not isinstance(self.type, TypeVarType):
return

self.type = map_type_from_supertype(self.type, sub_type, self.info)


class DataclassTransformer:
Expand Down Expand Up @@ -267,6 +278,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
line=stmt.line,
column=stmt.column,
type=sym.type,
info=cls.info,
))

# Next, collect attributes belonging to any class in the MRO
Expand All @@ -287,6 +299,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
name = data['name'] # type: str
if name not in known_attrs:
attr = DataclassAttribute.deserialize(info, data, ctx.api)
attr.expand_typevar_from_subtype(ctx.cls.info)
known_attrs.add(name)
super_attrs.append(attr)
elif all_attrs:
Expand Down
96 changes: 96 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,102 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in

[builtins fixtures/list.pyi]


[case testDataclassUntypedGenericInheritance]
from dataclasses import dataclass
from typing import Generic, TypeVar

T = TypeVar("T")

@dataclass
class Base(Generic[T]):
attr: T

@dataclass
class Sub(Base):
pass

sub = Sub(attr=1)
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.attr) # N: Revealed type is 'Any'


[case testDataclassGenericSubtype]
from dataclasses import dataclass
from typing import Generic, TypeVar

T = TypeVar("T")

@dataclass
class Base(Generic[T]):
attr: T

S = TypeVar("S")

@dataclass
class Sub(Base[S]):
pass

sub_int = Sub[int](attr=1)
reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]'
reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*'

sub_str = Sub[str](attr='ok')
reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]'
reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*'


[case testDataclassGenericInheritance]
from dataclasses import dataclass
from typing import Generic, TypeVar

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")

@dataclass
class Base(Generic[T1, T2, T3]):
one: T1
two: T2
three: T3

@dataclass
class Sub(Base[int, str, float]):
pass

sub = Sub(one=1, two='ok', three=3.14)
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.one) # N: Revealed type is 'builtins.int*'
reveal_type(sub.two) # N: Revealed type is 'builtins.str*'
reveal_type(sub.three) # N: Revealed type is 'builtins.float*'


[case testDataclassMultiGenericInheritance]
from dataclasses import dataclass
from typing import Generic, TypeVar

T = TypeVar("T")

@dataclass
class Base(Generic[T]):
base_attr: T

S = TypeVar("S")

@dataclass
class Middle(Base[int], Generic[S]):
middle_attr: S

@dataclass
class Sub(Middle[str]):
pass

sub = Sub(base_attr=1, middle_attr='ok')
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*'
reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*'


[case testDataclassGenericsClassmethod]
# flags: --python-version 3.6
from dataclasses import dataclass
Expand Down

0 comments on commit fb6f791

Please sign in to comment.