diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index f00f0f8ea78f..05f2859903fa 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -1145,29 +1145,32 @@ function narrowTypeBasedOnValuePattern( (subjectSubtypeExpanded) => { // If this is a negative test, see if it's an enum value. if (!isPositiveTest) { - if ( - isClassInstance(subjectSubtypeExpanded) && - ClassType.isEnumClass(subjectSubtypeExpanded) && - !isLiteralType(subjectSubtypeExpanded) && - isClassInstance(valueSubtypeExpanded) && - isSameWithoutLiteralValue(subjectSubtypeExpanded, valueSubtypeExpanded) && - isLiteralType(valueSubtypeExpanded) - ) { - const allEnumTypes = enumerateLiteralsForType(evaluator, subjectSubtypeExpanded); - if (allEnumTypes) { - return combineTypes( - allEnumTypes.filter( - (enumType) => !ClassType.isLiteralValueSame(valueSubtypeExpanded, enumType) - ) + if (isClassInstance(subjectSubtypeExpanded) && isClassInstance(valueSubtypeExpanded)) { + if ( + !isLiteralType(subjectSubtypeExpanded) && + isSameWithoutLiteralValue(subjectSubtypeExpanded, valueSubtypeExpanded) && + isLiteralType(valueSubtypeExpanded) + ) { + const expandedLiterals = enumerateLiteralsForType( + evaluator, + subjectSubtypeExpanded ); + if (expandedLiterals) { + return combineTypes( + expandedLiterals.filter( + (enumType) => + !ClassType.isLiteralValueSame(valueSubtypeExpanded, enumType) + ) + ); + } + } + + if ( + isLiteralType(subjectSubtypeExpanded) && + ClassType.isLiteralValueSame(valueSubtypeExpanded, subjectSubtypeExpanded) + ) { + return undefined; } - } else if ( - isClassInstance(subjectSubtypeExpanded) && - isClassInstance(valueSubtypeExpanded) && - isLiteralType(subjectSubtypeExpanded) && - ClassType.isLiteralValueSame(valueSubtypeExpanded, subjectSubtypeExpanded) - ) { - return undefined; } return subjectSubtypeExpanded; diff --git a/packages/pyright-internal/src/tests/samples/matchValue1.py b/packages/pyright-internal/src/tests/samples/matchValue1.py index 0144be986440..bfd5d8ba672a 100644 --- a/packages/pyright-internal/src/tests/samples/matchValue1.py +++ b/packages/pyright-internal/src/tests/samples/matchValue1.py @@ -3,8 +3,8 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Annotated, TypeVar from http import HTTPStatus +from typing import Annotated, Literal, TypeVar # pyright: reportIncompatibleMethodOverride=false @@ -141,3 +141,17 @@ def test_enum_narrowing_with_inf(subj: float): reveal_type(subj, expected_text="float") case f: reveal_type(subj, expected_text="float") + + +@dataclass +class DC2: + a: Literal[False] + + +def test_bool_expansion(subj: bool): + match subj: + case DC2.a: + reveal_type(subj, expected_text="Literal[False]") + + case x: + reveal_type(subj, expected_text="Literal[True]")