Skip to content

Commit

Permalink
Changed auto-variance algorithm to ignore __new__ and __init__ me…
Browse files Browse the repository at this point in the history
…thods for purposes of calculating the variance of a TypeVar. This mirrors the behavior of mypy. (microsoft#5327)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Jun 18, 2023
1 parent 3021b9c commit 47cd514
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 54 deletions.
103 changes: 55 additions & 48 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21207,58 +21207,65 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let isAssignable = true;

destType.details.fields.forEach((symbol, name) => {
if (isAssignable && symbol.isClassMember() && !symbol.isIgnoredForProtocolMatch()) {
const memberInfo = lookUpClassMember(srcType, name);
assert(memberInfo !== undefined);
if (!isAssignable || !symbol.isClassMember() || symbol.isIgnoredForProtocolMatch()) {
return;
}

// Constructor methods are exempt from variance calculations.
if (name === '__new__' || name === '__init__') {
return;
}

const memberInfo = lookUpClassMember(srcType, name);
assert(memberInfo !== undefined);

let destMemberType = getDeclaredTypeOfSymbol(symbol)?.type;
if (destMemberType) {
const srcMemberType = getTypeOfMember(memberInfo!);
destMemberType = partiallySpecializeType(destMemberType, destType);
let destMemberType = getDeclaredTypeOfSymbol(symbol)?.type;
if (destMemberType) {
const srcMemberType = getTypeOfMember(memberInfo!);
destMemberType = partiallySpecializeType(destMemberType, destType);

// Properties require special processing.
// Properties require special processing.
if (
isClassInstance(destMemberType) &&
ClassType.isPropertyClass(destMemberType) &&
isClassInstance(srcMemberType) &&
ClassType.isPropertyClass(srcMemberType)
) {
if (
isClassInstance(destMemberType) &&
ClassType.isPropertyClass(destMemberType) &&
isClassInstance(srcMemberType) &&
ClassType.isPropertyClass(srcMemberType)
!assignProperty(
evaluatorInterface,
ClassType.cloneAsInstantiable(destMemberType),
ClassType.cloneAsInstantiable(srcMemberType),
destType,
srcType,
diag,
typeVarContext,
/* selfTypeVarContext */ undefined,
recursionCount
)
) {
if (
!assignProperty(
evaluatorInterface,
ClassType.cloneAsInstantiable(destMemberType),
ClassType.cloneAsInstantiable(srcMemberType),
destType,
srcType,
diag,
typeVarContext,
/* selfTypeVarContext */ undefined,
recursionCount
)
) {
isAssignable = false;
}
} else {
const primaryDecl = symbol.getDeclarations()[0];
// Class and instance variables that are mutable need to
// enforce invariance.
const flags =
primaryDecl?.type === DeclarationType.Variable && !isFinalVariableDeclaration(primaryDecl)
? AssignTypeFlags.EnforceInvariance
: AssignTypeFlags.Default;
if (
!assignType(
destMemberType,
srcMemberType,
diag,
typeVarContext,
/* srcTypeVarContext */ undefined,
flags,
recursionCount
)
) {
isAssignable = false;
}
isAssignable = false;
}
} else {
const primaryDecl = symbol.getDeclarations()[0];
// Class and instance variables that are mutable need to
// enforce invariance.
const flags =
primaryDecl?.type === DeclarationType.Variable && !isFinalVariableDeclaration(primaryDecl)
? AssignTypeFlags.EnforceInvariance
: AssignTypeFlags.Default;
if (
!assignType(
destMemberType,
srcMemberType,
diag,
typeVarContext,
/* srcTypeVarContext */ undefined,
flags,
recursionCount
)
) {
isAssignable = false;
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions packages/pyright-internal/src/tests/samples/autoVariance1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __init__(self, value: T) -> None:
def value(self):
return self._value

@value.setter
def value(self, value: T):
self._value = value

# This should generate an error based on variance
vinv1_1: ShouldBeInvariant1[float] = ShouldBeInvariant1[int](1)

Expand All @@ -47,6 +51,9 @@ def __init__(self, value: T) -> None:
def get_value(self) ->T:
return self._value

def set_value(self, value: T):
self._value = value

# This should generate an error based on variance
vinv2_1: ShouldBeInvariant2[float] = ShouldBeInvariant2[int](1)

Expand All @@ -73,6 +80,9 @@ class ShouldBeInvariant3[K, V](dict[K, V]):
class ShouldBeContravariant1[T]:
def __init__(self, value: T) -> None:
self._value = value

def set_value(self, value: T) -> None:
self._value = value


# This should generate an error based on variance
Expand Down
24 changes: 20 additions & 4 deletions packages/pyright-internal/src/tests/samples/autoVariance3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@


class ShouldBeCovariant1(Generic[T]):
def __getitem__(self, index: int) -> T: ...
def __iter__(self) -> Iterator[T]: ...
def __getitem__(self, index: int) -> T:
...

def __iter__(self) -> Iterator[T]:
...


vco1_1: ShouldBeCovariant1[float] = ShouldBeCovariant1[int]()
Expand All @@ -30,13 +33,13 @@ def __iter__(self) -> Iterator[T]: ...
class ShouldBeCovariant2(Sequence[T]):
pass


vco2_1: ShouldBeCovariant2[float] = ShouldBeCovariant2[int]()

# This should generate an error based on variance
vco2_2: ShouldBeCovariant2[int] = ShouldBeCovariant2[float]()



class ShouldBeInvariant1(Generic[T]):
def __init__(self, value: T) -> None:
self._value = value
Expand All @@ -45,6 +48,11 @@ def __init__(self, value: T) -> None:
def value(self):
return self._value

@value.setter
def value(self, value: T):
self._value = value


# This should generate an error based on variance
vinv1_1: ShouldBeInvariant1[float] = ShouldBeInvariant1[int](1)

Expand All @@ -56,9 +64,13 @@ class ShouldBeInvariant2(Generic[T]):
def __init__(self, value: T) -> None:
self._value = value

def get_value(self) ->T:
def get_value(self) -> T:
return self._value

def set_value(self, value: T):
self._value = value


# This should generate an error based on variance
vinv2_1: ShouldBeInvariant2[float] = ShouldBeInvariant2[int](1)

Expand All @@ -69,6 +81,7 @@ def get_value(self) ->T:
class ShouldBeInvariant3(dict[K, V]):
pass


# This should generate an error based on variance
vinv3_1: ShouldBeInvariant3[float, str] = ShouldBeInvariant3[int, str]()

Expand All @@ -86,6 +99,9 @@ class ShouldBeContravariant1(Generic[T]):
def __init__(self, value: T) -> None:
self._value = value

def set_value(self, value: T):
self._value = value


# This should generate an error based on variance
vcontra1_1: ShouldBeContravariant1[float] = ShouldBeContravariant1[int](1)
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/partial1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Protocol, Self, TypeVar

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2", contravariant=True)
_T2 = TypeVar("_T2", covariant=True)


def func1():
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/protocol29.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Protocol, Self
from typing import Any, Callable, Type, TypeVar

_T = TypeVar("_T")
_T = TypeVar("_T", covariant=True)


class Partial(Protocol[_T]):
Expand Down

0 comments on commit 47cd514

Please sign in to comment.