Skip to content

Commit

Permalink
Speed cascade of if statements up (#6255)
Browse files Browse the repository at this point in the history
Fixes #5709. We have a test and a generic fix that improves inlining of every builtin. Everything seems to be faster.
  • Loading branch information
JaroslavTulach authored Apr 14, 2023
1 parent 9ebda56 commit a74933d
Show file tree
Hide file tree
Showing 14 changed files with 386 additions and 37 deletions.
2 changes: 1 addition & 1 deletion distribution/lib/Standard/Base/0.0.0-dev/src/Meta.enso
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ is_unresolved_symbol value = @Builtin_Method "Meta.is_unresolved_symbol"
used carefully.
get_source_location : Integer -> Text
get_source_location skip_frames =
get_source_location_builtin skip_frames+1
get_source_location_builtin skip_frames

## PRIVATE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3058,19 +3058,17 @@ class RuntimeServerTest
)
)
context.receiveN(4) should contain theSameElementsAs Seq(
Api.Response(Api.BackgroundJobsStartedNotification()),
Api.Response(requestId, Api.PushContextResponse(contextId)),
Api.Response(
Api.ExecutionUpdate(
contextId,
Seq(
Api.ExecutionResult.Diagnostic.error(
"Type error: expected `str` to be Text, but got 2 (Integer).",
None,
None,
Some(mainFile),
Some(model.Range(model.Position(2, 10), model.Position(2, 15))),
None,
Vector(
Api.StackTraceElement("Text.+", None, None, None),
Api.StackTraceElement(
"Main.bar",
Some(mainFile),
Expand All @@ -3092,6 +3090,7 @@ class RuntimeServerTest
)
)
),
Api.Response(Api.BackgroundJobsStartedNotification()),
context.executionComplete(contextId)
)
}
Expand Down Expand Up @@ -3216,19 +3215,17 @@ class RuntimeServerTest
)
)
context.receiveN(4) should contain theSameElementsAs Seq(
Api.Response(Api.BackgroundJobsStartedNotification()),
Api.Response(requestId, Api.PushContextResponse(contextId)),
Api.Response(
Api.ExecutionUpdate(
contextId,
Seq(
Api.ExecutionResult.Diagnostic.error(
"Type error: expected `that` to be Number, but got quux (Unresolved_Symbol).",
None,
None,
Some(mainFile),
Some(model.Range(model.Position(10, 8), model.Position(10, 17))),
None,
Vector(
Api.StackTraceElement("Small_Integer.+", None, None, None),
Api.StackTraceElement(
"Main.baz",
Some(mainFile),
Expand Down Expand Up @@ -3266,6 +3263,7 @@ class RuntimeServerTest
)
)
),
Api.Response(Api.BackgroundJobsStartedNotification()),
context.executionComplete(contextId)
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package org.enso.interpreter.bench.benchmarks.semantic;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.enso.interpreter.test.TestBase;
import org.enso.polyglot.MethodNames.Module;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Source;
import org.graalvm.polyglot.Value;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.BenchmarkParams;

@BenchmarkMode(Mode.AverageTime)
@Fork(1)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 3, time = 3)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
public class IfVsCaseBenchmarks extends TestBase {
private static final int INPUT_VEC_SIZE = 100_000;
private Context ctx;
private Value ifBench3;
private Value caseBench3;
private Value ifBench6;
private Value caseBench6;
private Value createVec;
private Value inputVec;
private OutputStream out = new ByteArrayOutputStream();

@Setup
public void initializeBench(BenchmarkParams params) throws IOException {
ctx = Context.newBuilder("enso")
.allowAllAccess(true)
.logHandler(out)
.out(out)
.err(out)
.allowIO(true)
.allowExperimentalOptions(true)
.option(
"enso.languageHomeOverride",
Paths.get("../../distribution/component").toFile().getAbsolutePath()
)
.option("engine.MultiTier", "true")
.option("engine.BackgroundCompilation", "true")
.build();

var code = """
from Standard.Base import all
type My_Type
Value f1 f2 f3 f4 f5 f6
if_bench_3 : Vector My_Type -> Integer
if_bench_3 vec =
vec.fold 0 acc-> curr->
if curr.f1.not then acc else
if curr.f2.not then acc else
if curr.f3.not then acc else
acc + 1
case_bench_3 : Vector My_Type -> Integer
case_bench_3 vec =
vec.fold 0 acc-> curr->
case curr.f1 of
False -> acc
True -> case curr.f2 of
False -> acc
True -> case curr.f3 of
False -> acc
True -> acc + 1
if_bench_6 : Vector My_Type -> Integer
if_bench_6 vec =
vec.fold 0 acc-> curr->
if curr.f1.not then acc else
if curr.f2.not then acc else
if curr.f3.not then acc else
if curr.f4.not then acc else
if curr.f5.not then acc else
if curr.f6.not then acc else
acc + 1
case_bench_6 : Vector My_Type -> Integer
case_bench_6 vec =
vec.fold 0 acc-> curr->
case curr.f1 of
False -> acc
True -> case curr.f2 of
False -> acc
True -> case curr.f3 of
False -> acc
True -> case curr.f4 of
False -> acc
True -> case curr.f5 of
False -> acc
True -> case curr.f6 of
False -> acc
True -> acc + 1
create_vec polyglot_vec =
Vector.from_polyglot_array polyglot_vec . map elem->
My_Type.Value (elem.at 0) (elem.at 1) (elem.at 2) (elem.at 3) (elem.at 4) (elem.at 5)
""";

var file = File.createTempFile("if_case", ".enso");
try (var w = new FileWriter(file)) {
w.write(code);
}
var src = Source.newBuilder("enso", file).build();
Value module = ctx.eval(src);
ifBench3 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "if_bench_3"));
caseBench3 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "case_bench_3"));
ifBench6 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "if_bench_6"));
caseBench6 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "case_bench_6"));
createVec = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "create_vec"));
// So far, input is a vector of My_Type.Value with all fields set to True
inputVec = createMyTypeAllTrue(INPUT_VEC_SIZE);
}

@TearDown
public void tearDown() {
ctx.close();
}

/**
* Iterates over a vector of {@code My_Type} values with True only fields.
*/
@Benchmark
public void ifBench3() {
Value res = ifBench3.execute(inputVec);
checkResult(res);
}

@Benchmark
public void ifBench6() {
Value res = ifBench6.execute(inputVec);
checkResult(res);
}

@Benchmark
public void caseBench3() {
Value res = caseBench3.execute(inputVec);
checkResult(res);
}

@Benchmark
public void caseBench6() {
Value res = caseBench6.execute(inputVec);
checkResult(res);
}

private static void checkResult(Value res) {
if (res.asInt() != INPUT_VEC_SIZE) {
throw new AssertionError("Expected result: " + INPUT_VEC_SIZE + ", got: " + res.asInt());
}
}

/**
* Creates a vector of {@code My_Type} with all True fields
*/
private Value createMyTypeAllTrue(int size) {
List<List<Boolean>> inputPolyVec = new ArrayList<>();
for (int i = 0; i < size; i++) {
inputPolyVec.add(List.of(true, true, true, true, true, true));
}
return createVec.execute(inputPolyVec);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.enso.interpreter.node;

import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.RootNode;

/**
* Special interface that allows various {@link RootNode} subclasses to provide
* more effective implementation of {@link DirectCallNode}. Used by for example
* by {@code BuiltinRootNode}.
*/
public interface InlineableRootNode {
/**
* Provides access to {@link RootNode}. Usually the object shall inherit from
* {link RootNode} as well as implement the {@link InlineableRootNode}
* interface. This method thus usually returns {@code this}.
*
* @return {@code this} types as {link RootNode}
*/
public RootNode getRootNode();

/**
* Name of the {@link RootNode}.
*
* @return root node name
*/
public String getName();

/**
* Override to provide more effective implementation of {@link DirectCallNode}
* suited more for Enso aggressive inlining.
*
* @return a node to {@link DirectCallNode#call(java.lang.Object...) call} the
* associated {@link RootNode} - may return {@code null}
*/
public DirectCallNode createDirectCallNode();

/**
* * Obtain a {@link DirectCallNode} for given {@link CallTarget}.Either
* delegates to {@link #createDirectCallNode} or uses regular
* {@link DirectCallNode#create(com.oracle.truffle.api.CallTarget)} method.
* Use for example by {@code ExecuteCallNode}.
*
* @param target call target with regular or
* {@link InlineableRootNode} {@link RootCallTarget#getRootNode()}
* @return instance of {@link DirectCallNode} to use to invoke the
* {@link RootNode#execute(com.oracle.truffle.api.frame.VirtualFrame)}.
*/
public static DirectCallNode create(RootCallTarget target) {
if (target.getRootNode() instanceof InlineableRootNode inRoot && inRoot.createDirectCallNode() instanceof DirectCallNode node) {
return node;
}
return DirectCallNode.create(target);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.oracle.truffle.api.nodes.IndirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.node.InlineableRootNode;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;

Expand Down Expand Up @@ -52,11 +53,15 @@ protected Object callDirect(
Object state,
Object[] arguments,
@Cached("function.getCallTarget()") RootCallTarget cachedTarget,
@Cached("create(cachedTarget)") DirectCallNode callNode) {
@Cached("createCallNode(cachedTarget)") DirectCallNode callNode) {
return callNode.call(
Function.ArgumentsHelper.buildArguments(function, callerInfo, state, arguments));
}

static DirectCallNode createCallNode(RootCallTarget t) {
return InlineableRootNode.create(t);
}

/**
* Calls the function with a lookup.
*
Expand Down
Loading

0 comments on commit a74933d

Please sign in to comment.