Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TCO in the presence of warnings #7116

Merged
merged 3 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.PanicSentinel;
Expand Down Expand Up @@ -264,6 +265,15 @@ public Object invokeWarnings(
State state,
Object[] arguments,
@CachedLibrary(limit = "3") WarningsLibrary warnings) {

Warning[] extracted;
Object callable;
try {
extracted = warnings.getWarnings(warning, null);
callable = warnings.removeWarnings(warning);
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a side note: Eventually, we should probably align the exceptions we are throwing from the nodes in the engine. In most places it is IllegalStateException, but sometimes it is also CompilerDirectives.shouldNotReachHere. We should provide a proper wrapper for these exceptions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've had the same observation a while ago ;)

}
try {
if (childDispatch == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
Expand All @@ -277,7 +287,7 @@ public Object invokeWarnings(
invokeFunctionNode.getSchema(),
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode()));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
Copy link
Member

@JaroslavTulach JaroslavTulach Jun 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the test below... what node behavior this change affects?

The warning in the test is attached to an argument, not to self, so I don't see how this change can be triggered and make a difference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you mean why I change it in InvokeCallableNode when the fix was only necessary in InvokeMethodNode, right?
I'm reverting the previous change for all three cases and replacing it with the same approach i.e. warnings propagation via TailCallException.
It does not appear that we have any test case exercising this case specifically.

childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
Expand All @@ -287,21 +297,21 @@ public Object invokeWarnings(
}

var result = childDispatch.execute(
warnings.removeWarnings(warning),
callable,
callerFrame,
state,
arguments);

Warning[] extracted = warnings.getWarnings(warning, null);

if (result instanceof DataflowError) {
return result;
} else if (result instanceof WithWarnings withWarnings) {
return withWarnings.append(EnsoContext.get(this), extracted);
} else {
return WithWarnings.wrap(EnsoContext.get(this), result, extracted);
}
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
} catch (TailCallException e) {
throw new TailCallException(e, extracted);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
Expand Down Expand Up @@ -162,19 +163,23 @@ Object doWarning(
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thatArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
} finally {
lock.unlock();
}
}
arguments[thatArgumentPosition] = that.getValue();
Object value = that.getValue();
arguments[thatArgumentPosition] = value;
ArrayRope<Warning> warnings = that.getReassignedWarningsAsRope(this);
Object result =
childDispatch.execute(frame, state, conversion, self, that.getValue(), arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
try {
Object result = childDispatch.execute(frame, state, conversion, self, value, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
} catch (TailCallException e) {
throw new TailCallException(e, warnings.toArray(Warning[]::new));
}
}

@Specialization(guards = "interop.isString(that)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.EnsoDate;
import org.enso.interpreter.runtime.data.EnsoDateTime;
Expand Down Expand Up @@ -328,7 +329,7 @@ Object doWarning(
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thisArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
Expand All @@ -339,8 +340,12 @@ Object doWarning(

arguments[thisArgumentPosition] = selfWithoutWarnings;

Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
try {
Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
} catch (TailCallException e) {
throw new TailCallException(e, arrOfWarnings);
}
}

@ExplodeLoop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -33,12 +34,14 @@ public static CallOptimiserNode build() {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code callable}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code callable} using {@code arguments}
*/
public abstract Object executeDispatch(
VirtualFrame frame,
Function callable,
CallerInfo callerInfo,
State state,
Object[] arguments);
Object[] arguments,
Warning[] warnings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public Object execute(
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);

return this.oversaturatedCallableNode.execute(
evaluatedVal, frame, state, oversaturatedArguments);
Expand All @@ -154,7 +154,7 @@ private Object doCall(
return switch (getTailStatus()) {
case TAIL_DIRECT -> directCall.executeCall(frame, function, callerInfo, state, arguments);
case TAIL_LOOP -> throw new TailCallException(function, callerInfo, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Object doCurry(
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);

return oversaturatedCallableNode.execute(
evaluatedVal,
Expand Down Expand Up @@ -129,7 +129,7 @@ private Object doCall(
case TAIL_LOOP:
throw new TailCallException(function, callerInfo, arguments);
default:
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import com.oracle.truffle.api.nodes.RepeatingNode;
import org.enso.interpreter.node.callable.ExecuteCallNode;
import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -54,31 +57,78 @@ public static LoopingCallOptimiserNode build() {
* @param loopNode a cached instance of the loop node used by this node
* @return the result of executing {@code function} using {@code arguments}
*/
@Specialization
public Object dispatch(
@Specialization(guards = "warnings == null")
public Object cachedDispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
return dispatch(function, callerInfo, state, arguments, loopNode);
}

@Specialization(guards = "warnings != null")
public Object cachedDispatchWarnings(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
Object result = dispatch(function, callerInfo, state, arguments, loopNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct. If there are warnings, remove them, handle all the tail calls and attach them again. Not sure how that can be triggered by the test however?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how that can be triggered by the test however?

I'm not sure how to convince you other than running it with debugger attached and putting the breakpoint yourself.

}

private Object dispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
LoopNode loopNode) {
RepeatedCallNode repeatedCallNode = (RepeatedCallNode) loopNode.getRepeatingNode();
VirtualFrame frame = repeatedCallNode.createFrame();
repeatedCallNode.setNextCall(frame, function, callerInfo, arguments);
repeatedCallNode.setState(frame, state);
loopNode.execute(frame);

return repeatedCallNode.getResult(frame);
}

@Specialization(replaces = "dispatch")
@Specialization(replaces = "cachedDispatch", guards = "warnings == null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatch(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
return loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
}

@Specialization(replaces = "cachedDispatchWarnings", guards = "warnings != null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatchWarnings(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
Object result =
loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
}

private Object loopUntilCompletion(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
ExecuteCallNode executeCallNode) {
while (true) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -40,6 +41,7 @@ public static SimpleCallOptimiserNode build() {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code function}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code function} using {@code arguments}
*/
@Override
Expand All @@ -48,7 +50,8 @@ public Object executeDispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments) {
Object[] arguments,
Warning[] warnings) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
} catch (TailCallException e) {
Expand All @@ -65,7 +68,7 @@ public Object executeDispatch(
}
}
return next.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Object doCached(
return callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Expand All @@ -89,7 +89,7 @@ Object doUncached(
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ public int compare(Object x, Object y) {
Object yConverted;
if (hasCustomOnFunc) {
// onFunc cannot have `self` argument, we assume it has just one argument.
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x});
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y});
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}, null);
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}, null);
} else {
xConverted = x;
yConverted = y;
Expand All @@ -823,7 +823,7 @@ public int compare(Object x, Object y) {
} else {
args = new Object[] {xConverted, yConverted};
}
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args);
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args, null);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ private static Object evalExpression(
eval.getFunction(),
callerInfo,
context.emptyState(),
new Object[] {builtins.debug(), Text.create(expr)});
new Object[] {builtins.debug(), Text.create(expr)},
null);
}

private static Object generateDocs(Module module, EnsoContext context) {
Expand Down
Loading