Skip to content

Commit

Permalink
Improved type inference for enum values in an enum class created with…
Browse files Browse the repository at this point in the history
… the one-line form `x = Enum("x", "a b c")`.
  • Loading branch information
msfterictraut committed Apr 15, 2023
1 parent 101b040 commit b49da70
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
52 changes: 20 additions & 32 deletions packages/pyright-internal/src/analyzer/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,9 @@
* Provides special-case logic for the Enum class.
*/

import { assert } from '../common/debug';
import { convertOffsetsToRange } from '../common/positionUtils';
import { TextRange } from '../common/textRange';
import {
ArgumentCategory,
ExpressionNode,
NameNode,
ParseNode,
ParseNodeType,
StringListNode,
} from '../parser/parseNodes';
import { ArgumentCategory, ExpressionNode, NameNode, ParseNode, ParseNodeType } from '../parser/parseNodes';
import { getFileInfo } from './analyzerNodeInfo';
import { DeclarationType, VariableDeclaration } from './declaration';
import { VariableDeclaration } from './declaration';
import { getClassFullName, getEnclosingClass, getTypeSourceId } from './parseTreeUtils';
import { Symbol, SymbolFlags } from './symbol';
import { isSingleDunderName } from './symbolNameUtils';
Expand All @@ -44,6 +34,7 @@ export function isKnownEnumType(className: string) {

// Creates a new custom enum class with named values.
export function createEnumType(
evaluator: TypeEvaluator,
errorNode: ExpressionNode,
enumClass: ClassType,
argList: FunctionArgument[]
Expand Down Expand Up @@ -101,33 +92,30 @@ export function createEnumType(
// recommend using the more standard class declaration syntax.
return undefined;
} else {
const classInstanceType = ClassType.cloneAsInstance(classType);
const intType = evaluator.getBuiltInType(errorNode, 'int');
const entries = entriesArg.valueExpression.strings
.map((s) => s.value)
.join('')
.split(' ');
entries.forEach((entryName) => {
entries.forEach((entryName, index) => {
entryName = entryName.trim();
if (entryName) {
const entryType = UnknownType.create();
const newSymbol = Symbol.createWithType(SymbolFlags.ClassMember, entryType);
const valueType =
intType && isInstantiableClass(intType)
? ClassType.cloneWithLiteral(ClassType.cloneAsInstance(intType), index + 1)
: UnknownType.create();

// We need to associate the declaration with a parse node.
// In this case it's just part of a string literal value.
// The definition provider won't necessarily take the
// user to the exact spot in the string, but it's close enough.
const stringNode = entriesArg.valueExpression!;
assert(stringNode.nodeType === ParseNodeType.StringList);
const fileInfo = getFileInfo(errorNode);
const declaration: VariableDeclaration = {
type: DeclarationType.Variable,
node: stringNode as StringListNode,
isRuntimeTypeExpression: true,
path: fileInfo.filePath,
range: convertOffsetsToRange(stringNode.start, TextRange.getEnd(stringNode), fileInfo.lines),
moduleName: fileInfo.moduleName,
isInExceptSuite: false,
};
newSymbol.addDeclaration(declaration);
const enumLiteral = new EnumLiteral(
classType.details.fullName,
classType.details.name,
entryName,
valueType
);
const newSymbol = Symbol.createWithType(
SymbolFlags.ClassMember,
ClassType.cloneWithLiteral(classInstanceType, enumLiteral)
);
classFields.set(entryName, newSymbol);
}
});
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9405,7 +9405,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

if (isClass(unexpandedSubtype) && isKnownEnumType(className)) {
return createEnumType(errorNode, expandedSubtype, argList) ?? UnknownType.create();
return createEnumType(evaluatorInterface, errorNode, expandedSubtype, argList) ?? UnknownType.create();
}

if (className === 'TypedDict') {
Expand Down
8 changes: 8 additions & 0 deletions packages/pyright-internal/src/tests/samples/enums1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ class TestEnum3(Enum):

a = TestEnum1["A"]
aa = TestEnum1.A
reveal_type(aa.name, expected_text="Literal['A']")
reveal_type(aa._name_, expected_text="Literal['A']")
reveal_type(aa.value, expected_text="Literal[1]")
reveal_type(aa._value_, expected_text="Literal[1]")
reveal_type(TestEnum1.D.name, expected_text="Literal['D']")
reveal_type(TestEnum1.D._name_, expected_text="Literal['D']")
reveal_type(TestEnum1.D.value, expected_text="Literal[4]")
reveal_type(TestEnum1.D._value_, expected_text="Literal[4]")

# This should generate an error because "Z" isn't
# a valid member.
Expand Down

0 comments on commit b49da70

Please sign in to comment.