Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay conversion of Truffle function body nodes until the function is invoked #3429

Merged
merged 13 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.enso.interpreter.node;

import com.oracle.truffle.api.dsl.ReportPolymorphism;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.source.SourceSection;
import java.util.function.Supplier;
import org.enso.interpreter.Language;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.scope.LocalScope;
Expand Down Expand Up @@ -44,12 +46,22 @@ private static String shortName(String atomName, String methodName) {
* @param language the language identifier
* @param localScope a description of the local scope
* @param moduleScope a description of the module scope
* @param body the program body to be executed
* @param section a mapping from {@code body} to the program source
* @param body the program provider to be executed
* @param section a mapping from {@code provider} to the program source
* @param atomConstructor the constructor this method is defined for
* @param methodName the name of this method
* @return a node representing the specified closure
*/
public static MethodRootNode build(
Language language,
LocalScope localScope,
ModuleScope moduleScope,
Supplier<ExpressionNode> body,
SourceSection section,
AtomConstructor atomConstructor,
String methodName) {
return build(language, localScope, moduleScope, new LazyBodyNode(body), section, atomConstructor, methodName);
}
public static MethodRootNode build(
Language language,
LocalScope localScope,
Expand Down Expand Up @@ -87,4 +99,20 @@ public AtomConstructor getAtomConstructor() {
public String getMethodName() {
return methodName;
}

private static class LazyBodyNode extends ExpressionNode {
private final Supplier<ExpressionNode> provider;

LazyBodyNode(Supplier<ExpressionNode> body) {
this.provider = body;
}


@Override
public Object executeGeneric(VirtualFrame frame) {
ExpressionNode newNode = replace(provider.get());
notifyInserted(newNode);
return newNode.executeGeneric(frame);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ public static BlockNode build(ExpressionNode[] expressions, ExpressionNode retur
return new BlockNode(expressions, returnExpr);
}

public static BlockNode buildSilent(ExpressionNode[] expressions, ExpressionNode returnExpr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this exist?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is used from AtomConstructor.buildConstructorFunction. Those functions are usually quite small - not sure it makes sense to make them lazy.

return new BlockNode(expressions, returnExpr);
}

/**
* Executes the body of the function.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private Function buildConstructorFunction(
ArgumentDefinition[] args) {

ExpressionNode instantiateNode = InstantiateNode.build(this, varReads);
BlockNode instantiateBlock = BlockNode.build(assignments, instantiateNode);
BlockNode instantiateBlock = BlockNode.buildSilent(assignments, instantiateNode);
RootNode rootNode =
ClosureRootNode.build(
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,18 @@ class IrToTruffle(

val function = methodDef.body match {
case fn: IR.Function =>
val (body, arguments) =
expressionProcessor.buildFunctionBody(fn.arguments, fn.body)
val bodyBuilder = new expressionProcessor.BuildFunctionBody(fn.arguments, fn.body)
val rootNode = MethodRootNode.build(
language,
expressionProcessor.scope,
moduleScope,
body,
() => bodyBuilder.bodyNode(),
makeSection(methodDef.location),
cons,
methodDef.methodName.name
)
val callTarget = Truffle.getRuntime.createCallTarget(rootNode)
val arguments = bodyBuilder.args()
new RuntimeFunction(
callTarget,
null,
Expand Down Expand Up @@ -348,18 +348,18 @@ class IrToTruffle(

val function = methodDef.body match {
case fn: IR.Function =>
val (body, arguments) =
expressionProcessor.buildFunctionBody(fn.arguments, fn.body)
val bodyBuilder = new expressionProcessor.BuildFunctionBody(fn.arguments, fn.body)
val rootNode = MethodRootNode.build(
language,
expressionProcessor.scope,
moduleScope,
body,
() => bodyBuilder.bodyNode(),
makeSection(methodDef.location),
toType,
methodDef.methodName.name
)
val callTarget = Truffle.getRuntime.createCallTarget(rootNode)
val arguments = bodyBuilder.args()
new RuntimeFunction(
callTarget,
null,
Expand Down Expand Up @@ -1187,59 +1187,74 @@ class IrToTruffle(
* @return a node for the final shape of function body and pre-processed
* argument definitions.
*/
def buildFunctionBody(
arguments: List[IR.DefinitionArgument],
body: IR.Expression
): (BlockNode, Array[ArgumentDefinition]) = {
class BuildFunctionBody(
val arguments: List[IR.DefinitionArgument],
val body: IR.Expression
) {
val argFactory = new DefinitionArgumentProcessor(scopeName, scope)
private var argDefinitions: Array[ArgumentDefinition] = null
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
private var argSlots: List[FrameSlot] = null
private var argExpressions: ArrayBuffer[RuntimeExpression] = null

val argDefinitions = new Array[ArgumentDefinition](arguments.size)
val argExpressions = new ArrayBuffer[RuntimeExpression]
val seenArgNames = mutable.Set[String]()
def args(): Array[ArgumentDefinition] = {
val (_, args, _) = slots()
args
}

// Note [Rewriting Arguments]
val argSlots =
arguments.zipWithIndex.map { case (unprocessedArg, idx) =>
val arg = argFactory.run(unprocessedArg, idx)
argDefinitions(idx) = arg
private def slots(): (List[FrameSlot],Array[ArgumentDefinition], ArrayBuffer[RuntimeExpression]) = {
if (argSlots == null) {
val seenArgNames = mutable.Set[String]()
argDefinitions = new Array[ArgumentDefinition](arguments.size)
argExpressions = new ArrayBuffer[RuntimeExpression]
// Note [Rewriting Arguments]
argSlots = arguments.zipWithIndex.map { case (unprocessedArg, idx) =>
val arg = argFactory.run(unprocessedArg, idx)
argDefinitions(idx) = arg

val occInfo = unprocessedArg
.unsafeGetMetadata(
AliasAnalysis,
"No occurrence on an argument definition."
)
.unsafeAs[AliasAnalysis.Info.Occurrence]

val occInfo = unprocessedArg
.unsafeGetMetadata(
AliasAnalysis,
"No occurrence on an argument definition."
)
.unsafeAs[AliasAnalysis.Info.Occurrence]
val slot = scope.createVarSlot(occInfo.id)
val readArg =
ReadArgumentNode.build(idx, arg.getDefaultValue.orElse(null))
val assignArg = AssignmentNode.build(readArg, slot)

val slot = scope.createVarSlot(occInfo.id)
val readArg =
ReadArgumentNode.build(idx, arg.getDefaultValue.orElse(null))
val assignArg = AssignmentNode.build(readArg, slot)
argExpressions.append(assignArg)

argExpressions.append(assignArg)
val argName = arg.getName

val argName = arg.getName
if (seenArgNames contains argName) {
throw new IllegalStateException(
s"A duplicate argument name, $argName, was found during codegen."
)
} else seenArgNames.add(argName)
slot
}
}
(argSlots, argDefinitions, argExpressions)
}

if (seenArgNames contains argName) {
throw new IllegalStateException(
s"A duplicate argument name, $argName, was found during codegen."
def bodyNode(): BlockNode = {
val (argSlots, _, argExpressions) = slots()

val bodyExpr = body match {
case IR.Foreign.Definition(lang, code, _, _, _) =>
buildForeignBody(
lang,
code,
arguments.map(_.name.name),
argSlots
)
} else seenArgNames.add(argName)
slot
case _ => ExpressionProcessor.this.run(body)
}

val bodyExpr = body match {
case IR.Foreign.Definition(lang, code, _, _, _) =>
buildForeignBody(
lang,
code,
arguments.map(_.name.name),
argSlots
)
case _ => this.run(body)
}
BlockNode.build(argExpressions.toArray, bodyExpr)

val fnBodyNode = BlockNode.build(argExpressions.toArray, bodyExpr)
(fnBodyNode, argDefinitions)
}
}

private def buildForeignBody(
Expand Down Expand Up @@ -1269,18 +1284,18 @@ class IrToTruffle(
body: IR.Expression,
location: Option[IdentifiedLocation]
): CreateFunctionNode = {
val (fnBodyNode, argDefinitions) = buildFunctionBody(arguments, body)
val bodyBuilder = new BuildFunctionBody(arguments, body)
val fnRootNode = ClosureRootNode.build(
language,
scope,
moduleScope,
fnBodyNode,
bodyBuilder.bodyNode(),
makeSection(location),
scopeName
)
val callTarget = Truffle.getRuntime.createCallTarget(fnRootNode)

val expr = CreateFunctionNode.build(callTarget, argDefinitions)
val expr = CreateFunctionNode.build(callTarget, bodyBuilder.args())

setLocation(expr, location)
}
Expand Down