Skip to content

Commit

Permalink
Fixed false negative when a literal and non-literal are assigned to t…
Browse files Browse the repository at this point in the history
…he same TypeVar in an invariant context. This addresses microsoft#5321. (microsoft#5323)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Jun 17, 2023
1 parent 2829429 commit 0559382
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 20 deletions.
83 changes: 69 additions & 14 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12377,6 +12377,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
keyTypes,
valueTypes,
/* forceStrictInference */ true,
/* isValueTypeInvariant */ true,
/* expectedKeyType */ undefined,
/* expectedValueType */ undefined,
expectedTypedDictEntries,
Expand Down Expand Up @@ -12437,13 +12438,22 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const expectedKeyType = specializedDict.typeArguments[0];
const expectedValueType = specializedDict.typeArguments[1];

// Dict and MutableMapping types have invariant value types, so they
// cannot be narrowed further. Other super-types like Mapping, Collection,
// and Iterable use covariant value types, so they can be narrowed.
const isValueTypeInvariant =
isClassInstance(inferenceContext.expectedType) &&
(ClassType.isBuiltIn(inferenceContext.expectedType, 'dict') ||
ClassType.isBuiltIn(inferenceContext.expectedType, 'MutableMapping'));

// Infer the key and value types if possible.
if (
getKeyAndValueTypesFromDictionary(
node,
keyTypes,
valueTypes,
/* forceStrictInference */ true,
isValueTypeInvariant,
expectedKeyType,
expectedValueType,
undefined,
Expand All @@ -12453,14 +12463,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
isIncomplete = true;
}

// Dict and MutableMapping types have invariant value types, so they
// cannot be narrowed further. Other super-types like Mapping, Collection,
// and Iterable use covariant value types, so they can be narrowed.
const isValueTypeInvariant =
isClassInstance(inferenceContext.expectedType) &&
(ClassType.isBuiltIn(inferenceContext.expectedType, 'dict') ||
ClassType.isBuiltIn(inferenceContext.expectedType, 'MutableMapping'));

const specializedKeyType = inferTypeArgFromExpectedEntryType(
makeInferenceContext(expectedKeyType),
keyTypes.map((result) => result.type),
Expand Down Expand Up @@ -12498,7 +12500,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
node,
keyTypeResults,
valueTypeResults,
/* forceStrictInference */ hasExpectedType
/* forceStrictInference */ hasExpectedType,
/* isValueTypeInvariant */ false
)
) {
isIncomplete = true;
Expand Down Expand Up @@ -12554,6 +12557,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
keyTypes: TypeResultWithNode[],
valueTypes: TypeResultWithNode[],
forceStrictInference: boolean,
isValueTypeInvariant: boolean,
expectedKeyType?: Type,
expectedValueType?: Type,
expectedTypedDictEntries?: Map<string, TypedDictEntry>,
Expand Down Expand Up @@ -12589,6 +12593,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

let valueTypeResult: TypeResult;
let entryInferenceContext: InferenceContext | undefined;

if (
expectedTypedDictEntries &&
Expand All @@ -12598,21 +12603,32 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
expectedTypedDictEntries.has(keyType.literalValue as string)
) {
const effectiveValueType = expectedTypedDictEntries.get(keyType.literalValue as string)!.valueType;
entryInferenceContext = makeInferenceContext(effectiveValueType);
valueTypeResult = getTypeOfExpression(
entryNode.valueExpression,
/* flags */ undefined,
makeInferenceContext(effectiveValueType)
entryInferenceContext
);
} else {
const effectiveValueType =
expectedValueType ?? (forceStrictInference ? NeverType.createNever() : undefined);
entryInferenceContext = makeInferenceContext(effectiveValueType);
valueTypeResult = getTypeOfExpression(
entryNode.valueExpression,
/* flags */ undefined,
makeInferenceContext(effectiveValueType)
entryInferenceContext
);
}

if (entryInferenceContext && !valueTypeResult.typeErrors) {
valueTypeResult.type =
inferTypeArgFromExpectedEntryType(
entryInferenceContext,
[valueTypeResult.type],
!isValueTypeInvariant
) ?? valueTypeResult.type;
}

if (expectedDiagAddendum && valueTypeResult.expectedTypeDiagAddendum) {
expectedDiagAddendum.addAddendum(valueTypeResult.expectedTypeDiagAddendum);
}
Expand Down Expand Up @@ -12642,12 +12658,22 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}

const entryInferenceContext = makeInferenceContext(expectedType);
const unexpandedTypeResult = getTypeOfExpression(
entryNode.expandExpression,
/* flags */ undefined,
makeInferenceContext(expectedType)
entryInferenceContext
);

if (entryInferenceContext && !unexpandedTypeResult.typeErrors) {
unexpandedTypeResult.type =
inferTypeArgFromExpectedEntryType(
entryInferenceContext,
[unexpandedTypeResult.type],
!isValueTypeInvariant
) ?? unexpandedTypeResult.type;
}

if (unexpandedTypeResult.isIncomplete) {
isIncomplete = true;
}
Expand Down Expand Up @@ -13001,7 +13027,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return inferenceContext.expectedType;
}

const typeVarContext = new TypeVarContext();
const typeVarContext = new TypeVarContext(getTypeVarScopeId(inferenceContext.expectedType));
const expectedType = inferenceContext.expectedType;
let isCompatible = true;

Expand All @@ -13022,7 +13048,27 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
: stripLiteralValue(combinedTypes);
}

return applySolvedTypeVars(inferenceContext.expectedType, typeVarContext, { applyInScopePlaceholders: true });
return mapSubtypes(
applySolvedTypeVars(inferenceContext.expectedType, typeVarContext, { applyInScopePlaceholders: true }),
(subtype) => {
if (entryTypes.length !== 1) {
return subtype;
}
const entryType = entryTypes[0];

// If the entry type is a TypedDict instance, clone it with additional information.
if (
isTypeSame(subtype, entryType, { ignoreTypedDictNarrowEntries: true }) &&
isClass(subtype) &&
isClass(entryType) &&
ClassType.isTypedDictClass(entryType)
) {
return ClassType.cloneForNarrowedTypedDictEntries(subtype, entryType.typedDictNarrowedEntries);
}

return subtype;
}
);
}

function getTypeOfYield(node: YieldNode): TypeResult {
Expand Down Expand Up @@ -21519,6 +21565,15 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
prevSrcType = curSrcType;
}

// If we're enforcing invariance, literal types must match as well.
if ((flags & AssignTypeFlags.EnforceInvariance) !== 0) {
const srcIsLiteral = srcType.literalValue !== undefined;
const destIsLiteral = destType.literalValue !== undefined;
if (srcIsLiteral !== destIsLiteral) {
return false;
}
}

if (destType.typeArguments) {
// If the dest type is specialized, make sure the specialized source
// type arguments are assignable to the dest type arguments.
Expand Down
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export type InheritanceChain = (ClassType | UnknownType)[];
export interface TypeSameOptions {
ignorePseudoGeneric?: boolean;
ignoreTypeFlags?: boolean;
ignoreTypedDictNarrowEntries?: boolean;
treatAnySameAsUnknown?: boolean;
}

Expand Down Expand Up @@ -2688,7 +2689,7 @@ export function isTypeSame(type1: Type, type2: Type, options: TypeSameOptions =
return false;
}

if (!ClassType.isTypedDictNarrowedEntriesSame(type1, classType2)) {
if (!options.ignoreTypedDictNarrowEntries && !ClassType.isTypedDictNarrowedEntriesSame(type1, classType2)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def s17():

def s18():
a1: Mapping[object, object] = {"a": 3, "b": 5.6}
reveal_type(a1, expected_text="dict[object, float | int]")
reveal_type(a1, expected_text="dict[object, int | float]")

a2: Dict[object, object] = {"a": 3, "b": 5.6}
reveal_type(a2, expected_text="dict[object, object]")
Expand Down
21 changes: 21 additions & 0 deletions packages/pyright-internal/src/tests/samples/genericTypes117.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# This sample validates that a literal and a non-literal are not considered
# compatible types when in an invariant context.

from typing import Literal, TypeVar

T = TypeVar("T")


def func1(a: T, b: T) -> T:
return a


def func2() -> None:
foo_list: list[Literal["foo"]] = ["foo"]
x = func1(foo_list, [""])
reveal_type(x, expected_text="list[Literal['foo']] | list[str]")

# This should generate an error.
x.append("not foo")
print(foo_list)

9 changes: 7 additions & 2 deletions packages/pyright-internal/src/tests/samples/protocol6.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,26 @@ class Cow:
# doesn't provide an attributes.
c: Mammal[str] = Tapir()

# This should generate an error because "species"
# is incompatible.
d: Ungulate[bytes] = Camel()

e: Ungulate[str] = Cow()
f: CamelLike = Camel()



class CallTreeProto(Protocol):
subcalls: list["CallTreeProto"]


class MyCallTree:
subcalls: list["MyCallTree"]



class OtherCallTree:
subcalls: list["CallTreeProto"]


# This should generate an error.
x1: CallTreeProto = MyCallTree()

Expand Down
4 changes: 3 additions & 1 deletion packages/pyright-internal/src/tests/samples/typedDict18.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class TD4(TD3, Generic[_T1]):


v4: TD4[str] = {"a": 3, "b": ""}
v5: TD4[tuple[str]] = {"a": 3, "b": ("",)}


def func1(x: TD1[_T1, _T2]) -> dict[_T1, _T2]:
Expand Down Expand Up @@ -108,6 +109,7 @@ def func6(a: TD8) -> Literal[1]:

reveal_type(func6({"x": 1, "y": 1, "z": "a"}))


class TD9(TypedDict, Generic[_T1]):
x: _T1

Expand All @@ -123,4 +125,4 @@ def __init__(self, **attrs: Unpack[TD9[_T1]]) -> None:
f6 = ClassA[str](x=1)

f7 = ClassA(x=1)
reveal_type(f7, expected_text='ClassA[int]')
reveal_type(f7, expected_text="ClassA[int]")
8 changes: 7 additions & 1 deletion packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,12 @@ test('GenericTypes116', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('GenericTypes117', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['genericTypes117.py']);

TestUtils.validateResults(analysisResults, 1);
});

test('Protocol1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol1.py']);

Expand Down Expand Up @@ -1154,7 +1160,7 @@ test('Protocol5', () => {
test('Protocol6', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol6.py']);

TestUtils.validateResults(analysisResults, 3);
TestUtils.validateResults(analysisResults, 4);
});

test('Protocol7', () => {
Expand Down

0 comments on commit 0559382

Please sign in to comment.