Skip to content

Commit

Permalink
s/ExecutionEventListener/ExecutionEventNodeFactory
Browse files Browse the repository at this point in the history
Replacing ExecutionEventListener with ExecutionEventNodeFactory allows
us to enhance EventNode with an additional child node, TypeOfNode. The
latter gets adopted to the RootNode, meaning it won't blow up assertions
anymore.

Additionally reduced the scope of @TruffleBoundary so that the child can
possibly get PE, as suggested by Jaroslav.
  • Loading branch information
hubertp committed Feb 3, 2023
1 parent 11661cd commit 121e463
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,68 @@ protected void onCreate(Env env) {
this.env = env;
}

/** The listener class used by this instrument. */
private static class IdExecutionEventListener implements ExecutionEventListener {
/** Factory for creating new id event nodes **/
private static class IdEventNodeFactory implements ExecutionEventNodeFactory {

private final CallTarget entryCallTarget;
private final Consumer<ExpressionCall> functionCallCallback;
private final Consumer<ExpressionValue> onComputedCallback;
private final Consumer<ExpressionValue> onCachedCallback;
private final Consumer<Exception> onExceptionalCallback;
private final RuntimeCache cache;
private final MethodCallsCache methodCallsCache;
private final UpdatesSynchronizationState syncState;
private final UUID nextExecutionItem;
private final Map<UUID, FunctionCallInfo> calls = new HashMap<>();
private final Timer timer;

/**
* Creates a new event node factory.
*
* @param entryCallTarget the call target being observed.
* @param cache the precomputed expression values.
* @param methodCallsCache the storage tracking the executed method calls.
* @param syncState the synchronization state of runtime updates.
* @param nextExecutionItem the next item scheduled for execution.
* @param functionCallCallback the consumer of function call events.
* @param onComputedCallback the consumer of the computed value events.
* @param onCachedCallback the consumer of the cached value events.
* @param onExceptionalCallback the consumer of the exceptional events.
* @param timer the timer for timing execution
*/
public IdEventNodeFactory(
CallTarget entryCallTarget,
RuntimeCache cache,
MethodCallsCache methodCallsCache,
UpdatesSynchronizationState syncState,
UUID nextExecutionItem, // The expression ID
Consumer<ExpressionCall> functionCallCallback,
Consumer<ExpressionValue> onComputedCallback,
Consumer<ExpressionValue> onCachedCallback,
Consumer<Exception> onExceptionalCallback,
Timer timer) {
this.entryCallTarget = entryCallTarget;
this.cache = cache;
this.methodCallsCache = methodCallsCache;
this.syncState = syncState;
this.nextExecutionItem = nextExecutionItem;
this.functionCallCallback = functionCallCallback;
this.onComputedCallback = onComputedCallback;
this.onCachedCallback = onCachedCallback;
this.onExceptionalCallback = onExceptionalCallback;
this.timer = timer;
}

@Override
public ExecutionEventNode create(EventContext context) {
return new IdExecutionEventNode(context, entryCallTarget, cache, methodCallsCache, syncState,
nextExecutionItem, calls, functionCallCallback, onComputedCallback, onCachedCallback, onExceptionalCallback, timer);
}
}

/** The execution event node class used by this instrument. */
private static class IdExecutionEventNode extends ExecutionEventNode {
private final EventContext context;
private final CallTarget entryCallTarget;
private final Consumer<ExpressionCall> functionCallCallback;
private final Consumer<ExpressionValue> onComputedCallback;
Expand All @@ -64,13 +124,13 @@ private static class IdExecutionEventListener implements ExecutionEventListener
private final MethodCallsCache callsCache;
private final UpdatesSynchronizationState syncState;
private final UUID nextExecutionItem;
private final Map<UUID, FunctionCallInfo> calls = new HashMap<>();
private final Map<UUID, FunctionCallInfo> calls;
private final Timer timer;
private final TypeOfNode typeOfNode;
private long nanoTimeElapsed = 0;
private @Child TypeOfNode typeOfNode = TypeOfNode.build();

/**
* Creates a new listener.
* Creates a new event node.
*
* @param entryCallTarget the call target being observed.
* @param cache the precomputed expression values.
Expand All @@ -83,19 +143,23 @@ private static class IdExecutionEventListener implements ExecutionEventListener
* @param onExceptionalCallback the consumer of the exceptional events.
* @param timer the timer for timing execution
*/
public IdExecutionEventListener(
public IdExecutionEventNode(
EventContext context,
CallTarget entryCallTarget,
RuntimeCache cache,
MethodCallsCache methodCallsCache,
UpdatesSynchronizationState syncState,
UUID nextExecutionItem, // The expression ID
Map<UUID, FunctionCallInfo> calls,
Consumer<ExpressionCall> functionCallCallback,
Consumer<ExpressionValue> onComputedCallback,
Consumer<ExpressionValue> onCachedCallback,
Consumer<Exception> onExceptionalCallback,
Timer timer) {
this.context = context;
this.entryCallTarget = entryCallTarget;
this.cache = cache;
this.calls = calls;
this.callsCache = methodCallsCache;
this.syncState = syncState;
this.nextExecutionItem = nextExecutionItem;
Expand All @@ -104,24 +168,23 @@ public IdExecutionEventListener(
this.onCachedCallback = onCachedCallback;
this.onExceptionalCallback = onExceptionalCallback;
this.timer = timer;
this.typeOfNode = TypeOfNode.build();
}

@Override
public Object onUnwind(EventContext context, VirtualFrame frame, Object info) {
public Object onUnwind(VirtualFrame frame, Object info) {
return info;
}

@Override
public void onEnter(EventContext context, VirtualFrame frame) {
public void onEnter(VirtualFrame frame) {
if (!isTopFrame(entryCallTarget)) {
return;
}
onEnterImpl(context);
onEnterImpl();
}

@CompilerDirectives.TruffleBoundary
private void onEnterImpl(EventContext context) {
private void onEnterImpl() {
UUID nodeId = getNodeId(context.getInstrumentedNode());

// Add a flag to say it was cached.
Expand Down Expand Up @@ -152,12 +215,11 @@ private void onEnterImpl(EventContext context) {
* Triggered when a node (either a function call sentry or an identified expression) finishes
* execution.
*
* @param context the event context.
* @param frame the current execution frame.
* @param result the result of executing the node this method was triggered for.
*/
@Override
public void onReturnValue(EventContext context, VirtualFrame frame, Object result) {
public void onReturnValue(VirtualFrame frame, Object result) {
nanoTimeElapsed = timer.getTime() - nanoTimeElapsed;
if (!isTopFrame(entryCallTarget)) {
return;
Expand All @@ -174,19 +236,18 @@ public void onReturnValue(EventContext context, VirtualFrame frame, Object resul
}

@Override
public void onReturnExceptional(EventContext context, VirtualFrame frame, Throwable exception) {
public void onReturnExceptional(VirtualFrame frame, Throwable exception) {
if (exception instanceof TailCallException) {
onTailCallReturn(exception, Function.ArgumentsHelper.getState(frame.getArguments()), context);
onTailCallReturn(exception, Function.ArgumentsHelper.getState(frame.getArguments()));
} else if (exception instanceof PanicException) {
PanicException panicException = (PanicException) exception;
onReturnValue(
context, frame, new PanicSentinel(panicException, context.getInstrumentedNode()));
onReturnValue(frame, new PanicSentinel(panicException, context.getInstrumentedNode()));
} else if (exception instanceof PanicSentinel) {
onReturnValue(context, frame, exception);
onReturnValue(frame, exception);
}
}

@CompilerDirectives.TruffleBoundary
//@CompilerDirectives.TruffleBoundary
private void onExpressionReturn(Object result, Node node, EventContext context) throws ThreadDeath {
boolean isPanic = result instanceof PanicSentinel;
UUID nodeId = ((ExpressionNode) node).getId();
Expand All @@ -200,7 +261,7 @@ private void onExpressionReturn(Object result, Node node, EventContext context)
}

String cachedType = cache.getType(nodeId);
FunctionCallInfo call = calls.get(nodeId);
FunctionCallInfo call = functionCallInfoById(nodeId);
FunctionCallInfo cachedCall = cache.getCall(nodeId);
ProfilingInfo[] profilingInfo = new ProfilingInfo[] {new ExecutionTime(nanoTimeElapsed)};

Expand All @@ -219,12 +280,21 @@ private void onExpressionReturn(Object result, Node node, EventContext context)
cache.putType(nodeId, resultType);
cache.putCall(nodeId, call);

onComputedCallback.accept(expressionValue);
passExpressionValueToCallback(expressionValue);
if (isPanic) {
throw context.createUnwind(result);
}
}

@CompilerDirectives.TruffleBoundary
private void passExpressionValueToCallback(ExpressionValue expressionValue) {
onComputedCallback.accept(expressionValue);
}

@CompilerDirectives.TruffleBoundary
private FunctionCallInfo functionCallInfoById(UUID nodeId) {
return calls.get(nodeId);
}

@CompilerDirectives.TruffleBoundary
private void onFunctionReturn(UUID nodeId, Object result, EventContext context) throws ThreadDeath {
Expand All @@ -241,7 +311,7 @@ private void onFunctionReturn(UUID nodeId, Object result, EventContext context)
}

@CompilerDirectives.TruffleBoundary
private void onTailCallReturn(Throwable exception, State state, EventContext context) {
private void onTailCallReturn(Throwable exception, State state) {
try {
TailCallException tailCallException = (TailCallException) exception;
FunctionCallInstrumentationNode.FunctionCall functionCall =
Expand All @@ -250,7 +320,7 @@ private void onTailCallReturn(Throwable exception, State state, EventContext con
state,
tailCallException.getArguments());
Object result = InteropLibrary.getFactory().getUncached().execute(functionCall);
onReturnValue(context, null, result);
onReturnValue(null, result);
} catch (InteropException e) {
onExceptionalCallback.accept(e);
}
Expand Down Expand Up @@ -298,7 +368,7 @@ private UUID getNodeId(Node node) {
}

/**
* Attach a new listener to observe identified nodes within given function.
* Attach a new event node factory to observe identified nodes within given function.
*
* @param module module that contains the code
* @param entryCallTarget the call target being observed.
Expand All @@ -311,10 +381,10 @@ private UUID getNodeId(Node node) {
* @param onComputedCallback the consumer of the computed value events.
* @param onCachedCallback the consumer of the cached value events.
* @param onExceptionalCallback the consumer of the exceptional events.
* @return a reference to the attached event listener.
* @return a reference to the attached event node factory.
*/
@Override
public EventBinding<ExecutionEventListener> bind(
public EventBinding<ExecutionEventNodeFactory> bind(
Module module,
CallTarget entryCallTarget,
RuntimeCache cache,
Expand All @@ -340,9 +410,9 @@ public EventBinding<ExecutionEventListener> bind(
SourceSectionFilter filter = builder.build();

return env.getInstrumenter()
.attachExecutionEventListener(
.attachExecutionEventFactory(
filter,
new IdExecutionEventListener(
new IdEventNodeFactory(
entryCallTarget,
cache,
methodCallsCache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class RuntimeInstrumentTest

// Open the new file
context.send(
Api.Request(Api.SetModuleSourcesNotification(mainFile, contents))
Api.Request(Api.OpenFileNotification(mainFile, contents))
)
context.receiveNone shouldEqual None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.instrumentation.EventBinding;
import com.oracle.truffle.api.instrumentation.ExecutionEventListener;
import com.oracle.truffle.api.instrumentation.ExecutionEventNode;
import com.oracle.truffle.api.instrumentation.ExecutionEventNodeFactory;
import com.oracle.truffle.api.nodes.RootNode;
import java.util.Arrays;
import java.util.Objects;
Expand All @@ -21,7 +23,7 @@ public interface IdExecutionService {
public static final String INSTRUMENT_ID = "id-value-extractor";

/**
* Attach a new listener to observe identified nodes within given function.
* Attach a new event node factory to observe identified nodes within given function.
*
* @param module module that contains the code
* @param entryCallTarget the call target being observed.
Expand All @@ -34,9 +36,9 @@ public interface IdExecutionService {
* @param onComputedCallback the consumer of the computed value events.
* @param onCachedCallback the consumer of the cached value events.
* @param onExceptionalCallback the consumer of the exceptional events.
* @return a reference to the attached event listener.
* @return a reference to the attached event node factory.
*/
public EventBinding<ExecutionEventListener> bind(
EventBinding<ExecutionEventNodeFactory> bind(
Module module,
CallTarget entryCallTarget,
RuntimeCache cache,
Expand All @@ -50,7 +52,7 @@ public EventBinding<ExecutionEventListener> bind(
Consumer<Exception> onExceptionalCallback);

/** A class for notifications about functions being called in the course of execution. */
public static class ExpressionCall {
class ExpressionCall {
private final UUID expressionId;
private final FunctionCallInstrumentationNode.FunctionCall call;

Expand All @@ -77,7 +79,7 @@ public FunctionCallInstrumentationNode.FunctionCall getCall() {
}

/** A class for notifications about identified expressions' values being computed. */
public static class ExpressionValue {
class ExpressionValue {
private final UUID expressionId;
private final Object value;
private final String type;
Expand Down Expand Up @@ -195,7 +197,7 @@ public boolean isFunctionCallChanged() {
}

/** Information about the function call. */
public static class FunctionCallInfo {
class FunctionCallInfo {

private final QualifiedName moduleName;
private final QualifiedName typeName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.enso.interpreter.instrument;

import com.oracle.truffle.api.CompilerDirectives;

import java.lang.ref.SoftReference;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -21,6 +23,7 @@ public final class RuntimeCache {
* @param value the added value.
* @return {@code true} if the value was added to the cache.
*/
@CompilerDirectives.TruffleBoundary
public boolean offer(UUID key, Object value) {
Double weight = weights.get(key);
if (weight != null && weight > 0) {
Expand Down Expand Up @@ -57,11 +60,13 @@ public void clear() {
*
* @return the previously cached type.
*/
@CompilerDirectives.TruffleBoundary
public String putType(UUID key, String typeName) {
return types.put(key, typeName);
}

/** @return the cached type of the expression */
@CompilerDirectives.TruffleBoundary
public String getType(UUID key) {
return types.get(key);
}
Expand All @@ -73,6 +78,7 @@ public String getType(UUID key) {
* @param call the function call.
* @return the function call that was previously associated with this expression.
*/
@CompilerDirectives.TruffleBoundary
public IdExecutionService.FunctionCallInfo putCall(
UUID key, IdExecutionService.FunctionCallInfo call) {
if (call == null) {
Expand All @@ -82,6 +88,7 @@ public IdExecutionService.FunctionCallInfo putCall(
}

/** @return the cached function call associated with the expression. */
@CompilerDirectives.TruffleBoundary
public IdExecutionService.FunctionCallInfo getCall(UUID key) {
return calls.get(key);
}
Expand Down
Loading

0 comments on commit 121e463

Please sign in to comment.