From fb6f791c70dad41dae575fb55bca122ef5b5bb1f Mon Sep 17 00:00:00 2001 From: Nate McMaster Date: Sun, 30 Aug 2020 10:00:50 -0700 Subject: [PATCH] fix: resolve typevar before generating dataclasses init method Resolves python/mypy#7520 --- mypy/plugins/dataclasses.py | 15 ++++- test-data/unit/check-dataclasses.test | 96 +++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index b5c825394d13..5765e0599759 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -12,6 +12,7 @@ from mypy.plugins.common import ( add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, ) +from mypy.typeops import map_type_from_supertype from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type from mypy.server.trigger import make_wildcard_trigger @@ -34,6 +35,7 @@ def __init__( line: int, column: int, type: Optional[Type], + info: TypeInfo, ) -> None: self.name = name self.is_in_init = is_in_init @@ -42,6 +44,7 @@ def __init__( self.line = line self.column = column self.type = type + self.info = info def to_argument(self) -> Argument: return Argument( @@ -72,7 +75,15 @@ def deserialize( ) -> 'DataclassAttribute': data = data.copy() typ = deserialize_and_fixup_type(data.pop('type'), api) - return cls(type=typ, **data) + return cls(type=typ, info=info, **data) + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is inherited + from a generic super type.""" + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) class DataclassTransformer: @@ -267,6 +278,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: line=stmt.line, column=stmt.column, type=sym.type, + info=cls.info, )) # Next, collect attributes belonging to any class in the MRO @@ -287,6 +299,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: name = data['name'] # type: str if name not in known_attrs: attr = DataclassAttribute.deserialize(info, data, ctx.api) + attr.expand_typevar_from_subtype(ctx.cls.info) known_attrs.add(name) super_attrs.append(attr) elif all_attrs: diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index f965ac54bff5..3954df72db71 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -480,6 +480,102 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in [builtins fixtures/list.pyi] + +[case testDataclassUntypedGenericInheritance] +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + +@dataclass +class Base(Generic[T]): + attr: T + +@dataclass +class Sub(Base): + pass + +sub = Sub(attr=1) +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.attr) # N: Revealed type is 'Any' + + +[case testDataclassGenericSubtype] +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + +@dataclass +class Base(Generic[T]): + attr: T + +S = TypeVar("S") + +@dataclass +class Sub(Base[S]): + pass + +sub_int = Sub[int](attr=1) +reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]' +reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*' + +sub_str = Sub[str](attr='ok') +reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]' +reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*' + + +[case testDataclassGenericInheritance] +from dataclasses import dataclass +from typing import Generic, TypeVar + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") + +@dataclass +class Base(Generic[T1, T2, T3]): + one: T1 + two: T2 + three: T3 + +@dataclass +class Sub(Base[int, str, float]): + pass + +sub = Sub(one=1, two='ok', three=3.14) +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.one) # N: Revealed type is 'builtins.int*' +reveal_type(sub.two) # N: Revealed type is 'builtins.str*' +reveal_type(sub.three) # N: Revealed type is 'builtins.float*' + + +[case testDataclassMultiGenericInheritance] +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + +@dataclass +class Base(Generic[T]): + base_attr: T + +S = TypeVar("S") + +@dataclass +class Middle(Base[int], Generic[S]): + middle_attr: S + +@dataclass +class Sub(Middle[str]): + pass + +sub = Sub(base_attr=1, middle_attr='ok') +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*' +reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*' + + [case testDataclassGenericsClassmethod] # flags: --python-version 3.6 from dataclasses import dataclass