Skip to content

Commit

Permalink
Improved error handling for enum classes created with the one-line fo…
Browse files Browse the repository at this point in the history
…rm `x = Enum("x", "a b c")`. The code now handles other whitespaces and commas and properly checks for format strings.
  • Loading branch information
msfterictraut committed Apr 15, 2023
1 parent b49da70 commit fcdb119
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 54 deletions.
120 changes: 67 additions & 53 deletions packages/pyright-internal/src/analyzer/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,23 @@ export function createEnumType(
argList: FunctionArgument[]
): ClassType | undefined {
const fileInfo = getFileInfo(errorNode);
let className = 'enum';

if (argList.length === 0) {
return undefined;
} else {
const nameArg = argList[0];
if (
nameArg.argumentCategory === ArgumentCategory.Simple &&
nameArg.valueExpression &&
nameArg.valueExpression.nodeType === ParseNodeType.StringList
) {
className = nameArg.valueExpression.strings.map((s) => s.value).join('');
} else {
return undefined;
}
}

const nameArg = argList[0];
if (
nameArg.argumentCategory !== ArgumentCategory.Simple ||
!nameArg.valueExpression ||
nameArg.valueExpression.nodeType !== ParseNodeType.StringList ||
nameArg.valueExpression.strings.length !== 1 ||
nameArg.valueExpression.strings[0].nodeType !== ParseNodeType.String
) {
return undefined;
}

const className = nameArg.valueExpression.strings.map((s) => s.value).join('');
const classType = ClassType.createInstantiable(
className,
getClassFullName(errorNode, fileInfo.moduleName, className),
Expand All @@ -77,49 +78,62 @@ export function createEnumType(

if (argList.length < 2) {
return undefined;
} else {
const entriesArg = argList[1];
if (
entriesArg.argumentCategory !== ArgumentCategory.Simple ||
!entriesArg.valueExpression ||
entriesArg.valueExpression.nodeType !== ParseNodeType.StringList
) {
// Technically, the Enum constructor supports a bunch of different
// ways to specify the items: space-delimited string, a string
// iterator, an iterator of name/value tuples, and a dictionary
// of name/value pairs. We support only the simple space-delimited
// string here. For users who are interested in type checking, we
// recommend using the more standard class declaration syntax.
}

const initArg = argList[1];
if (initArg.argumentCategory !== ArgumentCategory.Simple || !initArg.valueExpression) {
return undefined;
}

// The Enum factory call supports various forms of arguments:
// Enum('name', 'a b c')
// Enum('name', 'a,b,c')
// Enum('name', ['a', 'b', 'c'])
// Enum('name', ('a', 'b', 'c'))
// Enum('name', (('a', 1), ('b', 2), ('c', 3)))
// Enum('name', [('a', 1), ('b', 2), ('c', 3))]
// Enum('name', {'a': 1, 'b': 2, 'c': 3}
if (initArg.valueExpression.nodeType === ParseNodeType.StringList) {
// Don't allow format strings in the init arg.
if (!initArg.valueExpression.strings.every((str) => str.nodeType === ParseNodeType.String)) {
return undefined;
} else {
const classInstanceType = ClassType.cloneAsInstance(classType);
const intType = evaluator.getBuiltInType(errorNode, 'int');
const entries = entriesArg.valueExpression.strings
.map((s) => s.value)
.join('')
.split(' ');
entries.forEach((entryName, index) => {
entryName = entryName.trim();
if (entryName) {
const valueType =
intType && isInstantiableClass(intType)
? ClassType.cloneWithLiteral(ClassType.cloneAsInstance(intType), index + 1)
: UnknownType.create();

const enumLiteral = new EnumLiteral(
classType.details.fullName,
classType.details.name,
entryName,
valueType
);
const newSymbol = Symbol.createWithType(
SymbolFlags.ClassMember,
ClassType.cloneWithLiteral(classInstanceType, enumLiteral)
);
classFields.set(entryName, newSymbol);
}
});
}

const classInstanceType = ClassType.cloneAsInstance(classType);
const intClassType = evaluator.getBuiltInType(errorNode, 'int');

const initStr = initArg.valueExpression.strings
.map((s) => s.value)
.join('')
.trim();

// Split by comma or whitespace.
const entryNames = initStr.split(/[\s,]+/);

for (const [index, entryName] of entryNames.entries()) {
if (!entryName) {
return undefined;
}

const valueType =
intClassType && isInstantiableClass(intClassType)
? ClassType.cloneWithLiteral(ClassType.cloneAsInstance(intClassType), index + 1)
: UnknownType.create();

const enumLiteral = new EnumLiteral(
classType.details.fullName,
classType.details.name,
entryName,
valueType
);
const newSymbol = Symbol.createWithType(
SymbolFlags.ClassMember,
ClassType.cloneWithLiteral(classInstanceType, enumLiteral)
);
classFields.set(entryName, newSymbol);
}

return classType;
}

return classType;
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/enums1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum, IntEnum


TestEnum1 = Enum("TestEnum1", "A B C D")
TestEnum1 = Enum("TestEnum1", " A B,C , \t D\t")
TestEnum2 = IntEnum("TestEnum2", "AA BB CC DD")


Expand Down

0 comments on commit fcdb119

Please sign in to comment.