diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index c6529b7ed477..bb94f9069c00 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -16884,6 +16884,15 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions functionNode.parameters[paramIndex].category ); + // If the parameter type is generic, specialize it in the context + // of the child class. + if (requiresSpecialization(inferredParamType) && isClass(baseClassMemberInfo.classType)) { + const typeVarContext = buildTypeVarContextFromSpecializedClass( + baseClassMemberInfo.classType + ); + inferredParamType = applySolvedTypeVars(inferredParamType, typeVarContext); + } + const fileInfo = AnalyzerNodeInfo.getFileInfo(functionNode); if (fileInfo.isInPyTypedPackage && !fileInfo.isStubFile) { inferredParamType = TypeBase.cloneForAmbiguousType(inferredParamType); diff --git a/packages/pyright-internal/src/tests/samples/paramInference2.py b/packages/pyright-internal/src/tests/samples/paramInference2.py new file mode 100644 index 000000000000..3ff981a8dfd6 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/paramInference2.py @@ -0,0 +1,20 @@ +# This sample tests the logic that infers parameter types based on +# annotated base class methods when the base class is generic. + + +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class Parent(Generic[T]): + def method1(self, a: T, b: list[T]) -> None: + ... + + +class Child(Parent[float]): + def method1(self, a, b): + reveal_type(self, expected_text="Self@Child") + reveal_type(a, expected_text="float") + reveal_type(b, expected_text="list[float]") + return a diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index 87bb6aa11c22..2b3031b61bea 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -1571,6 +1571,12 @@ test('ParamInference1', () => { TestUtils.validateResults(analysisResults, 0); }); +test('ParamInference2', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['paramInference2.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('Dictionary1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dictionary1.py']);