Skip to content

Commit

Permalink
Another round of simplification for TypeVar solving. Add support for …
Browse files Browse the repository at this point in the history
…recursive solutions (i.e. a TypeVar that has an upper or lower bound that depends on another TypeVar). (#8663)
  • Loading branch information
erictraut authored Aug 5, 2024
1 parent 4d471f1 commit 35ab773
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 110 deletions.
13 changes: 10 additions & 3 deletions packages/pyright-internal/src/analyzer/constraintSolution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { FunctionType, ParamSpecType, Type, TypeVarType } from './types';
// Records the types associated with a set of type variables.
export class ConstraintSolutionSet {
// Indexed by TypeVar ID.
private _typeVarMap: Map<string, Type>;
private _typeVarMap: Map<string, Type | undefined>;

// See the comment in constraintTracker for details about identifying
// solution sets by scope ID.
Expand All @@ -39,14 +39,21 @@ export class ConstraintSolutionSet {
return this._typeVarMap.get(key);
}

setType(typeVar: TypeVarType, type: Type) {
setType(typeVar: TypeVarType, type: Type | undefined) {
const key = TypeVarType.getNameWithScope(typeVar);
return this._typeVarMap.set(key, type);
}

hasType(typeVar: TypeVarType): boolean {
const key = TypeVarType.getNameWithScope(typeVar);
return this._typeVarMap.has(key);
}

doForEachTypeVar(callback: (type: Type, typeVarId: string) => void) {
this._typeVarMap.forEach((type, key) => {
callback(type, key);
if (type) {
callback(type, key);
}
});
}
}
Expand Down
115 changes: 66 additions & 49 deletions packages/pyright-internal/src/analyzer/constraintSolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import { DiagnosticAddendum } from '../common/diagnostic';
import { LocAddendum } from '../localization/localize';
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
import { ConstraintSet, ConstraintTracker } from './constraintTracker';
import { ConstraintSet, ConstraintTracker, TypeVarConstraints } from './constraintTracker';
import { maxSubtypesForInferredType, SolveConstraintsOptions, TypeEvaluator } from './typeEvaluatorTypes';
import {
ClassType,
Expand Down Expand Up @@ -52,6 +52,7 @@ import {
convertToInstantiable,
convertTypeToParamSpecValue,
getTypeCondition,
getTypeVarArgsRecursive,
getTypeVarScopeId,
isEffectivelyInstantiable,
isLiteralTypeOrUnion,
Expand All @@ -63,7 +64,6 @@ import {
specializeWithDefaultTypeArgs,
transformExpectedType,
transformPossibleRecursiveTypeAlias,
TypeVarTransformer,
} from './typeUtils';

// As we widen the lower bound of a type variable, we may end up with
Expand Down Expand Up @@ -239,15 +239,75 @@ export function solveConstraintSet(
const solutionSet = new ConstraintSolutionSet(constraintSet.getScopeIds());

constraintSet.doForEachTypeVar((entry) => {
const value = getTypeVarType(evaluator, constraintSet, entry.typeVar, options?.useLowerBoundOnly);
if (value) {
solutionSet.setType(entry.typeVar, value);
}
solveTypeVarRecursive(evaluator, constraintSet, options, solutionSet, entry);
});

return solutionSet;
}

function solveTypeVarRecursive(
evaluator: TypeEvaluator,
constraintSet: ConstraintSet,
options: SolveConstraintsOptions | undefined,
solutionSet: ConstraintSolutionSet,
entry: TypeVarConstraints
): Type | undefined {
// If this TypeVar already has a solution, don't attempt to re-solve it.
if (solutionSet.hasType(entry.typeVar)) {
return solutionSet.getType(entry.typeVar);
}

// Protect against infinite recursion by setting the initial value to undefined.
solutionSet.setType(entry.typeVar, undefined);
let value = getTypeVarType(evaluator, constraintSet, entry.typeVar, options?.useLowerBoundOnly);

if (value) {
// Are there any unsolved TypeVars in this type?
const typeVars = getTypeVarArgsRecursive(value);

if (typeVars.length > 0) {
const dependentSolution = new ConstraintSolution();

for (const typeVar of typeVars) {
// Don't attempt to replace a TypeVar with itself.
if (isTypeSame(typeVar, entry.typeVar, { ignoreTypeFlags: true })) {
continue;
}

// Don't attempt to solve or replace bound TypeVars.
if (TypeVarType.isBound(typeVar)) {
continue;
}

const dependentEntry = constraintSet.getTypeVar(typeVar);
if (!dependentEntry) {
continue;
}

const dependentType = solveTypeVarRecursive(
evaluator,
constraintSet,
options,
solutionSet,
dependentEntry
);

if (dependentType) {
dependentSolution.setType(typeVar, dependentType);
}
}

// Apply the dependent TypeVar values to the current TypeVar value.
if (!dependentSolution.isEmpty()) {
value = applySolvedTypeVars(value, dependentSolution);
}
}
}

solutionSet.setType(entry.typeVar, value);
return value;
}

// In cases where the expected type is a specialized base class of the
// source type, we need to determine which type arguments in the derived
// class will make it compatible with the specialized base class. This method
Expand Down Expand Up @@ -441,30 +501,6 @@ export function addConstraintsForExpectedType(
return false;
}

// If the constraint tracker contains any type variables whose types depend on
// unification vars used for bidirectional type inference, replace those
// with the solved type associated with those unification vars.
export function applyUnificationVars(evaluator: TypeEvaluator, constraints: ConstraintTracker) {
constraints.doForEachConstraintSet((constraintSet) => {
if (!constraintSet.hasUnificationVars()) {
return;
}

constraintSet.getTypeVars().forEach((entry) => {
if (!TypeVarType.isUnification(entry.typeVar)) {
const newLowerBound = entry.lowerBound
? applyUnificationVarsToType(evaluator, entry.lowerBound, constraintSet)
: undefined;
const newUpperBound = entry.upperBound
? applyUnificationVarsToType(evaluator, entry.upperBound, constraintSet)
: undefined;

constraintSet.setBounds(entry.typeVar, newLowerBound, newUpperBound, entry.retainLiterals);
}
});
});
}

function stripLiteralsForLowerBound(evaluator: TypeEvaluator, typeVar: TypeVarType, lowerBound: Type) {
return isTypeVarTuple(typeVar)
? stripLiteralValueForUnpackedTuple(evaluator, lowerBound)
Expand Down Expand Up @@ -1307,25 +1343,6 @@ function assignParamSpec(
return isAssignable;
}

class UnificationVarTransformer extends TypeVarTransformer {
constructor(private _evaluator: TypeEvaluator, private _constraintSet: ConstraintSet) {
super();
}

override transformTypeVar(typeVar: TypeVarType) {
if (TypeVarType.isUnification(typeVar)) {
return getTypeVarType(this._evaluator, this._constraintSet, typeVar) ?? typeVar;
}

return undefined;
}
}

function applyUnificationVarsToType(evaluator: TypeEvaluator, type: Type, constraintSet: ConstraintSet): Type {
const transformer = new UnificationVarTransformer(evaluator, constraintSet);
return transformer.apply(type, 0);
}

// For normal TypeVars, the constraint solver can widen a type by combining
// two otherwise incompatible types into a union. For TypeVarTuples, we need
// to do the equivalent operation for unpacked tuples.
Expand Down
30 changes: 10 additions & 20 deletions packages/pyright-internal/src/analyzer/constructors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -597,17 +597,12 @@ function applyExpectedSubtypeForConstructor(
expectedSubtype: Type,
constraints: ConstraintTracker
): Type | undefined {
const specializedType = evaluator.solveAndApplyConstraints(
ClassType.cloneAsInstance(type),
constraints,
{
replaceUnsolved: {
scopeIds: [],
tupleClassType: evaluator.getTupleClassType(),
},
const specializedType = evaluator.solveAndApplyConstraints(ClassType.cloneAsInstance(type), constraints, {
replaceUnsolved: {
scopeIds: [],
tupleClassType: evaluator.getTupleClassType(),
},
{ applyUnificationVars: true }
);
});

if (!evaluator.assignType(expectedSubtype, specializedType)) {
return undefined;
Expand All @@ -634,17 +629,12 @@ function applyExpectedTypeForConstructor(
// If this isn't a generic type or it's a type that has already been
// explicitly specialized, the expected type isn't applicable.
if (type.shared.typeParams.length === 0 || type.priv.typeArgs) {
return evaluator.solveAndApplyConstraints(
ClassType.cloneAsInstance(type),
constraints,
{
replaceUnsolved: {
scopeIds: [],
tupleClassType: evaluator.getTupleClassType(),
},
return evaluator.solveAndApplyConstraints(ClassType.cloneAsInstance(type), constraints, {
replaceUnsolved: {
scopeIds: [],
tupleClassType: evaluator.getTupleClassType(),
},
{ applyUnificationVars: true }
);
});
}

if (inferenceContext) {
Expand Down
54 changes: 17 additions & 37 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ import {
import {
addConstraintsForExpectedType,
applySourceSolutionToConstraints,
applyUnificationVars,
assignTypeVar,
solveConstraintSet,
solveConstraints,
Expand Down Expand Up @@ -2095,10 +2094,6 @@ export function createTypeEvaluator(
applyOptions?: ApplyTypeVarOptions,
solveOptions?: SolveConstraintsOptions
): Type {
if (solveOptions?.applyUnificationVars) {
applyUnificationVars(evaluatorInterface, constraints);
}

const solution = solveConstraints(evaluatorInterface, constraints, solveOptions);
return applySolvedTypeVars(type, solution, applyOptions);
}
Expand Down Expand Up @@ -11605,19 +11600,14 @@ export function createTypeEvaluator(
eliminateUnsolvedInUnions = false;
}

let specializedReturnType = solveAndApplyConstraints(
returnType,
constraints,
{
replaceUnsolved: {
scopeIds: getTypeVarScopeIds(type),
unsolvedExemptTypeVars: getUnknownExemptTypeVarsForReturnType(type, returnType),
tupleClassType: getTupleClassType(),
eliminateUnsolvedInUnions,
},
let specializedReturnType = solveAndApplyConstraints(returnType, constraints, {
replaceUnsolved: {
scopeIds: getTypeVarScopeIds(type),
unsolvedExemptTypeVars: getUnknownExemptTypeVarsForReturnType(type, returnType),
tupleClassType: getTupleClassType(),
eliminateUnsolvedInUnions,
},
{ applyUnificationVars: true }
);
});
specializedReturnType = addConditionToType(specializedReturnType, typeCondition);

// If the function includes a ParamSpec and the captured signature(s) includes
Expand Down Expand Up @@ -14080,17 +14070,12 @@ export function createTypeEvaluator(
}

return mapSubtypes(
solveAndApplyConstraints(
inferenceContext.expectedType,
constraints,
{
replaceUnsolved: {
scopeIds: [],
tupleClassType: getTupleClassType(),
},
solveAndApplyConstraints(inferenceContext.expectedType, constraints, {
replaceUnsolved: {
scopeIds: [],
tupleClassType: getTupleClassType(),
},
{ applyUnificationVars: true }
),
}),
(subtype) => {
if (entryTypes.length !== 1) {
return subtype;
Expand Down Expand Up @@ -14384,17 +14369,12 @@ export function createTypeEvaluator(
if (
assignType(expectedReturnType, returnTypeResult.type, /* diag */ undefined, constraints)
) {
functionType = solveAndApplyConstraints(
functionType,
constraints,
{
replaceUnsolved: {
scopeIds: [],
tupleClassType: getTupleClassType(),
},
functionType = solveAndApplyConstraints(functionType, constraints, {
replaceUnsolved: {
scopeIds: [],
tupleClassType: getTupleClassType(),
},
{ applyUnificationVars: true }
) as FunctionType;
}) as FunctionType;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ export interface ClassMemberLookup {
}

export interface SolveConstraintsOptions {
applyUnificationVars?: boolean;
useLowerBoundOnly?: boolean;
}

Expand Down

0 comments on commit 35ab773

Please sign in to comment.