From 9bf823160cde536758e547f1dfee7c6591dacdc2 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sat, 17 Jun 2023 22:48:18 -0700 Subject: [PATCH] Fixed a bug that led to incorrect type evaluation when passing a generic class (with a constructor that includes class-scoped TypeVars) as an argument for a callable parameter. The class was being specialized prematurely (with type arguments set to `Unknown`) before the constraint solver was able to solve the higher-order function's type variables. This addresses https://github.com/microsoft/pyright/issues/5324. (#5328) Co-authored-by: Eric Traut --- .../pyright-internal/src/analyzer/typeEvaluator.ts | 4 ++-- .../src/tests/samples/genericTypes118.py | 12 ++++++++++++ .../src/tests/typeEvaluator2.test.ts | 6 ++++++ 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/genericTypes118.py diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index bd070ab1635b..31d4511c5fa8 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -11214,7 +11214,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions EvaluatorFlags.EvaluateStringLiteralAsType | EvaluatorFlags.DisallowParamSpec | EvaluatorFlags.DisallowTypeVarTuple - : EvaluatorFlags.None; + : EvaluatorFlags.DoNotSpecialize; const exprTypeResult = getTypeOfExpression( argParam.argument.valueExpression, flags, @@ -22542,7 +22542,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions concreteSrcType, diag?.createAddendum(), destTypeVarContext ?? new TypeVarContext(getTypeVarScopeId(destType)), - srcTypeVarContext ?? new TypeVarContext(getTypeVarScopeId(concreteSrcType)), + srcTypeVarContext ?? new TypeVarContext(getTypeVarScopeIds(concreteSrcType)), flags, recursionCount ) diff --git a/packages/pyright-internal/src/tests/samples/genericTypes118.py b/packages/pyright-internal/src/tests/samples/genericTypes118.py new file mode 100644 index 000000000000..19cf05d47493 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/genericTypes118.py @@ -0,0 +1,12 @@ +# This sample tests the case where a generic class is passed as an +# argument to a function that accepts a generic callable parameter. +# The class-scoped TypeVars for the class must be preserved when +# solving the higher-order TypeVars. + +from itertools import compress +from typing import Any, Iterable + + +def func1(a: Iterable[Iterable[tuple[str, int]]], b: Any) -> None: + c = map(compress, a, b) + reveal_type(c, expected_text="map[compress[tuple[str, int]]]") diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index 9ba22d8d1f17..9e07d3d32491 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -1127,6 +1127,12 @@ test('GenericTypes117', () => { TestUtils.validateResults(analysisResults, 1); }); +test('GenericTypes118', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['genericTypes118.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('Protocol1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol1.py']);