Skip to content

Commit

Permalink
Simplify compilation of nested patterns (#4005)
Browse files Browse the repository at this point in the history
`NestedPatternMatch` pass desugared complex patterns in a very inefficient way resulting in an exponential generation of the number of `case` IR (and Truffle) nodes. Every failed nested pattern would copy all the remaining patterns of the original case expression, in a desugared form. While the execution itself of such deeply nested `case` expression might not have many problems, the time spent in compilation phases certainly was a blocker.

This change desugars deeply nested into individual cases with a fallthrough logic. However the fallthrough logic is implemented directly in Truffle nodes, rather than via IR. That way we can generate much simpler IR for nested patterns.

Consider a simple case of
```
case x of
Cons (Cons a b) Nil -> a + b
Cons a Nil -> a
_ -> 0
```

Before the change, the compiler would generate rather large IR even for those two patterns:
```
case x of
Cons w y -> case w of
Cons a b -> case y of
Nil -> a + b
_ -> case x of
Cons a z -> case z of
Nil -> a
_ -> case x of
_ -> 0
_ -> 0
_ -> case x of
Cons a z -> case z of
Nil -> a
_ -> case x of
_ -> 0
_ -> 0
Cons a z -> case z of
Nil -> a
_ -> case x of
_ -> 0
_ -> 0
```

Now we generate simple patterns with fallthrough semantics and no catch-all branches:
```
case x of
Cons w y -> case w of
Cons a b -> case y of   ## fallthrough on failed match ##
Nil -> a + b                ## fallthrough on failed match ##
Cons a z -> case z of
Nil -> a                          ## fallthrough on failed match ##
_ -> 0
```

# Important Notes
If you wonder how much does it improve, then @radeusgd's example in https://www.pivotaltracker.com/story/show/183971366/comments/234688327 used to take at least 8 minutes to compile and run.
Now it takes 5 seconds from cold start.

Also, the example in the benchmark includes compilation time on purpose (that was the main culprit of the slowdown).
For the old implementation I had to kill it after 15 minutes as it still wouldn't finish a single compilation.
Now it runs 2 seconds or less.

Bonus points: This PR will also fix problem reported in https://www.pivotaltracker.com/story/show/184071954 (duplicate errors for nested patterns)
  • Loading branch information
hubertp authored Dec 30, 2022
1 parent af57d14 commit d15bd8a
Show file tree
Hide file tree
Showing 28 changed files with 443 additions and 251 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@
- [Add executionContext/interrupt API command][3952]
- [Any.== is a builtin method][3956]
- [Simplify exception handling for polyglot exceptions][3981]
- [Simplify compilation of nested patterns][4005]
- [IGV can jump to JMH sources & more][4008]

[3227]: https://github.com/enso-org/enso/pull/3227
Expand Down Expand Up @@ -572,6 +573,7 @@
[3952]: https://github.com/enso-org/enso/pull/3952
[3956]: https://github.com/enso-org/enso/pull/3956
[3981]: https://github.com/enso-org/enso/pull/3981
[4005]: https://github.com/enso-org/enso/pull/4005
[4008]: https://github.com/enso-org/enso/pull/4008

# Enso 2.0.0-alpha.18 (2021-10-12)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package org.enso.interpreter.bench.benchmarks.semantic;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.AbstractList;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Value;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.BenchmarkParams;
import org.openjdk.jmh.infra.Blackhole;


@BenchmarkMode(Mode.AverageTime)
@Fork(1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
public class NestedPatternCompilationBenchmarks {
private Value self;
private String benchmarkName;
private String code;
private Context ctx;

@Setup
public void initializeBenchmark(BenchmarkParams params) throws Exception {
ctx = Context.newBuilder()
.allowExperimentalOptions(true)
.allowIO(true)
.allowAllAccess(true)
.logHandler(new ByteArrayOutputStream())
.option(
"enso.languageHomeOverride",
Paths.get("../../distribution/component").toFile().getAbsolutePath()
).build();

benchmarkName = params.getBenchmark().replaceFirst(".*\\.", "");
code = """
type List
Cons a b
Nil
test x =
case x of
List.Nil -> 0
List.Cons a List.Nil -> a
List.Cons a (List.Cons b List.Nil) -> a+b
List.Cons a (List.Cons b (List.Cons c List.Nil)) -> a+b+c
List.Cons a (List.Cons b (List.Cons c (List.Cons d List.Nil))) -> a+b+c+d
List.Cons a (List.Cons b (List.Cons c (List.Cons d (List.Cons e List.Nil)))) -> a+b+c+d+e
List.Cons a (List.Cons b (List.Cons c (List.Cons d (List.Cons e (List.Cons f List.Nil))))) -> a+b+c+d+e+f
list_of_6 =
List.Cons 1 (List.Cons 2 (List.Cons 3 (List.Cons 4 (List.Cons 5 (List.Cons 6 List.Nil)))))
""";
}

@Benchmark
public void sumList(Blackhole hole) throws IOException {
// Compilation is included in the benchmark on purpose
var module = ctx.eval(SrcUtil.source(benchmarkName, code));

this.self = module.invokeMember("get_associated_type");
Function<String,Value> getMethod = (name) -> module.invokeMember("get_method", self, name);

var list = getMethod.apply("list_of_6").execute(self);
var result = getMethod.apply("test").execute(self, list);

if (!result.fitsInDouble()) {
throw new AssertionError("Shall be a double: " + result);
}
var calculated = (long) result.asDouble();
var expected = 21;
if (calculated != expected) {
throw new AssertionError("Expected " + expected + " from sum but got " + calculated);
}
hole.consume(result);
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public abstract class BooleanBranchNode extends BranchNode {
private final boolean matched;
private final ConditionProfile profile = ConditionProfile.createCountingProfile();

BooleanBranchNode(boolean matched, RootCallTarget branch) {
super(branch);
BooleanBranchNode(boolean matched, RootCallTarget branch, boolean terminalBranch) {
super(branch, terminalBranch);
this.matched = matched;
}

Expand All @@ -25,8 +25,9 @@ public abstract class BooleanBranchNode extends BranchNode {
* @param branch the expression to be executed if (@code matcher} matches
* @return a node for matching in a case expression
*/
public static BooleanBranchNode build(boolean matched, RootCallTarget branch) {
return BooleanBranchNodeGen.create(matched, branch);
public static BooleanBranchNode build(
boolean matched, RootCallTarget branch, boolean terminalBranch) {
return BooleanBranchNodeGen.create(matched, branch, terminalBranch);
}

@Specialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.runtime.callable.function.Function;

/** An abstract representation of a case branch. */
@NodeInfo(shortName = "case_branch", description = "Represents a case branch at runtime.")
public abstract class BranchNode extends BaseNode {
private @Child DirectCallNode callNode;
private final boolean terminalBranch;

BranchNode(RootCallTarget branch) {
private final ConditionProfile finalBranchProfiler = ConditionProfile.createCountingProfile();
private final ConditionProfile propgateResultProfiler = ConditionProfile.createCountingProfile();
private final ConditionProfile ensureWrappedProfiler = ConditionProfile.createCountingProfile();

BranchNode(RootCallTarget branch, boolean terminalBranch) {
this.callNode = DirectCallNode.create(branch);
this.terminalBranch = terminalBranch;
}

/**
Expand All @@ -36,7 +43,22 @@ protected void accept(VirtualFrame frame, Object state, Object[] args) {
// Note [Caller Info For Case Branches]
var result =
callNode.call(Function.ArgumentsHelper.buildArguments(frame.materialize(), state, args));
throw new BranchSelectedException(result);

if (finalBranchProfiler.profile(terminalBranch)) {
throw new BranchSelectedException(ensureWrapped(result));
}
// Note [Guaranteed BranchResult instance in non-terminal branches]
BranchResult result1 = (BranchResult) result;
if (propgateResultProfiler.profile(result1.isMatched()))
throw new BranchSelectedException(result1);
}

private BranchResult ensureWrapped(Object result) {
if (ensureWrappedProfiler.profile(result instanceof BranchResult)) {
return (BranchResult) result;
} else {
return BranchResult.success(result);
}
}

/* Note [Caller Info For Case Branches]
Expand All @@ -46,6 +68,14 @@ protected void accept(VirtualFrame frame, Object state, Object[] args) {
* have no way of accessing the caller frame and can safely be passed null.
*/

/* Note [Guaranteed BranchResult instance in non-terminal branches]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* The NestedPatternMatch phase desugars complex patterns into individual. simple, patterns.
* An intermediate branch either propagates a failed case or a result of executing a
* successful and complete pattern. In both cases the result is wrapped in BranchResult to
* encapsulate that information.
*/

/* Note [Safe Casting to Function in Catch All Branches]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* The syntactic nature of a catch all node guarantees that it has _only one_
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.enso.interpreter.node.controlflow.caseexpr;

import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.runtime.EnsoContext;

public record BranchResult(boolean isMatched, Object result) {

public static BranchResult failure(Node node) {
return new BranchResult(false, EnsoContext.get(node).getBuiltins().nothing());
}

public static BranchResult success(Object result) {
return new BranchResult(true, result);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
/** This exception is used to signal when a certain branch in a case expression has been taken. */
@NodeInfo(shortName = "BranchSelect", description = "Signals that a case branch has been selected")
public class BranchSelectedException extends ControlFlowException {
private final Object result;
private final BranchResult result;

/**
* Creates a new exception instance.
* Creates a new exception instance. The result is wrapped in `CaseResult` to indiciate if the
* result represents a failed nested branch or actually a complete and successful execution of a
* (potentially nested) branch.
*
* @param result the result of executing the branch this is thrown from
*/
public BranchSelectedException(Object result) {
public BranchSelectedException(BranchResult result) {
this.result = result;
}

Expand All @@ -22,7 +24,7 @@ public BranchSelectedException(Object result) {
*
* @return the result of executing the case branch from which this is thrown
*/
public Object getResult() {
public BranchResult getBranchResult() {
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.error.*;
import org.enso.interpreter.runtime.state.State;
import org.enso.interpreter.runtime.type.TypesGen;
Expand All @@ -28,20 +28,26 @@
public abstract class CaseNode extends ExpressionNode {

@Children private final BranchNode[] cases;
private final boolean isNested;

CaseNode(BranchNode[] cases) {
private final ConditionProfile fallthroughProfile = ConditionProfile.createCountingProfile();

CaseNode(boolean isNested, BranchNode[] cases) {
this.cases = cases;
this.isNested = isNested;
}

/**
* Creates an instance of this node.
*
* @param scrutinee the value being scrutinised
* @param cases the case branches
* @param isNested if true, the flag indicates that the case node represents a nested pattern. If
* false, the case node represents a top-level case involving potentially nested patterns.
* @return a node representing a pattern match
*/
public static CaseNode build(ExpressionNode scrutinee, BranchNode[] cases) {
return CaseNodeGen.create(cases, scrutinee);
public static CaseNode build(ExpressionNode scrutinee, BranchNode[] cases, boolean isNested) {
return CaseNodeGen.create(isNested, cases, scrutinee);
}

/**
Expand Down Expand Up @@ -104,12 +110,16 @@ public Object doMatch(
for (BranchNode branchNode : cases) {
branchNode.execute(frame, state, object);
}
CompilerDirectives.transferToInterpreter();
throw new PanicException(
EnsoContext.get(this).getBuiltins().error().makeInexhaustivePatternMatch(object), this);
if (fallthroughProfile.profile(isNested)) {
return BranchResult.failure(this);
} else {
CompilerDirectives.transferToInterpreter();
throw new PanicException(
EnsoContext.get(this).getBuiltins().error().makeInexhaustivePatternMatch(object), this);
}
} catch (BranchSelectedException e) {
// Note [Branch Selection Control Flow]
return e.getResult();
return isNested ? e.getBranchResult() : e.getBranchResult().result();
}
}

Expand All @@ -132,5 +142,13 @@ boolean isPanicSentinel(Object sentinel) {
*
* The main alternative to this was desugaring to a nested-if, which would've been significantly
* harder to maintain, and also resulted in significantly higher code complexity.
*
* Note that the CaseNode may return either a BranchResult or it's underlying value.
* This depends on whether the current CaseNode has been constructed as part of the desugaring phase
* for nested patterns.
* Case expressions that are synthetic, correspond to nested patterns and must propagate additional
* information about the state of the match. That way, in the case of a failure in a deeply
* nested case, other branches of the original case expression are tried.
* `isNested` check ensures that `CaseResult` never leaks outside the CaseNode/BranchNode hierarchy.
*/
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
description = "An explicit catch-all branch in a case expression")
public class CatchAllBranchNode extends BranchNode {

private CatchAllBranchNode(RootCallTarget functionNode) {
super(functionNode);
private CatchAllBranchNode(RootCallTarget functionNode, boolean terminalBranch) {
super(functionNode, terminalBranch);
}

/**
Expand All @@ -23,8 +23,8 @@ private CatchAllBranchNode(RootCallTarget functionNode) {
* @param functionNode the function to execute in this case
* @return a catch-all node
*/
public static CatchAllBranchNode build(RootCallTarget functionNode) {
return new CatchAllBranchNode(functionNode);
public static CatchAllBranchNode build(RootCallTarget functionNode, boolean terminalBranch) {
return new CatchAllBranchNode(functionNode, terminalBranch);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class CatchTypeBranchNode extends BranchNode {
private @Child IsValueOfTypeNode isValueOfTypeNode = IsValueOfTypeNode.build();
private final ConditionProfile profile = ConditionProfile.createCountingProfile();

CatchTypeBranchNode(Type tpe, RootCallTarget functionNode) {
super(functionNode);
CatchTypeBranchNode(Type tpe, RootCallTarget functionNode, boolean terminalBranch) {
super(functionNode, terminalBranch);
this.expectedType = tpe;
}

Expand All @@ -27,8 +27,9 @@ public class CatchTypeBranchNode extends BranchNode {
* @param functionNode the function to execute in this case
* @return a catch-all node
*/
public static CatchTypeBranchNode build(Type tpe, RootCallTarget functionNode) {
return new CatchTypeBranchNode(tpe, functionNode);
public static CatchTypeBranchNode build(
Type tpe, RootCallTarget functionNode, boolean terminalBranch) {
return new CatchTypeBranchNode(tpe, functionNode, terminalBranch);
}

public void execute(VirtualFrame frame, Object state, Object value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public abstract class ConstructorBranchNode extends BranchNode {
private final AtomConstructor matcher;
private final ConditionProfile profile = ConditionProfile.createCountingProfile();

ConstructorBranchNode(AtomConstructor matcher, RootCallTarget branch) {
super(branch);
ConstructorBranchNode(AtomConstructor matcher, RootCallTarget branch, boolean terminalBranch) {
super(branch, terminalBranch);
this.matcher = matcher;
}

Expand All @@ -27,8 +27,9 @@ public abstract class ConstructorBranchNode extends BranchNode {
* @param branch the expression to be executed if (@code matcher} matches
* @return a node for matching in a case expression
*/
public static ConstructorBranchNode build(AtomConstructor matcher, RootCallTarget branch) {
return ConstructorBranchNodeGen.create(matcher, branch);
public static ConstructorBranchNode build(
AtomConstructor matcher, RootCallTarget branch, boolean terminalBranch) {
return ConstructorBranchNodeGen.create(matcher, branch, terminalBranch);
}

@Specialization
Expand Down
Loading

0 comments on commit d15bd8a

Please sign in to comment.