From 114fee68b57e12371d574d0a5f8944def7e61813 Mon Sep 17 00:00:00 2001 From: WinPlay02 Date: Thu, 7 Mar 2024 21:25:51 +0100 Subject: [PATCH] fix: do not memoize calls containing lambdas calling segments (#944) - do not memoize calls, if lambdas are referencing segment code to ensure correctness --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann --- .../generation/safe-ds-python-generator.ts | 34 +++++++++++++++++-- .../callWithRunnerIntegration/gen_input.py | 23 +++++++++++++ .../gen_input.py.map | 2 +- .../expressions/call/input.sdstest | 18 ++++++++++ 4 files changed, 74 insertions(+), 3 deletions(-) diff --git a/packages/safe-ds-lang/src/language/generation/safe-ds-python-generator.ts b/packages/safe-ds-lang/src/language/generation/safe-ds-python-generator.ts index b283b12e3..774b4364d 100644 --- a/packages/safe-ds-lang/src/language/generation/safe-ds-python-generator.ts +++ b/packages/safe-ds-lang/src/language/generation/safe-ds-python-generator.ts @@ -1,4 +1,4 @@ -import { AstUtils, LangiumDocument, TreeStreamImpl, URI } from 'langium'; +import { AstNode, AstUtils, LangiumDocument, TreeStreamImpl, URI } from 'langium'; import { CompositeGeneratorNode, expandToNode, @@ -16,6 +16,7 @@ import { TextDocument } from 'vscode-languageserver-textdocument'; import { groupBy, isEmpty } from '../../helpers/collections.js'; import { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js'; import { + isSdsAbstractCall, isSdsAbstractResult, isSdsAssignment, isSdsBlockLambda, @@ -968,7 +969,36 @@ export class SafeDsPythonGenerator { private isMemoizableCall(expression: SdsCall): boolean { const impurityReasons = this.purityComputer.getImpurityReasonsForExpression(expression); // If the file is not known, the call is not memoizable - return !impurityReasons.some((reason) => !(reason instanceof FileRead) || reason.path === undefined); + return ( + !impurityReasons.some((reason) => !(reason instanceof FileRead) || reason.path === undefined) && + !this.doesCallContainLambdaReferencingSegment(expression) + ); + } + + private doesCallContainLambdaReferencingSegment(expression: SdsCall): boolean { + return getArguments(expression).some((arg) => { + if (isSdsExpressionLambda(arg.value)) { + return this.containsSegmentCall(arg.value.result); + } else if (isSdsBlockLambda(arg.value)) { + return this.containsSegmentCall(arg.value.body); + } else { + /* c8 ignore next 2 */ + return false; + } + }); + } + + private containsSegmentCall(node: AstNode | undefined): boolean { + if (!node) { + /* c8 ignore next 2 */ + return false; + } + return AstUtils.streamAst(node) + .filter(isSdsAbstractCall) + .some((call) => { + const callable = this.nodeMapper.callToCallable(call); + return isSdsSegment(callable); + }); } private generateMemoizedCall( diff --git a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py index 7e2cfd90b..d8d51c43a 100644 --- a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py +++ b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py @@ -12,6 +12,12 @@ def __gen_null_safe_call(receiver: Any, callable: Callable[[], __gen_T]) -> __gen_T | None: return callable() if receiver is not None else None +# Segments --------------------------------------------------------------------- + +def segment_a(a): + __gen_yield_result = (a) * (2) + return __gen_yield_result + # Pipelines -------------------------------------------------------------------- def test(): @@ -32,3 +38,20 @@ def test(): __gen_null_safe_call(j, lambda: 'abc'.j(123)) __gen_null_safe_call(k, lambda: k(456, 1.23)) f(safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.readFile", readFile, [], [safeds_runner.file_mtime('a.txt')])) + f(l(lambda a: segment_a(a))) + f(l(lambda a: (3) * (segment_a(a)))) + f(l(lambda a: safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.m", m, [(3) * (segment_a(a))], []))) + f(l(lambda a: (3) * (safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.m", m, [safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.m", m, [(3) * (segment_a(a))], [])], [])))) + def __gen_block_lambda_0(a): + __gen_block_lambda_result_result = segment_a(a) + return __gen_block_lambda_result_result + f(l(__gen_block_lambda_0)) + def __gen_block_lambda_1(a): + __gen_block_lambda_result_result = safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.m", m, [segment_a(a)], []) + return __gen_block_lambda_result_result + f(l(__gen_block_lambda_1)) + f(safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.l", l, [lambda a: (3) * (a)], [])) + def __gen_block_lambda_2(a): + __gen_block_lambda_result_result = (3) * (safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.m", m, [a], [])) + return __gen_block_lambda_result_result + f(safeds_runner.memoized_call("tests.generator.callWithRunnerIntegration.l", l, [__gen_block_lambda_2], [])) diff --git a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py.map b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py.map index 0ad85b9dc..bf0e8287c 100644 --- a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py.map +++ b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/generated/tests/generator/callWithRunnerIntegration/gen_input.py.map @@ -1 +1 @@ -{"version":3,"sources":["input.sdstest"],"names":["test","f","g","param2","h","i","j","k","readfile"],"mappings":"AAAA;;;;;;;;;;;;;;;;AA6BA,IAASA,IAAI;IACTC,CAAC,CAAE,uFAAAC,CAAC,CAAC,CAAC,EAxBNC,MAAM,CAwBE,CAAC,IAAJ,CAAC,EAAE,CAAC;IACTF,CAAC,CAAE,uFAAAC,CAAC,CAAsB,CAAC,EAzB3BC,MAAM,CAyBQ,CAAC,IAAW,CAAC,EAAb,CAAC;IACfF,CAAC,CAAE,uFAAAG,CAAC,CAAC,CAAC,EArBiBD,OAAM,CAqBrB,CAAC,IAAJ,CAAC,EAAE,CAAC;IACTF,CAAC,CAAE,uFAAAG,CAAC,CAAsB,CAAC,EAtBJD,OAAM,CAsBf,CAAC,IAAW,CAAC,EAAb,CAAC;IACfF,CAAC,CAAE,2EAAAG,CAAC,GAAU,CAAC,EAvBsB,CAAC;IAwBpC,KAAK;IACL,KAAK,GAAE,GAAG;IACZ,EAAQ,GAAG,EAAT,IAAI;IAEL,oBAAC,CAAFH,CAAC,UAADA,CAAC,CAAG,uFAAAC,CAAC,CAAC,CAAC,EAjCPC,MAAM,CAiCG,CAAC,IAAJ,CAAC,EAAE,CAAC;IACT,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAC,CAAC,CAAsB,CAAC,EAlC5BC,MAAM,CAkCS,CAAC,IAAW,CAAC,EAAb,CAAC;IACf,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAG,CAAC,CAAC,CAAC,EA9BgBD,OAAM,CA8BpB,CAAC,IAAJ,CAAC,EAAE,CAAC;IACT,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAG,CAAC,CAAsB,CAAC,EA/BLD,OAAM,CA+Bd,CAAC,IAAW,CAAC,EAAb,CAAC;IACf,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,2EAAAG,CAAC,GAAU,CAAC,EAhCqB,CAAC;IAiCrC,oBAAC,CAAFC,CAAC,UAAE,KAAK;IACP,oBAAC,CAAFC,CAAC,UAAE,KAAK,GAAE,GAAG;IACZ,oBAAC,CAAFC,CAAC,UAAD,EAAS,GAAG,EAAT,IAAI;IAEPN,CAAC,CAAC,kFAAAO,QAAQ,OAAR,iCAAU","file":"gen_input.py"} \ No newline at end of file +{"version":3,"sources":["input.sdstest"],"names":["segment_a","a","result","test","f","g","param2","h","i","j","k","readfile","l","m"],"mappings":"AAAA;;;;;;;;;;;;;;;;AAiEA,IAAQA,SAAS,CAACC,CAAC;IACf,YAAMC,MAAM,GAAG,CAAAD,CAAC,EAAC,CAAC,EAAC,CAAC;IADE,OAAG;;;;AA9B7B,IAASE,IAAI;IACTC,CAAC,CAAE,uFAAAC,CAAC,CAAC,CAAC,EA9BNC,MAAM,CA8BE,CAAC,IAAJ,CAAC,EAAE,CAAC;IACTF,CAAC,CAAE,uFAAAC,CAAC,CAAsB,CAAC,EA/B3BC,MAAM,CA+BQ,CAAC,IAAW,CAAC,EAAb,CAAC;IACfF,CAAC,CAAE,uFAAAG,CAAC,CAAC,CAAC,EA3BiBD,OAAM,CA2BrB,CAAC,IAAJ,CAAC,EAAE,CAAC;IACTF,CAAC,CAAE,uFAAAG,CAAC,CAAsB,CAAC,EA5BJD,OAAM,CA4Bf,CAAC,IAAW,CAAC,EAAb,CAAC;IACfF,CAAC,CAAE,2EAAAG,CAAC,GAAU,CAAC,EA7BsB,CAAC;IA8BpC,KAAK;IACL,KAAK,GAAE,GAAG;IACZ,EAAQ,GAAG,EAAT,IAAI;IAEL,oBAAC,CAAFH,CAAC,UAADA,CAAC,CAAG,uFAAAC,CAAC,CAAC,CAAC,EAvCPC,MAAM,CAuCG,CAAC,IAAJ,CAAC,EAAE,CAAC;IACT,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAC,CAAC,CAAsB,CAAC,EAxC5BC,MAAM,CAwCS,CAAC,IAAW,CAAC,EAAb,CAAC;IACf,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAG,CAAC,CAAC,CAAC,EApCgBD,OAAM,CAoCpB,CAAC,IAAJ,CAAC,EAAE,CAAC;IACT,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,uFAAAG,CAAC,CAAsB,CAAC,EArCLD,OAAM,CAqCd,CAAC,IAAW,CAAC,EAAb,CAAC;IACf,oBAAC,CAAFF,CAAC,UAADA,CAAC,CAAG,2EAAAG,CAAC,GAAU,CAAC,EAtCqB,CAAC;IAuCrC,oBAAC,CAAFC,CAAC,UAAE,KAAK;IACP,oBAAC,CAAFC,CAAC,UAAE,KAAK,GAAE,GAAG;IACZ,oBAAC,CAAFC,CAAC,UAAD,EAAS,GAAG,EAAT,IAAI;IAEPN,CAAC,CAAC,kFAAAO,QAAQ,OAAR,iCAAU;IACZP,CAAC,CAACQ,CAAC,CAAC,OAACX,CAAC,EAAKD,SAAS,CAACC,CAAC;IACtBG,CAAC,CAACQ,CAAC,CAAC,OAACX,CAAC,EAAK,CAAA,CAAC,EAAC,CAAC,EAACD,SAAS,CAACC,CAAC;IAC1BG,CAAC,CAACQ,CAAC,CAAC,OAACX,CAAC,EAAK,2EAAAY,CAAC,GAAC,CAAA,CAAC,EAAC,CAAC,EAACb,SAAS,CAACC,CAAC;IAC5BG,CAAC,CAACQ,CAAC,CAAC,OAACX,CAAC,EAAK,CAAA,CAAC,EAAC,CAAC,EAAC,2EAAAY,CAAC,GAAC,2EAAAA,CAAC,GAAC,CAAA,CAAC,EAAC,CAAC,EAACb,SAAS,CAACC,CAAC;IAC9B,yBAACA,CAAC;QAAG,0BAAMC,MAAM,GAAGF,SAAS,CAACC,CAAC;QAA/B,OAAK,0BAAMC,MAAM;IAArBE,CAAC,CAACQ,CAAC,CAAC;IACA,yBAACX,CAAC;QAAG,0BAAMC,MAAM,GAAG,2EAAAW,CAAC,GAACb,SAAS,CAACC,CAAC;QAAjC,OAAK,0BAAMC,MAAM;IAArBE,CAAC,CAACQ,CAAC,CAAC;IACJR,CAAC,CAAC,2EAAAQ,CAAC,GAAC,OAACX,CAAC,EAAK,CAAA,CAAC,EAAC,CAAC,EAACA,CAAC;IACZ,yBAACA,CAAC;QAAG,0BAAMC,MAAM,GAAG,CAAA,CAAC,EAAC,CAAC,EAAC,2EAAAW,CAAC,GAACZ,CAAC;QAA3B,OAAK,0BAAMC,MAAM;IAArBE,CAAC,CAAC,2EAAAQ,CAAC,GAAC","file":"gen_input.py"} \ No newline at end of file diff --git a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/input.sdstest b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/input.sdstest index e020dca03..8bd19d796 100644 --- a/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/input.sdstest +++ b/packages/safe-ds-lang/tests/resources/generation/runner integration/expressions/call/input.sdstest @@ -24,6 +24,12 @@ fun j(param: Any?, param2: Any?) @PythonCall("k($param2, $param)") fun k(param: Any?, param2: Any?) +@Pure +fun l(param: (a: Int) -> result: Int) -> result: Int + +@Pure +fun m(param: Int) -> result: Int + @Impure([ImpurityReason.FileReadFromConstantPath("a.txt")]) fun readFile() -> content: String @@ -47,4 +53,16 @@ pipeline test { k?(1.23, 456); f(readFile()); + f(l((a) -> segment_a(a))); + f(l((a) -> 3 * segment_a(a))); + f(l((a) -> m(3 * segment_a(a)))); + f(l((a) -> 3 * m(m(3 * segment_a(a))))); + f(l((a) {yield result = segment_a(a); })); + f(l((a) {yield result = m(segment_a(a)); })); + f(l((a) -> 3 * a)); + f(l((a) {yield result = 3 * m(a); })); +} + +segment segment_a(a: Int) -> result: Int { + yield result = a * 2; }