Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing @BuiltinMethod.inlineable and InlineableNode #6442

Merged
merged 16 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ lazy val enso = (project in file("."))
.settings(version := "0.1")
.aggregate(
`interpreter-dsl`,
`interpreter-dsl-test`,
`json-rpc-server-test`,
`json-rpc-server`,
`language-server`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ final class InliningBuiltinsInNode extends Node {
long execute(long a, long b) {
return a + b;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.enso.interpreter.dsl.test;

import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.dsl.BuiltinMethod;
import static org.junit.Assert.assertNotNull;

@BuiltinMethod(type = "InliningBuiltins", name = "need_not", inlineable = true)
final class InliningBuiltinsNeedNotNode extends Node {

long execute(VirtualFrame frame, long a, long b) {
assertNotNull("Some frame is still provided", frame);
return a + b;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.enso.interpreter.dsl.test;

import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.dsl.BuiltinMethod;

@BuiltinMethod(type = "InliningBuiltins", name = "needs", inlineable = false)
final class InliningBuiltinsNeedsNode extends Node {

long execute(long a, long b) {
return a + b;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ long execute(VirtualFrame frame, long a, long b) {
Assert.assertNotNull("VirtualFrame is always provided", frame);
return a + b;
}

}
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
package org.enso.interpreter.dsl.test;

import org.enso.interpreter.node.InlineableRootNode;
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.runtime.callable.function.Function;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.enso.interpreter.node.InlineableNode;

public class InliningBuiltinsTest {

/** @see InliningBuiltinsInNode#execute(long, long) */
@Test
public void executeWithoutVirtualFrame() {
var fn = InliningBuiltinsInMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableRootNode root) {
var call = root.createDirectCallNode();
var clazz = call.getClass().getSuperclass();
assertEquals("InlinedCallNode", clazz.getSimpleName());
assertEquals("BuiltinRootNode", clazz.getEnclosingClass().getSimpleName());
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
var call = root.createInlineableNode();
var clazz = call.getClass();
assertEquals("InlineableNode", clazz.getSuperclass().getSimpleName());
assertEquals("org.enso.interpreter.node.InlineableNode$Root", clazz.getEnclosingClass().getInterfaces()[0].getName());

var res = call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
var res = WithFrame.invoke((frame) -> {
return call.call(frame, Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
});
assertEquals(12L, res);
} else {
fail("It is inlineable: " + fn.getCallTarget().getRootNode());
Expand All @@ -29,15 +34,73 @@ public void executeWithoutVirtualFrame() {
@Test
public void executeWithVirtualFrame() {
var fn = InliningBuiltinsOutMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableRootNode root) {
var call = root.createDirectCallNode();
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
fail("The node isn't inlineable: " + fn.getCallTarget().getRootNode());
} else {
var call = DirectCallNode.create(fn.getCallTarget());
var clazz = call.getClass().getSuperclass();
assertEquals("com.oracle.truffle.api.nodes.DirectCallNode", clazz.getName());

var res = call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
var res = WithFrame.invoke((frame) -> {
return call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
});
assertEquals(12L, res);
}
}

/** @see InliningBuiltinsNeedsNode#execute(long, long) */
@Test
public void executeWhenNeedsVirtualFrame() {
var fn = InliningBuiltinsNeedsMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
fail("The node isn't inlineable: " + fn.getCallTarget().getRootNode());
} else {
var call = DirectCallNode.create(fn.getCallTarget());
var clazz = call.getClass().getSuperclass();
assertEquals("com.oracle.truffle.api.nodes.DirectCallNode", clazz.getName());

var res = WithFrame.invoke((frame) -> {
return call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
});
assertEquals(12L, res);
}
}

/** @see InliningBuiltinsNeedNotNode#execute(com.oracle.truffle.api.frame.VirtualFrame, long, long) */
@Test
public void executeWhenNeedNotVirtualFrame() {
var fn = InliningBuiltinsNeedNotMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
var call = root.createInlineableNode();
var clazz = call.getClass();
assertEquals("InlineableNode", clazz.getSuperclass().getSimpleName());
assertEquals("org.enso.interpreter.node.InlineableNode$Root", clazz.getEnclosingClass().getInterfaces()[0].getName());

var res = WithFrame.invoke((frame) -> {
return call.call(frame, Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
});
assertEquals(12L, res);
} else {
fail("It is inlineable: " + fn.getCallTarget().getRootNode());
}
}

private static final class WithFrame<T> extends RootNode {
private final java.util.function.Function<VirtualFrame, T> fn;

private WithFrame(java.util.function.Function<VirtualFrame, T> fn) {
super(null);
this.fn = fn;
}

@Override
public Object execute(VirtualFrame frame) {
return fn.apply(frame);
}

@SuppressWarnings("unchecked")
static <T> T invoke(java.util.function.Function<VirtualFrame, T> fn, Object... args) {
return (T) new WithFrame<>(fn).getCallTarget().call(args);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class ArrayProxyBenchmarks {
private final long length = 100000;

@Setup
public void initializeBenchmark(BenchmarkParams params) {
public void initializeBenchmark(BenchmarkParams params) throws Exception {
Engine eng =
Engine.newBuilder()
.allowExperimentalOptions(true)
Expand Down Expand Up @@ -59,13 +59,15 @@ Array_Proxy.new n (i -> 3 + 5*i)
make_delegating_vector n =
Vector.from_polyglot_array (make_delegating_proxy n)
""";
var module = ctx.eval("enso", code);
var benchmarkName = SrcUtil.findName(params);
var src = SrcUtil.source(benchmarkName, code);
var module = ctx.eval(src);

this.self = module.invokeMember("get_associated_type");
Function<String, Value> getMethod = (name) -> module.invokeMember("get_method", self, name);

String test_builder;
switch (params.getBenchmark().replaceFirst(".*\\.", "")) {
switch (benchmarkName) {
case "sumOverVector":
test_builder = "make_vector";
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void initializeBenchmark(BenchmarkParams params) throws Exception {
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

var benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
var benchmarkName = SrcUtil.findName(params);
var code = """
avg fn len =
sum acc i = if i == len then acc else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void initializeBenchmark(BenchmarkParams params) throws Exception {
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

var benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
var benchmarkName = SrcUtil.findName(params);
var codeBuilder = new StringBuilder("""
import Standard.Base.Data.Range.Extensions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,8 @@ public void initializeBench(BenchmarkParams params) throws IOException {

""";

var file = File.createTempFile("if_case", ".enso");
try (var w = new FileWriter(file)) {
w.write(code);
}
var src = Source.newBuilder("enso", file).build();
var benchmarkName = SrcUtil.findName(params);
var src = SrcUtil.source(benchmarkName, code);
Value module = ctx.eval(src);
ifBench3 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "if_bench_3"));
caseBench3 = Objects.requireNonNull(module.invokeMember(Module.EVAL_EXPRESSION, "case_bench_3"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void initializeBenchmark(BenchmarkParams params) throws Exception {
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

var benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
var benchmarkName = SrcUtil.findName(params);
var code = """
from Standard.Base.Data.List.List import Cons, Nil
import Standard.Base.IO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void initializeBenchmark(BenchmarkParams params) throws Exception {
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
benchmarkName = SrcUtil.findName(params);
code = """
type List
Cons a b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import java.io.FileWriter;
import java.io.IOException;
import org.graalvm.polyglot.Source;
import org.openjdk.jmh.infra.BenchmarkParams;

final class SrcUtil {
private SrcUtil() {
}

static String findName(BenchmarkParams params) {
return params.getBenchmark().replaceFirst(".*\\.", "");
}

static Source source(String benchmarkName, String code) throws IOException {
var d = new File(new File(new File("."), "target"), "bench-data");
d.mkdirs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class StringBenchmarks {
private Value allLength;

@Setup
public void initializeBenchmark(BenchmarkParams params) {
public void initializeBenchmark(BenchmarkParams params) throws Exception {
var ctx = Context.newBuilder()
.allowExperimentalOptions(true)
.allowIO(true)
Expand All @@ -42,7 +42,8 @@ public void initializeBenchmark(BenchmarkParams params) {
"enso.languageHomeOverride",
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();
var module = ctx.eval("enso", """

var code ="""
from Standard.Base import all

all_length v = v.fold 0 (sum -> str -> sum + str.length)
Expand All @@ -51,7 +52,10 @@ public void initializeBenchmark(BenchmarkParams params) {
s = "Long string".repeat rep
v = Vector.new len (_ -> s)
v
""");
""";
var benchmarkName = SrcUtil.findName(params);
var src = SrcUtil.source(benchmarkName, code);
var module = ctx.eval(src);

this.self = module.invokeMember("get_associated_type");
Function<String,Value> getMethod = (name) -> module.invokeMember("get_method", self, name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class TypePatternBenchmarks {
private Value self;

@Setup
public void initializeBenchmark(BenchmarkParams params) {
public void initializeBenchmark(BenchmarkParams params) throws Exception {
var ctx = Context.newBuilder()
.allowExperimentalOptions(true)
.allowIO(true)
Expand All @@ -35,7 +35,7 @@ public void initializeBenchmark(BenchmarkParams params) {
"enso.languageHomeOverride",
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();
var module = ctx.eval("enso", """
var code ="""
from Standard.Base import Integer, Vector, Any, Decimal

avg arr =
Expand All @@ -60,14 +60,17 @@ public void initializeBenchmark(BenchmarkParams params) {

match_dec = v -> case v of
n : Decimal -> n + 1
""");
""";
var benchmarkName = SrcUtil.findName(params);
var src = SrcUtil.source(benchmarkName, code);
var module = ctx.eval(src);

this.self = module.invokeMember("get_associated_type");
Function<String,Value> getMethod = (name) -> module.invokeMember("get_method", self, name);

var length = 100;
this.vec = getMethod.apply("gen_vec").execute(self, length, 1.1);
switch (params.getBenchmark().replaceFirst(".*\\.", "")) {
switch (SrcUtil.findName(params)) {
case "matchOverAny":
this.patternMatch = getMethod.apply("match_any");
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void initializeBenchmark(BenchmarkParams params) throws Exception {
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

var benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
var benchmarkName = SrcUtil.findName(params);
var code = """
import Standard.Base.Data.Vector.Vector
import Standard.Base.Data.Array_Proxy.Array_Proxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class WarningBenchmarks extends TestBase {
public void initializeBench(BenchmarkParams params) throws IOException {
ctx = createDefaultContext();

benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
benchmarkName = SrcUtil.findName(params);

var code = """
from Standard.Base import all
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.enso.interpreter.node;

import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.node.callable.ExecuteCallNode;

/**
* More effective {@link DirectCallNode} alternative. Supports more aggressive inlining needed by
* {@link ExecuteCallNode}.
*/
public abstract class InlineableNode extends Node {
/**
* Invokes the computation represented by the node.
*
* @param frame current frame of the caller
* @param arguments arguments for the functionality
* @return result of the computation
*/
public abstract Object call(VirtualFrame frame, Object[] arguments);

/**
* Special interface that allows various {@link RootNode} subclasses to provide more effective
* implementation of {@link DirectCallNode} alternative. Used by for example by {@code
* BuiltinRootNode}.
*/
public interface Root {
/**
* Provides access to {@link RootNode}. Usually the object shall inherit from {link RootNode} as
* well as implement the {@link InlineableNode} interface. This method thus usually returns
* {@code this}.
*
* @return {@code this} types as {link RootNode}
*/
public RootNode getRootNode();
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved

/**
* Name of the {@link RootNode}.
*
* @return root node name
*/
public String getName();

/**
* Override to provide more effective implementation of {@link DirectCallNode} alternative.
* Suited more for Enso aggressive inlining.
*
* @return a node to call the associated {@link RootNode} - may return {@code null}
*/
public InlineableNode createInlineableNode();
}
}
Loading