Skip to content

Commit

Permalink
Automatically force all-defaulted functions (#3414)
Browse files Browse the repository at this point in the history
This changes the interpreter to treat functions with all-defaulted args as thunks. Seems to have no performance impact in compiled code.
  • Loading branch information
kustosz authored Apr 27, 2022
1 parent bb6a5ba commit 96a0c92
Show file tree
Hide file tree
Showing 38 changed files with 236 additions and 461 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
- [Fixed compiler issue related to module cache.][3367]
- [Fixed execution of defaulted arguments of Atom Constructors][3358]
- [Converting Enso Date to java.time.LocalDate and back][3374]
- [Functions with all-defaulted arguments now execute automatically][3414]

[3227]: https://github.com/enso-org/enso/pull/3227
[3248]: https://github.com/enso-org/enso/pull/3248
Expand All @@ -195,6 +196,7 @@
[3367]: https://github.com/enso-org/enso/pull/3367
[3374]: https://github.com/enso-org/enso/pull/3374
[3412]: https://github.com/enso-org/enso/pull/3412
[3414]: https://github.com/enso-org/enso/pull/3414
[3417]: https://github.com/enso-org/enso/pull/3417

# Enso 2.0.0-alpha.18 (2021-10-12)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ public enum TailStatus {
/** Node is in a tail position and marked as a tail call. */
TAIL_LOOP,
/** Node is not in a tail position. */
NOT_TAIL
NOT_TAIL;

private static final int NUMBER_OF_VALUES = values().length;

public static int numberOfValues() {
return NUMBER_OF_VALUES;
}
}

private @CompilationFinal TailStatus tailStatus = TailStatus.NOT_TAIL;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.source.SourceSection;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
Expand All @@ -17,7 +16,6 @@
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ private Object executeArguments(
Object state,
ThunkExecutorNode thunkExecutorNode) {
for (int i = 0; i < mapping.getArgumentShouldExecute().length; i++) {
if (TypesGen.isThunk(arguments[i]) && mapping.getArgumentShouldExecute()[i]) {
if (mapping.getArgumentShouldExecute()[i]) {
Stateful result =
thunkExecutorNode.executeThunk(
TypesGen.asThunk(arguments[i]), state, BaseNode.TailStatus.NOT_TAIL);
thunkExecutorNode.executeThunk(arguments[i], state, BaseNode.TailStatus.NOT_TAIL);
arguments[i] = result.getValue();
state = result.getState();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ public static ReadArgumentNode build(int position, ExpressionNode defaultValue)
*/
@Override
public Object executeGeneric(VirtualFrame frame) {
Object argument = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments())[index];
Object arguments[] = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments());

if (defaultValue == null) {
return argument;
return arguments[index];
}

// Note [Handling Argument Defaults]
if (defaultingProfile.profile(argument == null)) {
if (defaultingProfile.profile(arguments.length <= index || arguments[index] == null)) {
return defaultValue.executeGeneric(frame);
} else {
return argument;
return arguments[index];
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.callable.function.Function;

/** This node is responsible for wrapping a call target in a {@link Thunk} at execution time. */
@NodeInfo(shortName = "CreateThunk", description = "Wraps a call target in a thunk at runtime")
Expand Down Expand Up @@ -34,6 +34,6 @@ public static CreateThunkNode build(RootCallTarget callTarget) {
*/
@Override
public Object executeGeneric(VirtualFrame frame) {
return new Thunk(this.callTarget, frame.materialize());
return Function.thunk(this.callTarget, frame.materialize());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.state.Stateful;

/** Node responsible for handling user-requested thunks forcing. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.Constants;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.InvokeCallableNode;
import org.enso.interpreter.node.callable.dispatch.IndirectInvokeFunctionNode;
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
import org.enso.interpreter.node.callable.dispatch.LoopingCallOptimiserNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.state.Stateful;
import org.enso.interpreter.runtime.type.TypesGen;

/** Node responsible for executing (forcing) thunks passed to it as runtime values. */
@GenerateUncached
Expand Down Expand Up @@ -40,57 +42,102 @@ public static ThunkExecutorNode build() {
*/
public abstract Stateful executeThunk(Object thunk, Object state, BaseNode.TailStatus isTail);

static boolean isThunk(Object th) {
return TypesGen.isThunk(th);
}

@Specialization(guards = "!isThunk(thunk)")
Stateful doOther(Object thunk, Object state, BaseNode.TailStatus isTail) {
return new Stateful(state, thunk);
boolean sameCallTarget(DirectCallNode callNode, Function function) {
return function.getCallTarget() == callNode.getCallTarget();
}

@Specialization(
guards = "callNode.getCallTarget() == thunk.getCallTarget()",
guards = {"function.isThunk()", "sameCallTarget(callNode, function)"},
limit = Constants.CacheSizes.THUNK_EXECUTOR_NODE)
Stateful doCached(
Thunk thunk,
Function function,
Object state,
BaseNode.TailStatus isTail,
@Cached("create(thunk.getCallTarget())") DirectCallNode callNode,
@Cached("create(function.getCallTarget())") DirectCallNode callNode,
@Cached LoopingCallOptimiserNode loopingCallOptimiserNode) {
CompilerAsserts.partialEvaluationConstant(isTail);
if (isTail != BaseNode.TailStatus.NOT_TAIL) {
return (Stateful) callNode.call(Function.ArgumentsHelper.buildArguments(thunk, state));
return (Stateful) callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
} else {
try {
return (Stateful) callNode.call(Function.ArgumentsHelper.buildArguments(thunk, state));
return (Stateful) callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
e.getFunction(), e.getCallerInfo(), e.getState(), e.getArguments());
}
}
}

@Specialization(replaces = "doCached")
@Specialization(replaces = "doCached", guards = "function.isThunk()")
Stateful doUncached(
Thunk thunk,
Function function,
Object state,
BaseNode.TailStatus isTail,
@Cached IndirectCallNode callNode,
@Cached LoopingCallOptimiserNode loopingCallOptimiserNode) {
if (isTail != BaseNode.TailStatus.NOT_TAIL) {
return (Stateful)
callNode.call(
thunk.getCallTarget(), Function.ArgumentsHelper.buildArguments(thunk, state));
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
} else {
try {
return (Stateful)
callNode.call(
thunk.getCallTarget(), Function.ArgumentsHelper.buildArguments(thunk, state));
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
e.getFunction(), e.getCallerInfo(), e.getState(), e.getArguments());
}
}
}

static InvokeFunctionNode buildInvokeFunctionNode(BaseNode.TailStatus tailStatus) {
var node =
InvokeFunctionNode.build(
new CallArgumentInfo[0],
InvokeCallableNode.DefaultsExecutionMode.EXECUTE,
InvokeCallableNode.ArgumentsExecutionMode.EXECUTE);
node.setTailStatus(tailStatus);
return node;
}

static int numberOfTailStatuses() {
return BaseNode.TailStatus.numberOfValues();
}

@Specialization(
guards = {"!fn.isThunk()", "fn.isFullyApplied()", "isTail == cachedIsTail"},
limit = "numberOfTailStatuses()")
Stateful doCachedFn(
Function fn,
Object state,
BaseNode.TailStatus isTail,
@Cached("isTail") BaseNode.TailStatus cachedIsTail,
@Cached("buildInvokeFunctionNode(cachedIsTail)") InvokeFunctionNode invokeFunctionNode) {
return invokeFunctionNode.execute(fn, null, state, new Object[0]);
}

@Specialization(
guards = {"!fn.isThunk()", "fn.isFullyApplied()"},
replaces = {"doCachedFn"})
Stateful doUncachedFn(
Function fn,
Object state,
BaseNode.TailStatus isTail,
@Cached IndirectInvokeFunctionNode invokeFunctionNode) {
return invokeFunctionNode.execute(
fn,
null,
state,
new Object[0],
new CallArgumentInfo[0],
InvokeCallableNode.DefaultsExecutionMode.EXECUTE,
InvokeCallableNode.ArgumentsExecutionMode.EXECUTE,
isTail);
}

@Fallback
Stateful doOther(Object thunk, Object state, BaseNode.TailStatus isTail) {
return new Stateful(state, thunk);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.enso.interpreter.dsl.Suspend;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.state.Stateful;

@BuiltinMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.InvokeCallableNode;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.state.Stateful;

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.data.Array;
import org.enso.interpreter.runtime.type.TypesGen;

@BuiltinMethod(
type = "Meta",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.enso.interpreter.dsl.Suspend;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.state.Stateful;

@BuiltinMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import org.enso.interpreter.dsl.Suspend;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.control.ThreadInterruptedException;
import org.enso.interpreter.runtime.state.Stateful;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.enso.interpreter.node.expression.builtin.text.util.ToJavaStringNode;
import org.enso.interpreter.runtime.Context;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.scope.LocalScope;
import org.enso.interpreter.runtime.scope.ModuleScope;
Expand Down Expand Up @@ -98,7 +98,7 @@ Stateful doCached(
"parseExpression(callerInfo.getLocalScope(), callerInfo.getModuleScope(), expressionStr)")
RootCallTarget cachedCallTarget,
@Cached("build()") ThunkExecutorNode thunkExecutorNode) {
Thunk thunk = new Thunk(cachedCallTarget, callerInfo.getFrame());
Function thunk = Function.thunk(cachedCallTarget, callerInfo.getFrame());
return thunkExecutorNode.executeThunk(thunk, state, getTailStatus());
}

Expand All @@ -114,7 +114,7 @@ Stateful doUncached(
callerInfo.getLocalScope(),
callerInfo.getModuleScope(),
toJavaStringNode.execute(expression));
Thunk thunk = new Thunk(callTarget, callerInfo.getFrame());
Function thunk = Function.thunk(callTarget, callerInfo.getFrame());
return thunkExecutorNode.executeThunk(thunk, state, getTailStatus());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.enso.interpreter.node.expression.builtin.error.CaughtPanicConvertToDataflowErrorMethodGen;
import org.enso.interpreter.node.expression.builtin.error.GetAttachedStackTraceMethodGen;
import org.enso.interpreter.node.expression.builtin.error.ThrowPanicMethodGen;
import org.enso.interpreter.node.expression.builtin.function.ExplicitCallFunctionMethodGen;
import org.enso.interpreter.node.expression.builtin.interop.java.AddToClassPathMethodGen;
import org.enso.interpreter.node.expression.builtin.interop.java.LookupClassMethodGen;
import org.enso.interpreter.node.expression.builtin.io.GetCwdMethodGen;
Expand Down Expand Up @@ -191,8 +190,6 @@ public Builtins(Context context) {
scope.registerMethod(debug, MethodNames.Debug.EVAL, DebugEvalMethodGen.makeFunction(language));
scope.registerMethod(debug, "breakpoint", DebugBreakpointMethodGen.makeFunction(language));

scope.registerMethod(function, "call", ExplicitCallFunctionMethodGen.makeFunction(language));

scope.registerMethod(any, "to_text", AnyToTextMethodGen.makeFunction(language));
scope.registerMethod(any, "to_display_text", AnyToDisplayTextMethodGen.makeFunction(language));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import org.enso.interpreter.node.ExpressionNode;

/** Tracks the specifics about how arguments are defined at the callable definition site. */
public class ArgumentDefinition {
public final class ArgumentDefinition {

/** Represents the mode of passing this argument to the function. */
public enum ExecutionMode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import org.enso.interpreter.node.ExpressionNode;

/** Tracks the specifics about how arguments are specified at a call site. */
public class CallArgument {
public final class CallArgument {
private final String name;
private final ExpressionNode expression;

Expand Down
Loading

0 comments on commit 96a0c92

Please sign in to comment.