diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts index 3e065a724..7aea9d873 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts @@ -1,6 +1,7 @@ import { getContainerOfType } from 'langium'; import type { SafeDsClasses } from '../builtins/safe-ds-classes.js'; -import { isSdsEnum, SdsDeclaration } from '../generated/ast.js'; +import { isSdsEnum, type SdsAbstractResult, SdsDeclaration } from '../generated/ast.js'; +import { getParameters } from '../helpers/nodeProperties.js'; import { BooleanConstant, Constant, @@ -16,6 +17,7 @@ import { EnumType, EnumVariantType, LiteralType, + NamedTupleEntry, NamedTupleType, StaticType, Type, @@ -24,16 +26,19 @@ import { } from './model.js'; import { SafeDsClassHierarchy } from './safe-ds-class-hierarchy.js'; import { SafeDsCoreTypes } from './safe-ds-core-types.js'; +import type { SafeDsTypeComputer } from './safe-ds-type-computer.js'; export class SafeDsTypeChecker { private readonly builtinClasses: SafeDsClasses; private readonly classHierarchy: SafeDsClassHierarchy; private readonly coreTypes: SafeDsCoreTypes; + private readonly typeComputer: () => SafeDsTypeComputer; constructor(services: SafeDsServices) { this.builtinClasses = services.builtins.Classes; this.classHierarchy = services.types.ClassHierarchy; this.coreTypes = services.types.CoreTypes; + this.typeComputer = () => services.types.TypeComputer; } /** @@ -218,8 +223,48 @@ export class SafeDsTypeChecker { } } - private staticTypeIsAssignableTo(type: Type, other: Type): boolean { - return type.equals(other); + private staticTypeIsAssignableTo(type: StaticType, other: Type): boolean { + if (other instanceof CallableType) { + return this.isAssignableTo(this.associatedCallableTypeForStaticType(type), other); + } else { + return type.equals(other); + } + } + + private associatedCallableTypeForStaticType(type: StaticType): Type { + const instanceType = type.instanceType; + if (instanceType instanceof ClassType) { + const declaration = instanceType.declaration; + if (!declaration.parameterList) { + return UnknownType; + } + + const parameterEntries = new NamedTupleType( + ...getParameters(declaration).map( + (it) => new NamedTupleEntry(it, it.name, this.typeComputer().computeType(it)), + ), + ); + const resultEntries = new NamedTupleType( + new NamedTupleEntry(undefined, 'instance', instanceType), + ); + + return new CallableType(declaration, parameterEntries, resultEntries); + } else if (instanceType instanceof EnumVariantType) { + const declaration = instanceType.declaration; + + const parameterEntries = new NamedTupleType( + ...getParameters(declaration).map( + (it) => new NamedTupleEntry(it, it.name, this.typeComputer().computeType(it)), + ), + ); + const resultEntries = new NamedTupleType( + new NamedTupleEntry(undefined, 'instance', instanceType), + ); + + return new CallableType(declaration, parameterEntries, resultEntries); + } else { + return UnknownType; + } } private unionTypeIsAssignableTo(type: UnionType, other: Type): boolean { diff --git a/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts b/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts index d1e8e5b6b..8fd522e2b 100644 --- a/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts +++ b/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts @@ -19,6 +19,8 @@ import { } from '../../../src/language/partialEvaluation/model.js'; import { ClassType, + EnumType, + EnumVariantType, LiteralType, NamedTupleEntry, NamedTupleType, @@ -45,13 +47,15 @@ const code = ` fun func8() -> (s: Int) fun func9() -> (r: Any) fun func10() -> (r: String) + fun func11() -> (r: Class1) + fun func12() -> (r: Enum1) - class Class1 - class Class2 sub Class1 + class Class1(p: Int) + class Class2() sub Class1 class Class3 enum Enum1 { - Variant1 + Variant1(p: Int) Variant2 } enum Enum2 @@ -68,6 +72,8 @@ const callableType7 = typeComputer.computeType(functions[6]); const callableType8 = typeComputer.computeType(functions[7]); const callableType9 = typeComputer.computeType(functions[8]); const callableType10 = typeComputer.computeType(functions[9]); +const callableType11 = typeComputer.computeType(functions[10]); +const callableType12 = typeComputer.computeType(functions[11]); const classes = getModuleMembers(module).filter(isSdsClass); const class1 = classes[0]; @@ -80,14 +86,14 @@ const classType3 = typeComputer.computeType(class3) as ClassType; const enums = getModuleMembers(module).filter(isSdsEnum); const enum1 = enums[0]; const enum2 = enums[1]; -const enumType1 = typeComputer.computeType(enum1); -const enumType2 = typeComputer.computeType(enum2); +const enumType1 = typeComputer.computeType(enum1) as EnumType; +const enumType2 = typeComputer.computeType(enum2) as EnumType; const enumVariants = streamAllContents(module).filter(isSdsEnumVariant).toArray(); const enumVariant1 = enumVariants[0]; const enumVariant2 = enumVariants[1]; -const enumVariantType1 = typeComputer.computeType(enumVariant1); -const enumVariantType2 = typeComputer.computeType(enumVariant2); +const enumVariantType1 = typeComputer.computeType(enumVariant1) as EnumVariantType; +const enumVariantType2 = typeComputer.computeType(enumVariant2) as EnumVariantType; describe('SafeDsTypeChecker', async () => { const testCases: IsAssignableToTest[] = [ @@ -521,6 +527,47 @@ describe('SafeDsTypeChecker', async () => { type2: enumType1, expected: false, }, + // Static type to callable type + { + type1: new StaticType(classType1), + type2: callableType1, + expected: false, + }, + { + type1: new StaticType(classType1), + type2: callableType3, + expected: true, + }, + { + type1: new StaticType(classType2), + type2: callableType11, + expected: true, + }, + { + type1: new StaticType(classType3), + type2: callableType1, + expected: false, + }, + { + type1: new StaticType(enumType1), + type2: callableType1, + expected: false, + }, + { + type1: new StaticType(enumVariantType1), + type2: callableType1, + expected: false, + }, + { + type1: new StaticType(enumVariantType1), + type2: callableType3, + expected: true, + }, + { + type1: new StaticType(enumVariantType2), + type2: callableType12, + expected: true, + }, // Static type to static type { type1: new StaticType(classType1),