diff --git a/packages/safe-ds-lang/src/language/partialEvaluation/model.ts b/packages/safe-ds-lang/src/language/partialEvaluation/model.ts index 122befc25..70ab9e587 100644 --- a/packages/safe-ds-lang/src/language/partialEvaluation/model.ts +++ b/packages/safe-ds-lang/src/language/partialEvaluation/model.ts @@ -1,17 +1,20 @@ -import { stream } from 'langium'; +import { type NamedAstNode, stream } from 'langium'; import { isEmpty } from '../../helpers/collectionUtils.js'; import { isSdsAbstractResult, - SdsAbstractResult, - SdsBlockLambdaResult, + type SdsAbstractResult, + type SdsBlockLambda, + type SdsBlockLambdaResult, + type SdsCallable, type SdsDeclaration, - SdsEnumVariant, - SdsExpression, - SdsParameter, - SdsReference, - SdsResult, + type SdsEnumVariant, + type SdsExpression, + type SdsExpressionLambda, + type SdsLambda, + type SdsParameter, + type SdsReference, } from '../generated/ast.js'; -import { getParameters } from '../helpers/nodeProperties.js'; +import { getParameters, streamBlockLambdaResults } from '../helpers/nodeProperties.js'; export type ParameterSubstitutions = Map; export type ResultSubstitutions = Map; @@ -139,20 +142,27 @@ export const isConstant = (node: EvaluatedNode): node is Constant => { }; // ------------------------------------------------------------------------------------------------- -// Closures +// Callables // ------------------------------------------------------------------------------------------------- -export abstract class Closure extends EvaluatedNode { +export abstract class EvaluatedCallable extends EvaluatedNode { + abstract readonly callable: T; override readonly isFullyEvaluated: boolean = false; +} + +export abstract class Closure extends EvaluatedCallable { abstract readonly substitutionsOnCreation: ParameterSubstitutions; } -export class BlockLambdaClosure extends Closure { +export class BlockLambdaClosure extends Closure { + readonly results: SdsBlockLambdaResult[]; + constructor( + override readonly callable: SdsBlockLambda, override readonly substitutionsOnCreation: ParameterSubstitutions, - readonly results: SdsBlockLambdaResult[], ) { super(); + this.results = streamBlockLambdaResults(callable).toArray(); } override equals(other: unknown): boolean { @@ -163,9 +173,8 @@ export class BlockLambdaClosure extends Closure { } return ( - this.results.length === other.results.length && - substitutionsAreEqual(this.substitutionsOnCreation, other.substitutionsOnCreation) && - this.results.every((thisResult, i) => thisResult === other.results[i]) + this.callable === other.callable && + substitutionsAreEqual(this.substitutionsOnCreation, other.substitutionsOnCreation) ); } @@ -174,12 +183,15 @@ export class BlockLambdaClosure extends Closure { } } -export class ExpressionLambdaClosure extends Closure { +export class ExpressionLambdaClosure extends Closure { + readonly result: SdsExpression; + constructor( + override readonly callable: SdsExpressionLambda, override readonly substitutionsOnCreation: ParameterSubstitutions, - readonly result: SdsExpression, ) { super(); + this.result = callable.result; } override equals(other: unknown): boolean { @@ -190,7 +202,7 @@ export class ExpressionLambdaClosure extends Closure { } return ( - this.result === other.result && + this.callable === other.callable && substitutionsAreEqual(this.substitutionsOnCreation, other.substitutionsOnCreation) ); } @@ -200,28 +212,19 @@ export class ExpressionLambdaClosure extends Closure { } } -export class SegmentClosure extends Closure { - override readonly substitutionsOnCreation = new Map(); +export class NamedCallable extends EvaluatedCallable { + override readonly isFullyEvaluated: boolean = false; - constructor(readonly results: SdsResult[]) { + constructor(override readonly callable: T) { super(); } override equals(other: unknown): boolean { - if (other === this) { - return true; - } else if (!(other instanceof SegmentClosure)) { - return false; - } - - return ( - this.results.length === other.results.length && - this.results.every((thisResult, i) => thisResult === other.results[i]) - ); + return other instanceof NamedCallable && this.callable === other.callable; } override toString(): string { - return `$SegmentClosure`; + return this.callable.name; } } diff --git a/packages/safe-ds-lang/src/language/partialEvaluation/safe-ds-partial-evaluator.ts b/packages/safe-ds-lang/src/language/partialEvaluation/safe-ds-partial-evaluator.ts index 99a006141..1c63bb134 100644 --- a/packages/safe-ds-lang/src/language/partialEvaluation/safe-ds-partial-evaluator.ts +++ b/packages/safe-ds-lang/src/language/partialEvaluation/safe-ds-partial-evaluator.ts @@ -19,15 +19,14 @@ import { isSdsParenthesizedExpression, isSdsPrefixOperation, isSdsReference, + isSdsSegment, isSdsString, isSdsTemplateString, isSdsTemplateStringEnd, isSdsTemplateStringInner, isSdsTemplateStringStart, - SdsBlockLambda, SdsCall, SdsExpression, - SdsExpressionLambda, SdsIndexedAccess, SdsInfixOperation, SdsList, @@ -41,15 +40,18 @@ import { getArguments, getParameters } from '../helpers/nodeProperties.js'; import { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js'; import { SafeDsServices } from '../safe-ds-module.js'; import { + BlockLambdaClosure, BooleanConstant, EvaluatedEnumVariant, EvaluatedList, EvaluatedMap, EvaluatedMapEntry, EvaluatedNode, + ExpressionLambdaClosure, FloatConstant, IntConstant, isConstant, + NamedCallable, NullConstant, NumberConstant, ParameterSubstitutions, @@ -71,12 +73,11 @@ export class SafeDsPartialEvaluator { } evaluate(node: AstNode | undefined): EvaluatedNode { - return this.cachedDoEvaluate(node, NO_SUBSTITUTIONS)?.unwrap(); + return this.evaluateWithSubstitutions(node, NO_SUBSTITUTIONS)?.unwrap(); } - private cachedDoEvaluate(node: AstNode | undefined, substitutions: ParameterSubstitutions): EvaluatedNode { - // Only expressions can be evaluated at the moment - if (!isSdsExpression(node)) { + private evaluateWithSubstitutions(node: AstNode | undefined, substitutions: ParameterSubstitutions): EvaluatedNode { + if (!node) { return UnknownEvaluatedNode; } @@ -84,16 +85,31 @@ export class SafeDsPartialEvaluator { const documentUri = getDocument(node).uri.toString(); const nodePath = this.astNodeLocator.getAstNodePath(node); const key = `${documentUri}~${nodePath}`; - const resultWithoutSubstitutions = this.cache.get(key, () => this.doEvaluate(node, NO_SUBSTITUTIONS)); + const resultWithoutSubstitutions = this.cache.get(key, () => + this.doEvaluateWithSubstitutions(node, NO_SUBSTITUTIONS), + ); if (resultWithoutSubstitutions.isFullyEvaluated || isEmpty(substitutions)) { return resultWithoutSubstitutions; } /* c8 ignore start */ else { // Try again with parameter substitutions but don't cache the result - return this.doEvaluate(node, substitutions); + return this.doEvaluateWithSubstitutions(node, substitutions); } /* c8 ignore stop */ } - private doEvaluate(node: SdsExpression, substitutions: ParameterSubstitutions): EvaluatedNode { + private doEvaluateWithSubstitutions( + node: AstNode | undefined, + substitutions: ParameterSubstitutions, + ): EvaluatedNode { + if (isSdsExpression(node)) { + return this.evaluateExpression(node, substitutions); + } else if (isSdsSegment(node)) { + return new NamedCallable(node); + } else { + return UnknownEvaluatedNode; + } + } + + private evaluateExpression(node: SdsExpression, substitutions: ParameterSubstitutions): EvaluatedNode { // Base cases if (isSdsBoolean(node)) { return new BooleanConstant(node.value); @@ -112,14 +128,14 @@ export class SafeDsPartialEvaluator { } else if (isSdsTemplateStringEnd(node)) { return new StringConstant(node.value); } else if (isSdsBlockLambda(node)) { - return this.evaluateBlockLambda(node, substitutions); + return new BlockLambdaClosure(node, substitutions); } else if (isSdsExpressionLambda(node)) { - return this.evaluateExpressionLambda(node, substitutions); + return new ExpressionLambdaClosure(node, substitutions); } // Recursive cases else if (isSdsArgument(node)) { - return this.cachedDoEvaluate(node.value, substitutions); + return this.evaluateWithSubstitutions(node.value, substitutions); } else if (isSdsCall(node)) { return this.evaluateCall(node, substitutions); } else if (isSdsIndexedAccess(node)) { @@ -133,7 +149,7 @@ export class SafeDsPartialEvaluator { } else if (isSdsMemberAccess(node)) { return this.evaluateMemberAccess(node, substitutions); } else if (isSdsParenthesizedExpression(node)) { - return this.cachedDoEvaluate(node.expression, substitutions); + return this.evaluateWithSubstitutions(node.expression, substitutions); } else if (isSdsPrefixOperation(node)) { return this.evaluatePrefixOperation(node, substitutions); } else if (isSdsReference(node)) { @@ -141,45 +157,18 @@ export class SafeDsPartialEvaluator { } else if (isSdsTemplateString(node)) { return this.evaluateTemplateString(node, substitutions); } /* c8 ignore start */ else { - throw new Error(`Unexpected node type: ${node.$type}`); + throw new Error(`Unexpected expression type: ${node.$type}`); } /* c8 ignore stop */ } - private evaluateBlockLambda(_node: SdsBlockLambda, _substitutions: ParameterSubstitutions): EvaluatedNode { - // return when { - // callableHasNoSideEffects(resultIfUnknown = true) -> SdsIntermediateBlockLambda( - // parameters = parametersOrEmpty(), - // results = blockLambdaResultsOrEmpty(), - // substitutionsOnCreation = substitutions - // ) - // else -> undefined - // } - return UnknownEvaluatedNode; - } - - private evaluateExpressionLambda( - _node: SdsExpressionLambda, - _substitutions: ParameterSubstitutions, - ): EvaluatedNode { - // return when { - // callableHasNoSideEffects(resultIfUnknown = true) -> SdsIntermediateExpressionLambda( - // parameters = parametersOrEmpty(), - // result = result, - // substitutionsOnCreation = substitutions - // ) - // else -> undefined - // } - return UnknownEvaluatedNode; - } - private evaluateInfixOperation(node: SdsInfixOperation, substitutions: ParameterSubstitutions): EvaluatedNode { // By design none of the operators are short-circuited - const evaluatedLeft = this.cachedDoEvaluate(node.leftOperand, substitutions); + const evaluatedLeft = this.evaluateWithSubstitutions(node.leftOperand, substitutions); if (evaluatedLeft === UnknownEvaluatedNode) { return UnknownEvaluatedNode; } - const evaluatedRight = this.cachedDoEvaluate(node.rightOperand, substitutions); + const evaluatedRight = this.evaluateWithSubstitutions(node.rightOperand, substitutions); if (evaluatedRight === UnknownEvaluatedNode) { return UnknownEvaluatedNode; } @@ -314,22 +303,22 @@ export class SafeDsPartialEvaluator { private evaluateList(node: SdsList, substitutions: ParameterSubstitutions): EvaluatedNode { // TODO: if any entry has side effects, return UnknownEvaluatedNode - return new EvaluatedList(node.elements.map((it) => this.cachedDoEvaluate(it, substitutions))); + return new EvaluatedList(node.elements.map((it) => this.evaluateWithSubstitutions(it, substitutions))); } private evaluateMap(node: SdsMap, substitutions: ParameterSubstitutions): EvaluatedNode { // TODO: if any entry has side effects, return UnknownEvaluatedNode return new EvaluatedMap( node.entries.map((it) => { - const key = this.cachedDoEvaluate(it.key, substitutions); - const value = this.cachedDoEvaluate(it.value, substitutions); + const key = this.evaluateWithSubstitutions(it.key, substitutions); + const value = this.evaluateWithSubstitutions(it.value, substitutions); return new EvaluatedMapEntry(key, value); }), ); } private evaluatePrefixOperation(node: SdsPrefixOperation, substitutions: ParameterSubstitutions): EvaluatedNode { - const evaluatedOperand = this.cachedDoEvaluate(node.operand, substitutions); + const evaluatedOperand = this.evaluateWithSubstitutions(node.operand, substitutions); if (evaluatedOperand === UnknownEvaluatedNode) { return UnknownEvaluatedNode; } @@ -350,7 +339,7 @@ export class SafeDsPartialEvaluator { } private evaluateTemplateString(node: SdsTemplateString, substitutions: ParameterSubstitutions): EvaluatedNode { - const expressions = node.expressions.map((it) => this.cachedDoEvaluate(it, substitutions)); + const expressions = node.expressions.map((it) => this.evaluateWithSubstitutions(it, substitutions)); if (expressions.every(isConstant)) { return new StringConstant(expressions.map((it) => it.toInterpolationString()).join('')); } @@ -359,7 +348,7 @@ export class SafeDsPartialEvaluator { } private evaluateCall(node: SdsCall, substitutions: ParameterSubstitutions): EvaluatedNode { - const receiver = this.cachedDoEvaluate(node.receiver, substitutions).unwrap(); + const receiver = this.evaluateWithSubstitutions(node.receiver, substitutions).unwrap(); if (receiver instanceof EvaluatedEnumVariant) { // The enum variant has already been instantiated @@ -370,15 +359,17 @@ export class SafeDsPartialEvaluator { // Store default values for all parameters const args = new Map( getParameters(receiver.variant).map((it) => { - return [it, this.cachedDoEvaluate(it.defaultValue, NO_SUBSTITUTIONS)]; + // TODO We may refer to other parameters in the default value, so we must pass substitutions + return [it, this.evaluateWithSubstitutions(it.defaultValue, NO_SUBSTITUTIONS)]; }), ); // Override default values with the actual arguments + // TODO: If any argument has side effects, return UnknownEvaluatedNode getArguments(node).forEach((it) => { const parameter = this.nodeMapper.argumentToParameter(it); if (parameter && args.has(parameter)) { - args.set(parameter, this.cachedDoEvaluate(it.value, substitutions)); + args.set(parameter, this.evaluateWithSubstitutions(it.value, substitutions)); } }); @@ -387,7 +378,7 @@ export class SafeDsPartialEvaluator { // val simpleReceiver = evaluateReceiver(substitutions) ?: return undefined // val newSubstitutions = buildNewSubstitutions(simpleReceiver, substitutions) - // + // // TODO Also check whether the callable has side effects // return when (simpleReceiver) { // is SdsIntermediateBlockLambda -> { // SdsIntermediateRecord( @@ -442,15 +433,15 @@ export class SafeDsPartialEvaluator { // } private evaluateIndexedAccess(node: SdsIndexedAccess, substitutions: ParameterSubstitutions): EvaluatedNode { - const receiver = this.cachedDoEvaluate(node.receiver, substitutions).unwrap(); + const receiver = this.evaluateWithSubstitutions(node.receiver, substitutions).unwrap(); if (receiver instanceof EvaluatedList) { - const index = this.cachedDoEvaluate(node.index, substitutions).unwrap(); + const index = this.evaluateWithSubstitutions(node.index, substitutions).unwrap(); if (index instanceof IntConstant) { return receiver.getElementByIndex(Number(index.value)); } } else if (receiver instanceof EvaluatedMap) { - const key = this.cachedDoEvaluate(node.index, substitutions).unwrap(); + const key = this.evaluateWithSubstitutions(node.index, substitutions).unwrap(); return receiver.getLastValueForKey(key); } @@ -475,11 +466,11 @@ export class SafeDsPartialEvaluator { } private evaluateReference(_node: SdsReference, _substitutions: ParameterSubstitutions): EvaluatedNode { + // TODO: always call evaluateWithSubstitutions so caching works // const target = node.target.ref; // is SdsPlaceholder -> declaration.evaluateAssignee(substitutions) // is SdsParameter -> declaration.evaluateParameter(substitutions) - // is SdsStep -> declaration.evaluateStep() // else -> undefined // } return UnknownEvaluatedNode; @@ -507,16 +498,6 @@ export class SafeDsPartialEvaluator { // else -> undefined // } // } - // - // private fun SdsStep.evaluateStep(): SdsIntermediateStep? { - // return when { - // callableHasNoSideEffects(resultIfUnknown = true) -> SdsIntermediateStep( - // parameters = parametersOrEmpty(), - // results = resultsOrEmpty() - // ) - // else -> undefined - // } - // } } const NO_SUBSTITUTIONS: ParameterSubstitutions = new Map(); diff --git a/packages/safe-ds-lang/tests/language/partialEvaluation/model.test.ts b/packages/safe-ds-lang/tests/language/partialEvaluation/model.test.ts index 32a77ade8..9f91824d7 100644 --- a/packages/safe-ds-lang/tests/language/partialEvaluation/model.test.ts +++ b/packages/safe-ds-lang/tests/language/partialEvaluation/model.test.ts @@ -1,10 +1,12 @@ import { EmptyFileSystem } from 'langium'; import { describe, expect, it } from 'vitest'; import { + isSdsBlockLambda, isSdsEnumVariant, isSdsExpressionLambda, isSdsReference, isSdsResult, + isSdsSegment, type SdsBlockLambdaResult, } from '../../../src/language/generated/ast.js'; import { getParameters } from '../../../src/language/helpers/nodeProperties.js'; @@ -22,8 +24,8 @@ import { ExpressionLambdaClosure, FloatConstant, IntConstant, + NamedCallable, NullConstant, - SegmentClosure, StringConstant, UnknownEvaluatedNode, } from '../../../src/language/partialEvaluation/model.js'; @@ -42,18 +44,28 @@ segment mySegment() -> (result1: Int, result2: Int) { (() { yield a; })().a; MyEnum; } + +segment mySegment2() {} `; const enumVariantWithoutParameters = await getNodeOfType(services, code, isSdsEnumVariant, 0); const enumVariantWithParameters = await getNodeOfType(services, code, isSdsEnumVariant, 1); +const enumVariantParameter = getParameters(enumVariantWithParameters)[0]!; + const result1 = await getNodeOfType(services, code, isSdsResult, 0); const result2 = await getNodeOfType(services, code, isSdsResult, 0); -const enumVariantParameter = getParameters(enumVariantWithParameters)[0]!; -const expressionLambdaResult1 = (await getNodeOfType(services, code, isSdsExpressionLambda, 0)).result; -const expressionLambdaResult2 = (await getNodeOfType(services, code, isSdsExpressionLambda, 1)).result; + +const expressionLambda1 = await getNodeOfType(services, code, isSdsExpressionLambda, 0); +const expressionLambda2 = await getNodeOfType(services, code, isSdsExpressionLambda, 1); + const reference1 = await getNodeOfType(services, code, isSdsReference, 0); const reference2 = await getNodeOfType(services, code, isSdsReference, 1); + +const blockLambda1 = await getNodeOfType(services, code, isSdsBlockLambda, 0); const blockLambdaResult1 = reference1.target.ref as SdsBlockLambdaResult; +const segment1 = await getNodeOfType(services, code, isSdsSegment, 0); +const segment2 = await getNodeOfType(services, code, isSdsSegment, 1); + describe('partial evaluation model', async () => { const equalsTests: EqualsTest[] = [ { @@ -81,19 +93,18 @@ describe('partial evaluation model', async () => { nodeOfOtherType: () => NullConstant, }, { - node: () => new BlockLambdaClosure(new Map([[enumVariantParameter, NullConstant]]), []), - unequalNodeOfSameType: () => new BlockLambdaClosure(new Map(), []), + node: () => new BlockLambdaClosure(blockLambda1, new Map([[enumVariantParameter, NullConstant]])), + unequalNodeOfSameType: () => new BlockLambdaClosure(blockLambda1, new Map()), nodeOfOtherType: () => NullConstant, }, { - node: () => - new ExpressionLambdaClosure(new Map([[enumVariantParameter, NullConstant]]), expressionLambdaResult1), - unequalNodeOfSameType: () => new ExpressionLambdaClosure(new Map(), expressionLambdaResult2), + node: () => new ExpressionLambdaClosure(expressionLambda1, new Map([[enumVariantParameter, NullConstant]])), + unequalNodeOfSameType: () => new ExpressionLambdaClosure(expressionLambda2, new Map()), nodeOfOtherType: () => NullConstant, }, { - node: () => new SegmentClosure([result1]), - unequalNodeOfSameType: () => new SegmentClosure([]), + node: () => new NamedCallable(segment1), + unequalNodeOfSameType: () => new NamedCallable(segment2), nodeOfOtherType: () => NullConstant, }, { @@ -171,16 +182,16 @@ describe('partial evaluation model', async () => { expectedString: '"foo"', }, { - node: new BlockLambdaClosure(new Map(), []), + node: new BlockLambdaClosure(blockLambda1, new Map()), expectedString: '$BlockLambdaClosure', }, { - node: new ExpressionLambdaClosure(new Map(), expressionLambdaResult1), + node: new ExpressionLambdaClosure(expressionLambda1, new Map()), expectedString: '$ExpressionLambdaClosure', }, { - node: new SegmentClosure([]), - expectedString: '$SegmentClosure', + node: new NamedCallable(segment1), + expectedString: 'mySegment', }, { node: new EvaluatedEnumVariant(enumVariantWithoutParameters, undefined), @@ -267,15 +278,15 @@ describe('partial evaluation model', async () => { expectedValue: true, }, { - node: new BlockLambdaClosure(new Map(), []), + node: new BlockLambdaClosure(blockLambda1, new Map()), expectedValue: false, }, { - node: new ExpressionLambdaClosure(new Map(), expressionLambdaResult1), + node: new ExpressionLambdaClosure(expressionLambda1, new Map()), expectedValue: false, }, { - node: new SegmentClosure([]), + node: new NamedCallable(segment1), expectedValue: false, }, { diff --git a/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/block lambdas/main.sdstest b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/block lambdas/main.sdstest new file mode 100644 index 000000000..e00137974 --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/block lambdas/main.sdstest @@ -0,0 +1,6 @@ +package tests.partialValidation.baseCases.expressionLambdas + +pipeline test { + // $TEST$ serialization $ExpressionLambdaClosure + »() -> 1«; +} diff --git a/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/expression lambdas/main.sdstest b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/expression lambdas/main.sdstest new file mode 100644 index 000000000..8b3f1fdf6 --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/expression lambdas/main.sdstest @@ -0,0 +1,6 @@ +package tests.partialValidation.baseCases.blockLambdas + +pipeline test { + // $TEST$ serialization $BlockLambdaClosure + »() {}«; +} diff --git a/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/segments/main.sdstest b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/segments/main.sdstest new file mode 100644 index 000000000..fc5b47348 --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/partial evaluation/base cases/segments/main.sdstest @@ -0,0 +1,4 @@ +package tests.partialValidation.baseCases.segments + +// $TEST$ serialization mySegment +»segment mySegment() {}« diff --git a/packages/safe-ds-lang/tests/resources/partial evaluation/invalid nodes/main.sdstest b/packages/safe-ds-lang/tests/resources/partial evaluation/invalid nodes/main.sdstest new file mode 100644 index 000000000..009268ccc --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/partial evaluation/invalid nodes/main.sdstest @@ -0,0 +1,4 @@ +package tests.partialValidation.invalidNodes + +// $TEST$ serialization ? +»class C«