From d15bd8ab3b6cd044d2371ccee73830bd153e5d29 Mon Sep 17 00:00:00 2001 From: Hubert Plociniczak Date: Fri, 30 Dec 2022 11:56:27 +0100 Subject: [PATCH] Simplify compilation of nested patterns (#4005) `NestedPatternMatch` pass desugared complex patterns in a very inefficient way resulting in an exponential generation of the number of `case` IR (and Truffle) nodes. Every failed nested pattern would copy all the remaining patterns of the original case expression, in a desugared form. While the execution itself of such deeply nested `case` expression might not have many problems, the time spent in compilation phases certainly was a blocker. This change desugars deeply nested into individual cases with a fallthrough logic. However the fallthrough logic is implemented directly in Truffle nodes, rather than via IR. That way we can generate much simpler IR for nested patterns. Consider a simple case of ``` case x of Cons (Cons a b) Nil -> a + b Cons a Nil -> a _ -> 0 ``` Before the change, the compiler would generate rather large IR even for those two patterns: ``` case x of Cons w y -> case w of Cons a b -> case y of Nil -> a + b _ -> case x of Cons a z -> case z of Nil -> a _ -> case x of _ -> 0 _ -> 0 _ -> case x of Cons a z -> case z of Nil -> a _ -> case x of _ -> 0 _ -> 0 Cons a z -> case z of Nil -> a _ -> case x of _ -> 0 _ -> 0 ``` Now we generate simple patterns with fallthrough semantics and no catch-all branches: ``` case x of Cons w y -> case w of Cons a b -> case y of ## fallthrough on failed match ## Nil -> a + b ## fallthrough on failed match ## Cons a z -> case z of Nil -> a ## fallthrough on failed match ## _ -> 0 ``` # Important Notes If you wonder how much does it improve, then @radeusgd's example in https://www.pivotaltracker.com/story/show/183971366/comments/234688327 used to take at least 8 minutes to compile and run. Now it takes 5 seconds from cold start. Also, the example in the benchmark includes compilation time on purpose (that was the main culprit of the slowdown). For the old implementation I had to kill it after 15 minutes as it still wouldn't finish a single compilation. Now it runs 2 seconds or less. Bonus points: This PR will also fix problem reported in https://www.pivotaltracker.com/story/show/184071954 (duplicate errors for nested patterns) --- CHANGELOG.md | 2 + .../NestedPatternCompilationBenchmarks.java | 93 ++++++++++++ .../caseexpr/BooleanBranchNode.java | 9 +- .../node/controlflow/caseexpr/BranchNode.java | 34 ++++- .../controlflow/caseexpr/BranchResult.java | 16 +++ .../caseexpr/BranchSelectedException.java | 10 +- .../node/controlflow/caseexpr/CaseNode.java | 34 +++-- .../caseexpr/CatchAllBranchNode.java | 8 +- .../caseexpr/CatchTypeBranchNode.java | 9 +- .../caseexpr/ConstructorBranchNode.java | 9 +- .../caseexpr/NumericLiteralBranchNode.java | 19 +-- .../caseexpr/ObjectEqualityBranchNode.java | 10 +- .../caseexpr/PolyglotBranchNode.java | 9 +- .../PolyglotSymbolTypeBranchNode.java | 9 +- .../caseexpr/StringLiteralBranchNode.java | 9 +- .../enso/compiler/codegen/IrToTruffle.scala | 67 ++++++--- .../scala/org/enso/compiler/core/IR.scala | 91 +++++++++++- .../compiler/pass/analyse/AliasAnalysis.scala | 2 +- .../pass/analyse/DataflowAnalysis.scala | 10 +- .../pass/analyse/DemandAnalysis.scala | 2 +- .../enso/compiler/pass/analyse/TailCall.scala | 2 +- .../pass/desugar/NestedPatternMatch.scala | 95 ++++--------- .../pass/lint/ShadowedPatternFields.scala | 2 +- .../compiler/pass/lint/UnusedBindings.scala | 2 +- .../optimise/UnreachableMatchBranches.scala | 2 +- .../pass/resolve/DocumentationComments.scala | 4 +- .../pass/resolve/IgnoredBindings.scala | 2 +- .../pass/desugar/NestedPatternMatchTest.scala | 133 ++++++------------ 28 files changed, 443 insertions(+), 251 deletions(-) create mode 100644 engine/runtime/src/bench/java/org/enso/interpreter/bench/benchmarks/semantic/NestedPatternCompilationBenchmarks.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchResult.java diff --git a/CHANGELOG.md b/CHANGELOG.md index e998420b84fc..cf4833f975cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -494,6 +494,7 @@ - [Add executionContext/interrupt API command][3952] - [Any.== is a builtin method][3956] - [Simplify exception handling for polyglot exceptions][3981] +- [Simplify compilation of nested patterns][4005] - [IGV can jump to JMH sources & more][4008] [3227]: https://github.com/enso-org/enso/pull/3227 @@ -572,6 +573,7 @@ [3952]: https://github.com/enso-org/enso/pull/3952 [3956]: https://github.com/enso-org/enso/pull/3956 [3981]: https://github.com/enso-org/enso/pull/3981 +[4005]: https://github.com/enso-org/enso/pull/4005 [4008]: https://github.com/enso-org/enso/pull/4008 # Enso 2.0.0-alpha.18 (2021-10-12) diff --git a/engine/runtime/src/bench/java/org/enso/interpreter/bench/benchmarks/semantic/NestedPatternCompilationBenchmarks.java b/engine/runtime/src/bench/java/org/enso/interpreter/bench/benchmarks/semantic/NestedPatternCompilationBenchmarks.java new file mode 100644 index 000000000000..7d2eca3b2e4a --- /dev/null +++ b/engine/runtime/src/bench/java/org/enso/interpreter/bench/benchmarks/semantic/NestedPatternCompilationBenchmarks.java @@ -0,0 +1,93 @@ +package org.enso.interpreter.bench.benchmarks.semantic; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.file.Paths; +import java.util.AbstractList; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.Value; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.BenchmarkParams; +import org.openjdk.jmh.infra.Blackhole; + + +@BenchmarkMode(Mode.AverageTime) +@Fork(1) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class NestedPatternCompilationBenchmarks { + private Value self; + private String benchmarkName; + private String code; + private Context ctx; + + @Setup + public void initializeBenchmark(BenchmarkParams params) throws Exception { + ctx = Context.newBuilder() + .allowExperimentalOptions(true) + .allowIO(true) + .allowAllAccess(true) + .logHandler(new ByteArrayOutputStream()) + .option( + "enso.languageHomeOverride", + Paths.get("../../distribution/component").toFile().getAbsolutePath() + ).build(); + + benchmarkName = params.getBenchmark().replaceFirst(".*\\.", ""); + code = """ + type List + Cons a b + Nil + + test x = + case x of + List.Nil -> 0 + List.Cons a List.Nil -> a + List.Cons a (List.Cons b List.Nil) -> a+b + List.Cons a (List.Cons b (List.Cons c List.Nil)) -> a+b+c + List.Cons a (List.Cons b (List.Cons c (List.Cons d List.Nil))) -> a+b+c+d + List.Cons a (List.Cons b (List.Cons c (List.Cons d (List.Cons e List.Nil)))) -> a+b+c+d+e + List.Cons a (List.Cons b (List.Cons c (List.Cons d (List.Cons e (List.Cons f List.Nil))))) -> a+b+c+d+e+f + + list_of_6 = + List.Cons 1 (List.Cons 2 (List.Cons 3 (List.Cons 4 (List.Cons 5 (List.Cons 6 List.Nil))))) + """; + } + + @Benchmark + public void sumList(Blackhole hole) throws IOException { + // Compilation is included in the benchmark on purpose + var module = ctx.eval(SrcUtil.source(benchmarkName, code)); + + this.self = module.invokeMember("get_associated_type"); + Function getMethod = (name) -> module.invokeMember("get_method", self, name); + + var list = getMethod.apply("list_of_6").execute(self); + var result = getMethod.apply("test").execute(self, list); + + if (!result.fitsInDouble()) { + throw new AssertionError("Shall be a double: " + result); + } + var calculated = (long) result.asDouble(); + var expected = 21; + if (calculated != expected) { + throw new AssertionError("Expected " + expected + " from sum but got " + calculated); + } + hole.consume(result); + } + +} + diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BooleanBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BooleanBranchNode.java index 957f556b365f..4c086150c285 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BooleanBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BooleanBranchNode.java @@ -13,8 +13,8 @@ public abstract class BooleanBranchNode extends BranchNode { private final boolean matched; private final ConditionProfile profile = ConditionProfile.createCountingProfile(); - BooleanBranchNode(boolean matched, RootCallTarget branch) { - super(branch); + BooleanBranchNode(boolean matched, RootCallTarget branch, boolean terminalBranch) { + super(branch, terminalBranch); this.matched = matched; } @@ -25,8 +25,9 @@ public abstract class BooleanBranchNode extends BranchNode { * @param branch the expression to be executed if (@code matcher} matches * @return a node for matching in a case expression */ - public static BooleanBranchNode build(boolean matched, RootCallTarget branch) { - return BooleanBranchNodeGen.create(matched, branch); + public static BooleanBranchNode build( + boolean matched, RootCallTarget branch, boolean terminalBranch) { + return BooleanBranchNodeGen.create(matched, branch, terminalBranch); } @Specialization diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchNode.java index db3e3fc5c545..73200a813bc4 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchNode.java @@ -4,6 +4,7 @@ import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.DirectCallNode; import com.oracle.truffle.api.nodes.NodeInfo; +import com.oracle.truffle.api.profiles.ConditionProfile; import org.enso.interpreter.node.BaseNode; import org.enso.interpreter.runtime.callable.function.Function; @@ -11,9 +12,15 @@ @NodeInfo(shortName = "case_branch", description = "Represents a case branch at runtime.") public abstract class BranchNode extends BaseNode { private @Child DirectCallNode callNode; + private final boolean terminalBranch; - BranchNode(RootCallTarget branch) { + private final ConditionProfile finalBranchProfiler = ConditionProfile.createCountingProfile(); + private final ConditionProfile propgateResultProfiler = ConditionProfile.createCountingProfile(); + private final ConditionProfile ensureWrappedProfiler = ConditionProfile.createCountingProfile(); + + BranchNode(RootCallTarget branch, boolean terminalBranch) { this.callNode = DirectCallNode.create(branch); + this.terminalBranch = terminalBranch; } /** @@ -36,7 +43,22 @@ protected void accept(VirtualFrame frame, Object state, Object[] args) { // Note [Caller Info For Case Branches] var result = callNode.call(Function.ArgumentsHelper.buildArguments(frame.materialize(), state, args)); - throw new BranchSelectedException(result); + + if (finalBranchProfiler.profile(terminalBranch)) { + throw new BranchSelectedException(ensureWrapped(result)); + } + // Note [Guaranteed BranchResult instance in non-terminal branches] + BranchResult result1 = (BranchResult) result; + if (propgateResultProfiler.profile(result1.isMatched())) + throw new BranchSelectedException(result1); + } + + private BranchResult ensureWrapped(Object result) { + if (ensureWrappedProfiler.profile(result instanceof BranchResult)) { + return (BranchResult) result; + } else { + return BranchResult.success(result); + } } /* Note [Caller Info For Case Branches] @@ -46,6 +68,14 @@ protected void accept(VirtualFrame frame, Object state, Object[] args) { * have no way of accessing the caller frame and can safely be passed null. */ + /* Note [Guaranteed BranchResult instance in non-terminal branches] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * The NestedPatternMatch phase desugars complex patterns into individual. simple, patterns. + * An intermediate branch either propagates a failed case or a result of executing a + * successful and complete pattern. In both cases the result is wrapped in BranchResult to + * encapsulate that information. + */ + /* Note [Safe Casting to Function in Catch All Branches] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * The syntactic nature of a catch all node guarantees that it has _only one_ diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchResult.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchResult.java new file mode 100644 index 000000000000..e3798ef8b15b --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchResult.java @@ -0,0 +1,16 @@ +package org.enso.interpreter.node.controlflow.caseexpr; + +import com.oracle.truffle.api.nodes.Node; +import org.enso.interpreter.runtime.EnsoContext; + +public record BranchResult(boolean isMatched, Object result) { + + public static BranchResult failure(Node node) { + return new BranchResult(false, EnsoContext.get(node).getBuiltins().nothing()); + } + + public static BranchResult success(Object result) { + return new BranchResult(true, result); + } + +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchSelectedException.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchSelectedException.java index eec9355e87f0..8be24aca749a 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchSelectedException.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/BranchSelectedException.java @@ -6,14 +6,16 @@ /** This exception is used to signal when a certain branch in a case expression has been taken. */ @NodeInfo(shortName = "BranchSelect", description = "Signals that a case branch has been selected") public class BranchSelectedException extends ControlFlowException { - private final Object result; + private final BranchResult result; /** - * Creates a new exception instance. + * Creates a new exception instance. The result is wrapped in `CaseResult` to indiciate if the + * result represents a failed nested branch or actually a complete and successful execution of a + * (potentially nested) branch. * * @param result the result of executing the branch this is thrown from */ - public BranchSelectedException(Object result) { + public BranchSelectedException(BranchResult result) { this.result = result; } @@ -22,7 +24,7 @@ public BranchSelectedException(Object result) { * * @return the result of executing the case branch from which this is thrown */ - public Object getResult() { + public BranchResult getBranchResult() { return result; } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CaseNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CaseNode.java index afd95e8827a9..c9f14e8e881e 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CaseNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CaseNode.java @@ -8,10 +8,10 @@ import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.NodeInfo; +import com.oracle.truffle.api.profiles.ConditionProfile; import org.enso.interpreter.node.ExpressionNode; import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.callable.function.Function; -import org.enso.interpreter.runtime.data.ArrayRope; import org.enso.interpreter.runtime.error.*; import org.enso.interpreter.runtime.state.State; import org.enso.interpreter.runtime.type.TypesGen; @@ -28,9 +28,13 @@ public abstract class CaseNode extends ExpressionNode { @Children private final BranchNode[] cases; + private final boolean isNested; - CaseNode(BranchNode[] cases) { + private final ConditionProfile fallthroughProfile = ConditionProfile.createCountingProfile(); + + CaseNode(boolean isNested, BranchNode[] cases) { this.cases = cases; + this.isNested = isNested; } /** @@ -38,10 +42,12 @@ public abstract class CaseNode extends ExpressionNode { * * @param scrutinee the value being scrutinised * @param cases the case branches + * @param isNested if true, the flag indicates that the case node represents a nested pattern. If + * false, the case node represents a top-level case involving potentially nested patterns. * @return a node representing a pattern match */ - public static CaseNode build(ExpressionNode scrutinee, BranchNode[] cases) { - return CaseNodeGen.create(cases, scrutinee); + public static CaseNode build(ExpressionNode scrutinee, BranchNode[] cases, boolean isNested) { + return CaseNodeGen.create(isNested, cases, scrutinee); } /** @@ -104,12 +110,16 @@ public Object doMatch( for (BranchNode branchNode : cases) { branchNode.execute(frame, state, object); } - CompilerDirectives.transferToInterpreter(); - throw new PanicException( - EnsoContext.get(this).getBuiltins().error().makeInexhaustivePatternMatch(object), this); + if (fallthroughProfile.profile(isNested)) { + return BranchResult.failure(this); + } else { + CompilerDirectives.transferToInterpreter(); + throw new PanicException( + EnsoContext.get(this).getBuiltins().error().makeInexhaustivePatternMatch(object), this); + } } catch (BranchSelectedException e) { // Note [Branch Selection Control Flow] - return e.getResult(); + return isNested ? e.getBranchResult() : e.getBranchResult().result(); } } @@ -132,5 +142,13 @@ boolean isPanicSentinel(Object sentinel) { * * The main alternative to this was desugaring to a nested-if, which would've been significantly * harder to maintain, and also resulted in significantly higher code complexity. + * + * Note that the CaseNode may return either a BranchResult or it's underlying value. + * This depends on whether the current CaseNode has been constructed as part of the desugaring phase + * for nested patterns. + * Case expressions that are synthetic, correspond to nested patterns and must propagate additional + * information about the state of the match. That way, in the case of a failure in a deeply + * nested case, other branches of the original case expression are tried. + * `isNested` check ensures that `CaseResult` never leaks outside the CaseNode/BranchNode hierarchy. */ } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchAllBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchAllBranchNode.java index af27982dd8df..99ea6309ceb6 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchAllBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchAllBranchNode.java @@ -13,8 +13,8 @@ description = "An explicit catch-all branch in a case expression") public class CatchAllBranchNode extends BranchNode { - private CatchAllBranchNode(RootCallTarget functionNode) { - super(functionNode); + private CatchAllBranchNode(RootCallTarget functionNode, boolean terminalBranch) { + super(functionNode, terminalBranch); } /** @@ -23,8 +23,8 @@ private CatchAllBranchNode(RootCallTarget functionNode) { * @param functionNode the function to execute in this case * @return a catch-all node */ - public static CatchAllBranchNode build(RootCallTarget functionNode) { - return new CatchAllBranchNode(functionNode); + public static CatchAllBranchNode build(RootCallTarget functionNode, boolean terminalBranch) { + return new CatchAllBranchNode(functionNode, terminalBranch); } /** diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchTypeBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchTypeBranchNode.java index e27cad043460..4e3f6c1d2c84 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchTypeBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/CatchTypeBranchNode.java @@ -15,8 +15,8 @@ public class CatchTypeBranchNode extends BranchNode { private @Child IsValueOfTypeNode isValueOfTypeNode = IsValueOfTypeNode.build(); private final ConditionProfile profile = ConditionProfile.createCountingProfile(); - CatchTypeBranchNode(Type tpe, RootCallTarget functionNode) { - super(functionNode); + CatchTypeBranchNode(Type tpe, RootCallTarget functionNode, boolean terminalBranch) { + super(functionNode, terminalBranch); this.expectedType = tpe; } @@ -27,8 +27,9 @@ public class CatchTypeBranchNode extends BranchNode { * @param functionNode the function to execute in this case * @return a catch-all node */ - public static CatchTypeBranchNode build(Type tpe, RootCallTarget functionNode) { - return new CatchTypeBranchNode(tpe, functionNode); + public static CatchTypeBranchNode build( + Type tpe, RootCallTarget functionNode, boolean terminalBranch) { + return new CatchTypeBranchNode(tpe, functionNode, terminalBranch); } public void execute(VirtualFrame frame, Object state, Object value) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ConstructorBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ConstructorBranchNode.java index a1a459282808..1d227fd817b8 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ConstructorBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ConstructorBranchNode.java @@ -15,8 +15,8 @@ public abstract class ConstructorBranchNode extends BranchNode { private final AtomConstructor matcher; private final ConditionProfile profile = ConditionProfile.createCountingProfile(); - ConstructorBranchNode(AtomConstructor matcher, RootCallTarget branch) { - super(branch); + ConstructorBranchNode(AtomConstructor matcher, RootCallTarget branch, boolean terminalBranch) { + super(branch, terminalBranch); this.matcher = matcher; } @@ -27,8 +27,9 @@ public abstract class ConstructorBranchNode extends BranchNode { * @param branch the expression to be executed if (@code matcher} matches * @return a node for matching in a case expression */ - public static ConstructorBranchNode build(AtomConstructor matcher, RootCallTarget branch) { - return ConstructorBranchNodeGen.create(matcher, branch); + public static ConstructorBranchNode build( + AtomConstructor matcher, RootCallTarget branch, boolean terminalBranch) { + return ConstructorBranchNodeGen.create(matcher, branch, terminalBranch); } @Specialization diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/NumericLiteralBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/NumericLiteralBranchNode.java index dd52e6326958..1831dfb08cf9 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/NumericLiteralBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/NumericLiteralBranchNode.java @@ -17,21 +17,24 @@ public abstract class NumericLiteralBranchNode extends BranchNode { private final ConditionProfile numProfile = ConditionProfile.createCountingProfile(); - NumericLiteralBranchNode(Object literal, RootCallTarget branch) { - super(branch); + NumericLiteralBranchNode(Object literal, RootCallTarget branch, boolean terminalBranch) { + super(branch, terminalBranch); this.literal = literal; } - public static NumericLiteralBranchNode build(long literal, RootCallTarget branch) { - return NumericLiteralBranchNodeGen.create(literal, branch); + public static NumericLiteralBranchNode build( + long literal, RootCallTarget branch, boolean terminalBranch) { + return NumericLiteralBranchNodeGen.create(literal, branch, terminalBranch); } - public static NumericLiteralBranchNode build(double literal, RootCallTarget branch) { - return NumericLiteralBranchNodeGen.create(literal, branch); + public static NumericLiteralBranchNode build( + double literal, RootCallTarget branch, boolean terminalBranch) { + return NumericLiteralBranchNodeGen.create(literal, branch, terminalBranch); } - public static NumericLiteralBranchNode build(BigInteger literal, RootCallTarget branch) { - return NumericLiteralBranchNodeGen.create(literal, branch); + public static NumericLiteralBranchNode build( + BigInteger literal, RootCallTarget branch, boolean terminalBranch) { + return NumericLiteralBranchNodeGen.create(literal, branch, terminalBranch); } @Specialization(guards = "interop.isNumber(target)") diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ObjectEqualityBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ObjectEqualityBranchNode.java index 5c996cf273f6..91c1470fed9b 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ObjectEqualityBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/ObjectEqualityBranchNode.java @@ -10,13 +10,13 @@ public class ObjectEqualityBranchNode extends BranchNode { private @Child IsSameObjectNode isSameObject = IsSameObjectNode.build(); private final ConditionProfile profile = ConditionProfile.createCountingProfile(); - public static BranchNode build(RootCallTarget branch, Object expected) { - return new ObjectEqualityBranchNode(branch, expected); + private ObjectEqualityBranchNode(RootCallTarget branch, Object expected, boolean terminalBranch) { + super(branch, terminalBranch); + this.expected = expected; } - private ObjectEqualityBranchNode(RootCallTarget branch, Object expected) { - super(branch); - this.expected = expected; + public static BranchNode build(RootCallTarget branch, Object expected, boolean terminalBranch) { + return new ObjectEqualityBranchNode(branch, expected, terminalBranch); } @Override diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotBranchNode.java index 3f335708d235..946faf6a7fe0 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotBranchNode.java @@ -15,8 +15,8 @@ public abstract class PolyglotBranchNode extends BranchNode { private final ConditionProfile constructorProfile = ConditionProfile.createCountingProfile(); private final ConditionProfile polyglotProfile = ConditionProfile.createCountingProfile(); - PolyglotBranchNode(Type polyglot, RootCallTarget branch) { - super(branch); + PolyglotBranchNode(Type polyglot, RootCallTarget branch, boolean terminalBranch) { + super(branch, terminalBranch); this.polyglot = polyglot; } @@ -27,8 +27,9 @@ public abstract class PolyglotBranchNode extends BranchNode { * @param branch the code to execute * @return an integer branch node */ - public static PolyglotBranchNode build(Type polyglot, RootCallTarget branch) { - return PolyglotBranchNodeGen.create(polyglot, branch); + public static PolyglotBranchNode build( + Type polyglot, RootCallTarget branch, boolean terminalBranch) { + return PolyglotBranchNodeGen.create(polyglot, branch, terminalBranch); } @Specialization diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotSymbolTypeBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotSymbolTypeBranchNode.java index c97599c51099..f40e4c80659e 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotSymbolTypeBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/PolyglotSymbolTypeBranchNode.java @@ -26,8 +26,9 @@ public abstract class PolyglotSymbolTypeBranchNode extends BranchNode { private final ConditionProfile profile = ConditionProfile.createCountingProfile(); private final ConditionProfile subtypeProfile = ConditionProfile.createCountingProfile(); - PolyglotSymbolTypeBranchNode(Object polyglotSymbol, RootCallTarget functionNode) { - super(functionNode); + PolyglotSymbolTypeBranchNode( + Object polyglotSymbol, RootCallTarget functionNode, boolean terminalBranch) { + super(functionNode, terminalBranch); this.polyglotSymbol = polyglotSymbol; } @@ -39,8 +40,8 @@ public abstract class PolyglotSymbolTypeBranchNode extends BranchNode { * @return a catch-all node */ public static PolyglotSymbolTypeBranchNode build( - Object polyglotSymbol, RootCallTarget functionNode) { - return PolyglotSymbolTypeBranchNodeGen.create(polyglotSymbol, functionNode); + Object polyglotSymbol, RootCallTarget functionNode, boolean terminalBranch) { + return PolyglotSymbolTypeBranchNodeGen.create(polyglotSymbol, functionNode, terminalBranch); } @Specialization diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/StringLiteralBranchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/StringLiteralBranchNode.java index 95fdceecac1a..f6517c810af1 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/StringLiteralBranchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/caseexpr/StringLiteralBranchNode.java @@ -19,13 +19,14 @@ public abstract class StringLiteralBranchNode extends BranchNode { private final ConditionProfile textProfile = ConditionProfile.createCountingProfile(); - StringLiteralBranchNode(String literal, RootCallTarget branch) { - super(branch); + StringLiteralBranchNode(String literal, RootCallTarget branch, boolean terminalBranch) { + super(branch, terminalBranch); this.literal = literal; } - public static StringLiteralBranchNode build(String literal, RootCallTarget branch) { - return StringLiteralBranchNodeGen.create(literal, branch); + public static StringLiteralBranchNode build( + String literal, RootCallTarget branch, boolean terminalBranch) { + return StringLiteralBranchNodeGen.create(literal, branch, terminalBranch); } @Specialization diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala index 6f824c37d8b4..52cd1aa781d9 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala @@ -882,7 +882,7 @@ class IrToTruffle( */ def processCase(caseExpr: IR.Case): RuntimeExpression = caseExpr match { - case IR.Case.Expr(scrutinee, branches, location, _, _) => + case IR.Case.Expr(scrutinee, branches, isNested, location, _, _) => val scrutineeNode = this.run(scrutinee) val maybeCases = branches.map(processCaseBranch) @@ -898,7 +898,8 @@ class IrToTruffle( // Note [Pattern Match Fallbacks] val matchExpr = CaseNode.build( scrutineeNode, - cases + cases, + isNested ) setLocation(matchExpr, location) } else { @@ -914,7 +915,7 @@ class IrToTruffle( setLocation(ErrorNode.build(error), caseExpr.location) } - case IR.Case.Branch(_, _, _, _, _) => + case _: IR.Case.Branch => throw new CompilerError("A CaseBranch should never occur here.") } @@ -947,7 +948,7 @@ class IrToTruffle( ) val branchNode = - CatchAllBranchNode.build(branchCodeNode.getCallTarget) + CatchAllBranchNode.build(branchCodeNode.getCallTarget, true) Right(branchNode) case cons @ Pattern.Constructor(constructor, _, _, _, _) => @@ -980,7 +981,8 @@ class IrToTruffle( Right( ObjectEqualityBranchNode.build( branchCodeNode.getCallTarget, - mod.unsafeAsModule().getScope.getAssociatedType + mod.unsafeAsModule().getScope.getAssociatedType, + branch.terminalBranch ) ) case Some( @@ -991,13 +993,22 @@ class IrToTruffle( val atomCons = tp.unsafeToRuntimeType().getConstructors.get(cons.name) val r = if (atomCons == context.getBuiltins.bool().getTrue) { - BooleanBranchNode.build(true, branchCodeNode.getCallTarget) + BooleanBranchNode.build( + true, + branchCodeNode.getCallTarget, + branch.terminalBranch + ) } else if (atomCons == context.getBuiltins.bool().getFalse) { - BooleanBranchNode.build(false, branchCodeNode.getCallTarget) + BooleanBranchNode.build( + false, + branchCodeNode.getCallTarget, + branch.terminalBranch + ) } else { ConstructorBranchNode.build( atomCons, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) } Right(r) @@ -1007,15 +1018,20 @@ class IrToTruffle( val tpe = mod.unsafeAsModule().getScope.getTypes.get(tp.name) val polyglot = context.getBuiltins.polyglot - val branch = if (tpe == polyglot) { - PolyglotBranchNode.build(tpe, branchCodeNode.getCallTarget) + val branchNode = if (tpe == polyglot) { + PolyglotBranchNode.build( + tpe, + branchCodeNode.getCallTarget, + branch.terminalBranch + ) } else { ObjectEqualityBranchNode.build( branchCodeNode.getCallTarget, - tpe + tpe, + branch.terminalBranch ) } - Right(branch) + Right(branchNode) case Some( BindingsMap.Resolution( BindingsMap.ResolvedPolyglotSymbol(mod, symbol) @@ -1029,7 +1045,11 @@ class IrToTruffle( Either.cond( polyglotSymbol != null, ObjectEqualityBranchNode - .build(branchCodeNode.getCallTarget, polyglotSymbol), + .build( + branchCodeNode.getCallTarget, + polyglotSymbol, + branch.terminalBranch + ), BadPatternMatch.NonVisiblePolyglotSymbol(symbol.name) ) case Some( @@ -1056,21 +1076,24 @@ class IrToTruffle( Right( NumericLiteralBranchNode.build( doubleVal, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) ) case longVal: Long => Right( NumericLiteralBranchNode.build( longVal, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) ) case bigIntVal: BigInteger => Right( NumericLiteralBranchNode.build( bigIntVal, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) ) case _ => @@ -1082,7 +1105,8 @@ class IrToTruffle( Right( StringLiteralBranchNode.build( text.text, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) ) } @@ -1116,7 +1140,11 @@ class IrToTruffle( branch.location ) Right( - CatchTypeBranchNode.build(tpe, branchCodeNode.getCallTarget) + CatchTypeBranchNode.build( + tpe, + branchCodeNode.getCallTarget, + branch.terminalBranch + ) ) case None => Left(BadPatternMatch.NonVisibleType(tpeName.name)) } @@ -1152,7 +1180,8 @@ class IrToTruffle( Right( PolyglotSymbolTypeBranchNode.build( polySymbol, - branchCodeNode.getCallTarget + branchCodeNode.getCallTarget, + branch.terminalBranch ) ) } else { diff --git a/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala b/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala index 37b925601e90..8f6b4713b26b 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala @@ -5709,6 +5709,7 @@ object IR { * * @param scrutinee the expression whose value is being matched on * @param branches the branches of the case expression + * @param isNested if true, the flag indicates that the expr represents a desugared nested case * @param location the source location that the node corresponds to * @param passData the pass metadata associated with this node * @param diagnostics compiler diagnostics for this node @@ -5716,6 +5717,7 @@ object IR { sealed case class Expr( scrutinee: Expression, branches: Seq[Branch], + isNested: Boolean, override val location: Option[IdentifiedLocation], override val passData: MetadataStorage = MetadataStorage(), override val diagnostics: DiagnosticStorage = DiagnosticStorage() @@ -5723,10 +5725,21 @@ object IR { with IRKind.Primitive { override protected var id: Identifier = randomId + def this( + scrutinee: Expression, + branches: Seq[Branch], + location: Option[IdentifiedLocation], + passData: MetadataStorage, + diagnostics: DiagnosticStorage + ) = { + this(scrutinee, branches, false, location, passData, diagnostics) + } + /** Creates a copy of `this`. * * @param scrutinee the expression whose value is being matched on * @param branches the branches of the case expression + * @param isNested if true, the flag indicates that the expr represents a desugared nested case * @param location the source location that the node corresponds to * @param passData the pass metadata associated with this node * @param diagnostics compiler diagnostics for this node @@ -5736,13 +5749,14 @@ object IR { def copy( scrutinee: Expression = scrutinee, branches: Seq[Branch] = branches, + isNested: Boolean = isNested, location: Option[IdentifiedLocation] = location, passData: MetadataStorage = passData, diagnostics: DiagnosticStorage = diagnostics, id: Identifier = id ): Expr = { val res = - Expr(scrutinee, branches, location, passData, diagnostics) + Expr(scrutinee, branches, isNested, location, passData, diagnostics) res.id = id res } @@ -5769,6 +5783,7 @@ object IR { keepIdentifiers ) ), + isNested = isNested, location = if (keepLocations) location else None, passData = if (keepMetadata) passData.duplicate else MetadataStorage(), @@ -5795,6 +5810,7 @@ object IR { |IR.Case.Expr( |scrutinee = $scrutinee, |branches = $branches, + |isNested = $isNested, |location = $location, |passData = ${this.showPassData}, |diagnostics = $diagnostics, @@ -5817,10 +5833,34 @@ object IR { } } + object Expr { + def apply( + scrutinee: Expression, + branches: Seq[Branch], + location: Option[IdentifiedLocation] + ): Expr = + apply( + scrutinee, + branches, + location, + new MetadataStorage(), + new DiagnosticStorage() + ) + + def apply( + scrutinee: Expression, + branches: Seq[Branch], + location: Option[IdentifiedLocation], + passData: MetadataStorage, + diagnostics: DiagnosticStorage + ): Expr = new Expr(scrutinee, branches, location, passData, diagnostics) + } + /** A branch in a case statement. * * @param pattern the pattern that attempts to match against the scrutinee * @param expression the expression that is executed if the pattern matches + * @param terminalBranch the flag indicating whether the branch represents the final pattern to be checked * @param location the source location that the node corresponds to * @param passData the pass metadata associated with this node * @param diagnostics compiler diagnostics for this node @@ -5828,6 +5868,7 @@ object IR { sealed case class Branch( pattern: Pattern, expression: Expression, + terminalBranch: Boolean, override val location: Option[IdentifiedLocation], override val passData: MetadataStorage = MetadataStorage(), override val diagnostics: DiagnosticStorage = DiagnosticStorage() @@ -5835,6 +5876,16 @@ object IR { with IRKind.Primitive { override protected var id: Identifier = randomId + def this( + pattern: Pattern, + expression: Expression, + location: Option[IdentifiedLocation], + passData: MetadataStorage, + diagnostics: DiagnosticStorage + ) = { + this(pattern, expression, true, location, passData, diagnostics) + } + /** Creates a copy of `this`. * * @param pattern the pattern that attempts to match against the scrutinee @@ -5848,12 +5899,20 @@ object IR { def copy( pattern: Pattern = pattern, expression: Expression = expression, + terminalBranch: Boolean = terminalBranch, location: Option[IdentifiedLocation] = location, passData: MetadataStorage = passData, diagnostics: DiagnosticStorage = diagnostics, id: Identifier = id ): Branch = { - val res = Branch(pattern, expression, location, passData, diagnostics) + val res = Branch( + pattern, + expression, + terminalBranch, + location, + passData, + diagnostics + ) res.id = id res } @@ -5878,7 +5937,8 @@ object IR { keepDiagnostics, keepIdentifiers ), - location = if (keepLocations) location else None, + terminalBranch = terminalBranch, + location = if (keepLocations) location else None, passData = if (keepMetadata) passData.duplicate else MetadataStorage(), diagnostics = @@ -5901,6 +5961,7 @@ object IR { |IR.Case.Branch( |pattern = $pattern, |expression = $expression, + |terminalBranch = $terminalBranch, |location = $location, |passData = ${this.showPassData}, |diagnostics = $diagnostics, @@ -5922,6 +5983,30 @@ object IR { s"${pattern.showCode(indent)} -> $bodyStr" } } + + object Branch { + def apply( + pattern: Pattern, + expression: Expression, + location: Option[IdentifiedLocation] + ): Branch = + apply( + pattern, + expression, + location, + new MetadataStorage(), + new DiagnosticStorage() + ) + + def apply( + pattern: Pattern, + expression: Expression, + location: Option[IdentifiedLocation], + passData: MetadataStorage, + diagnostics: DiagnosticStorage + ): Branch = + new Branch(pattern, expression, location, passData, diagnostics) + } } // === Patterns ============================================================= diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala index 6f906221f6e6..d8f690be7f16 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala @@ -665,7 +665,7 @@ case object AliasAnalysis extends IRPass { parentScope: Scope ): IR.Case = { ir match { - case caseExpr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case caseExpr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => caseExpr .copy( scrutinee = analyseExpression(scrutinee, graph, parentScope), diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala index 49914c91cc82..ec0842fbe751 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala @@ -574,12 +574,12 @@ case object DataflowAnalysis extends IRPass { */ def analyseCase(cse: IR.Case, info: DependencyInfo): IR.Case = { cse match { - case expr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case expr: IR.Case.Expr => val exprDep = asStatic(expr) - val scrutDep = asStatic(scrutinee) + val scrutDep = asStatic(expr.scrutinee) info.dependents.updateAt(scrutDep, Set(exprDep)) info.dependencies.updateAt(exprDep, Set(scrutDep)) - branches.foreach(branch => { + expr.branches.foreach(branch => { val branchDep = asStatic(branch) info.dependents.updateAt(branchDep, Set(exprDep)) info.dependencies.updateAt(exprDep, Set(branchDep)) @@ -587,8 +587,8 @@ case object DataflowAnalysis extends IRPass { expr .copy( - scrutinee = analyseExpression(scrutinee, info), - branches = branches.map(analyseCaseBranch(_, info)) + scrutinee = analyseExpression(expr.scrutinee, info), + branches = expr.branches.map(analyseCaseBranch(_, info)) ) .updateMetadata(this -->> info) case _: IR.Case.Branch => diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DemandAnalysis.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DemandAnalysis.scala index 83403bc0f32c..06e4a01441d4 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DemandAnalysis.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/DemandAnalysis.scala @@ -309,7 +309,7 @@ case object DemandAnalysis extends IRPass { isInsideCallArgument: Boolean ): IR.Case = cse match { - case expr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case expr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => expr.copy( scrutinee = analyseExpression( scrutinee, diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala index 627ae9852d44..4f2b3627ce1d 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala @@ -323,7 +323,7 @@ case object TailCall extends IRPass { */ def analyseCase(caseExpr: IR.Case, isInTailPosition: Boolean): IR.Case = { caseExpr match { - case caseExpr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case caseExpr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => caseExpr .copy( scrutinee = analyseExpression(scrutinee, isInTailPosition = false), diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala index bb1b9853a9b1..45e79adf2e47 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala @@ -30,51 +30,33 @@ import scala.annotation.unused * # Desugar Nil in first branch * case x of * Cons (Cons a b) y -> case y of - * Nil -> a + b - * _ -> case x of - * Cons a Nil -> a - * _ -> 0 + * Nil -> a + b ## fallthrough on failed match ## * Cons a Nil -> a * _ -> 0 * * # Desuar `Cons a b` in the first branch * case x of * Cons w y -> case w of - * Cons a b -> case y of - * Nil -> a + b - * _ -> case x of - * Cons a Nil -> a - * _ -> 0 - * _ -> case x of - * Cons a Nil -> a - * _ -> 0 + * Cons a b -> case y of ## fallthrough on failed match ## + * Nil -> a + b ## fallthrough on failed match ## * Cons a Nil -> a * _ -> 0 * * # Desugar `Cons a Nil` in the second branch * case x of * Cons w y -> case w of - * Cons a b -> case y of - * Nil -> a + b - * _ -> case x of - * Cons a z -> case z of - * Nil -> a - * _ -> case x of - * _ -> 0 - * _ -> 0 - * _ -> case x of - * Cons a z -> case z of - * Nil -> a - * _ -> case x of - * _ -> 0 - * _ -> 0 + * Cons a b -> case y of ## fallthrough on failed match ## + * Nil -> a + b ## fallthrough on failed match ## * Cons a z -> case z of - * Nil -> a - * _ -> case x of - * _ -> 0 + * Nil -> a ## fallthrough on failed match ## * _ -> 0 * }}} * + * Note how the desugaring discards unmatched branches for nested cases. + * This is done on purpose to simplify the constructed IR. Rather than + * implementing the fallthrough logic using IR, it is done in CaseNode/BranchNode + * Truffle nodes directly. + * * This pass requires no configuration. * * This pass requires the context to provide: @@ -178,7 +160,7 @@ case object NestedPatternMatch extends IRPass { freshNameSupply: FreshNameSupply ): IR.Expression = { expr match { - case expr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case expr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => val scrutineeBindingName = freshNameSupply.newName() val scrutineeExpression = desugarExpression(scrutinee, freshNameSupply) val scrutineeBinding = @@ -186,14 +168,10 @@ case object NestedPatternMatch extends IRPass { val caseExprScrutinee = scrutineeBindingName.duplicate() - val processedBranches = branches.zipWithIndex.map { case (branch, ix) => - val remainingBranches = branches.drop(ix + 1).toList - + val processedBranches = branches.zipWithIndex.map { case (branch, _) => desugarCaseBranch( branch, - caseExprScrutinee, branch.location, - remainingBranches, freshNameSupply ) } @@ -214,20 +192,15 @@ case object NestedPatternMatch extends IRPass { /** Desugars a case branch. * * @param branch the branch to desugar - * @param originalScrutinee the original scrutinee of the pattern match * @param topBranchLocation the location of the source branch that is being * desugared - * @param remainingBranches all subsequent branches at the current pattern - * match level * @param freshNameSupply the compiler's supply of fresh names * @return `branch`, with any nested patterns desugared */ @scala.annotation.tailrec def desugarCaseBranch( branch: IR.Case.Branch, - originalScrutinee: IR.Expression, topBranchLocation: Option[IR.IdentifiedLocation], - remainingBranches: List[IR.Case.Branch], freshNameSupply: FreshNameSupply ): IR.Case.Branch = { if (containsNestedPatterns(branch.pattern)) { @@ -241,7 +214,6 @@ case object NestedPatternMatch extends IRPass { val newField = Pattern.Name(newName, None) val nestedScrutinee = newName.duplicate() - newName.duplicate() val newFields = fields.take(nestedPosition) ++ (newField :: fields.drop( @@ -255,22 +227,20 @@ case object NestedPatternMatch extends IRPass { val newExpression = generateNestedCase( lastNestedPattern, nestedScrutinee, - originalScrutinee, - branch.expression, - remainingBranches + branch.expression ) + val newPattern1 = newPattern.duplicate() val partDesugaredBranch = IR.Case.Branch( - pattern = newPattern.duplicate(), - expression = newExpression.duplicate(), + pattern = newPattern1, + expression = newExpression.duplicate(), + terminalBranch = false, None ) desugarCaseBranch( partDesugaredBranch, - originalScrutinee, topBranchLocation, - remainingBranches, freshNameSupply ) case _: Pattern.Literal => @@ -321,41 +291,30 @@ case object NestedPatternMatch extends IRPass { * @param pattern the pattern being replaced in the desugaring * @param nestedScrutinee the name of the variable replacing `pattern` in the * branch - * @param topLevelScrutineeExpr the scrutinee of the original case expression * @param currentBranchExpr the expression executed in the current branch on * a success - * @param remainingBranches the branches to check against on a failure * @return a nested case expression of the form above */ def generateNestedCase( pattern: Pattern, nestedScrutinee: IR.Expression, - topLevelScrutineeExpr: IR.Expression, - currentBranchExpr: IR.Expression, - remainingBranches: List[IR.Case.Branch] + currentBranchExpr: IR.Expression ): IR.Expression = { - val fallbackCase = IR.Case.Expr( - topLevelScrutineeExpr.duplicate(), - remainingBranches.duplicate(), - None - ) - + val patternDuplicate = pattern.duplicate() + val finalTest = containsNestedPatterns(patternDuplicate) val patternBranch = IR.Case.Branch( - pattern.duplicate(), + patternDuplicate, currentBranchExpr.duplicate(), - None + terminalBranch = !finalTest, + location = None ) - val fallbackBranch = IR.Case.Branch( - IR.Pattern.Name(IR.Name.Blank(None), None), - fallbackCase, - None - ) IR.Case.Expr( nestedScrutinee.duplicate(), - List(patternBranch, fallbackBranch), - None + List(patternBranch), + isNested = true, + location = None ) } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala index bc01f11ef4f1..d496a4bd4e46 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala @@ -104,7 +104,7 @@ case object ShadowedPatternFields extends IRPass { */ def lintCase(cse: IR.Case): IR.Case = { cse match { - case expr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case expr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => expr.copy( scrutinee = lintExpression(scrutinee), branches = branches.map(lintCaseBranch) diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala index d11ef574da81..6f30ba55a1ad 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala @@ -217,7 +217,7 @@ case object UnusedBindings extends IRPass { */ def lintCase(cse: IR.Case, context: InlineContext): IR.Case = { cse match { - case expr @ Case.Expr(scrutinee, branches, _, _, _) => + case expr @ Case.Expr(scrutinee, branches, _, _, _, _) => expr.copy( scrutinee = runExpression(scrutinee, context), branches = branches.map(lintCaseBranch(_, context)) diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala index 210d5e64ab6e..c2c196aef836 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala @@ -120,7 +120,7 @@ case object UnreachableMatchBranches extends IRPass { //noinspection DuplicatedCode def optimizeCase(cse: IR.Case): IR.Case = { cse match { - case expr @ IR.Case.Expr(scrutinee, branches, _, _, _) => + case expr @ IR.Case.Expr(scrutinee, branches, _, _, _, _) => val reachableNonCatchAllBranches = branches.takeWhile(!isCatchAll(_)) val firstCatchAll = branches.find(isCatchAll) val unreachableBranches = diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/DocumentationComments.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/DocumentationComments.scala index b53630a441ce..7109fcf0941e 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/DocumentationComments.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/DocumentationComments.scala @@ -116,10 +116,10 @@ case object DocumentationComments extends IRPass { private def resolveBranches(items: Seq[Branch]): Seq[Branch] = { var lastDoc: Option[String] = None items.flatMap { - case Branch(IR.Pattern.Documentation(doc, _, _, _), _, _, _, _) => + case Branch(IR.Pattern.Documentation(doc, _, _, _), _, _, _, _, _) => lastDoc = Some(doc) None - case branch @ Branch(pattern, expression, _, _, _) => + case branch @ Branch(pattern, expression, _, _, _, _) => val resolved = branch.copy( pattern = pattern.mapExpressions(resolveExpression), diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala index d9435ad5574a..74c1cd47b800 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala @@ -286,7 +286,7 @@ case object IgnoredBindings extends IRPass { */ def resolveCase(cse: IR.Case, supply: FreshNameSupply): IR.Case = { cse match { - case expr @ Case.Expr(scrutinee, branches, _, _, _) => + case expr @ Case.Expr(scrutinee, branches, _, _, _, _) => expr.copy( scrutinee = resolveExpression(scrutinee, supply), branches = branches.map(resolveCaseBranch(_, supply)) diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/NestedPatternMatchTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/NestedPatternMatchTest.scala index 8fe47a37994d..7de04f2c5aa2 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/NestedPatternMatchTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/NestedPatternMatchTest.scala @@ -152,10 +152,13 @@ class NestedPatternMatchTest extends CompilerTest { val catchAllBranch = ir.branches(4) "desugar nested constructors to simple patterns" in { + ir.isNested shouldBe false + consANilBranch.expression shouldBe an[IR.Expression.Block] consANilBranch.pattern shouldBe an[IR.Pattern.Constructor] NestedPatternMatch .containsNestedPatterns(consANilBranch.pattern) shouldEqual false + consANilBranch.terminalBranch shouldBe false val nestedCase = consANilBranch.expression .asInstanceOf[IR.Expression.Block] @@ -163,10 +166,10 @@ class NestedPatternMatchTest extends CompilerTest { .asInstanceOf[IR.Case.Expr] nestedCase.scrutinee shouldBe an[IR.Name.Literal] - nestedCase.branches.length shouldEqual 2 + nestedCase.branches.length shouldEqual 1 + nestedCase.isNested shouldBe true - val nilBranch = nestedCase.branches(0) - val fallbackBranch = nestedCase.branches(1) + val nilBranch = nestedCase.branches(0) nilBranch.pattern shouldBe a[Pattern.Constructor] nilBranch.pattern @@ -175,19 +178,7 @@ class NestedPatternMatchTest extends CompilerTest { .name shouldEqual "Nil" nilBranch.expression shouldBe an[IR.Name.Literal] nilBranch.expression.asInstanceOf[IR.Name].name shouldEqual "a" - - fallbackBranch.pattern shouldBe a[Pattern.Name] - fallbackBranch.pattern - .asInstanceOf[Pattern.Name] - .name shouldBe an[IR.Name.Blank] - - fallbackBranch.expression shouldBe an[IR.Expression.Block] - fallbackBranch.expression - .asInstanceOf[IR.Expression.Block] - .returnValue - .asInstanceOf[IR.Case.Expr] - .branches - .length shouldEqual 1 + nilBranch.terminalBranch shouldBe true } "desugar deeply nested patterns to simple patterns" in { @@ -195,6 +186,7 @@ class NestedPatternMatchTest extends CompilerTest { consConsNilBranch.pattern shouldBe an[IR.Pattern.Constructor] NestedPatternMatch .containsNestedPatterns(consConsNilBranch.pattern) shouldEqual false + consConsNilBranch.terminalBranch shouldBe false val nestedCase = consConsNilBranch.expression .asInstanceOf[IR.Expression.Block] @@ -202,25 +194,19 @@ class NestedPatternMatchTest extends CompilerTest { .asInstanceOf[IR.Case.Expr] nestedCase.scrutinee shouldBe an[IR.Name.Literal] - nestedCase.branches.length shouldEqual 2 + nestedCase.branches.length shouldEqual 1 + nestedCase.isNested shouldBe true - val consBranch = nestedCase.branches(0) - val fallbackBranch1 = nestedCase.branches(1) + val consBranch = nestedCase.branches(0) consBranch.expression shouldBe an[IR.Expression.Block] - fallbackBranch1.expression shouldBe an[IR.Expression.Block] val consBranchBody = consBranch.expression .asInstanceOf[IR.Expression.Block] .returnValue .asInstanceOf[IR.Case.Expr] - val fallbackBranch1Body = - fallbackBranch1.expression - .asInstanceOf[IR.Expression.Block] - .returnValue - .asInstanceOf[IR.Case.Expr] - consBranchBody.branches.length shouldEqual 2 + consBranchBody.branches.length shouldEqual 1 consBranchBody.branches.head.expression shouldBe an[IR.Expression.Block] consBranchBody.branches.head.pattern .asInstanceOf[Pattern.Constructor] @@ -229,16 +215,6 @@ class NestedPatternMatchTest extends CompilerTest { NestedPatternMatch.containsNestedPatterns( consBranchBody.branches.head.pattern ) shouldEqual false - - fallbackBranch1Body.branches.length shouldEqual 4 - fallbackBranch1Body.branches.head.pattern shouldBe a[Pattern.Constructor] - fallbackBranch1Body.branches.head.pattern - .asInstanceOf[Pattern.Constructor] - .constructor - .name shouldEqual "Cons" - NestedPatternMatch.containsNestedPatterns( - fallbackBranch1Body.branches.head.pattern - ) shouldEqual false } "desugar deeply nested patterns with literals to simple patterns" in { @@ -246,6 +222,7 @@ class NestedPatternMatchTest extends CompilerTest { consConsOneNilBranch.pattern shouldBe an[IR.Pattern.Constructor] NestedPatternMatch .containsNestedPatterns(consConsOneNilBranch.pattern) shouldEqual false + consConsOneNilBranch.terminalBranch shouldBe false val nestedCase = consConsOneNilBranch.expression .asInstanceOf[IR.Expression.Block] @@ -253,25 +230,20 @@ class NestedPatternMatchTest extends CompilerTest { .asInstanceOf[IR.Case.Expr] nestedCase.scrutinee shouldBe an[IR.Name.Literal] - nestedCase.branches.length shouldEqual 2 + nestedCase.branches.length shouldEqual 1 + nestedCase.isNested shouldBe true - val consBranch = nestedCase.branches(0) - val fallbackBranch1 = nestedCase.branches(1) + val consBranch = nestedCase.branches(0) consBranch.expression shouldBe an[IR.Expression.Block] - fallbackBranch1.expression shouldBe an[IR.Expression.Block] + consBranch.terminalBranch shouldBe false val consBranchBody = consBranch.expression .asInstanceOf[IR.Expression.Block] .returnValue .asInstanceOf[IR.Case.Expr] - val fallbackBranch1Body = - fallbackBranch1.expression - .asInstanceOf[IR.Expression.Block] - .returnValue - .asInstanceOf[IR.Case.Expr] - consBranchBody.branches.length shouldEqual 2 + consBranchBody.branches.length shouldEqual 1 consBranchBody.branches.head.expression shouldBe an[IR.Expression.Block] consBranchBody.branches.head.pattern .asInstanceOf[Pattern.Literal] @@ -281,16 +253,8 @@ class NestedPatternMatchTest extends CompilerTest { NestedPatternMatch.containsNestedPatterns( consBranchBody.branches.head.pattern ) shouldEqual false - - fallbackBranch1Body.branches.length shouldEqual 3 - fallbackBranch1Body.branches.head.pattern shouldBe a[Pattern.Constructor] - fallbackBranch1Body.branches.head.pattern - .asInstanceOf[Pattern.Constructor] - .constructor - .name shouldEqual "Cons" - NestedPatternMatch.containsNestedPatterns( - fallbackBranch1Body.branches.head.pattern - ) shouldEqual false + consBranchBody.isNested shouldBe true + consBranchBody.branches.head.terminalBranch shouldBe true } "desugar deeply nested patterns with type pattern to simple patterns" in { @@ -300,6 +264,7 @@ class NestedPatternMatchTest extends CompilerTest { .containsNestedPatterns( consConsIntegerNilBranch.pattern ) shouldEqual false + consConsIntegerNilBranch.terminalBranch shouldBe false val nestedCase = consConsIntegerNilBranch.expression .asInstanceOf[IR.Expression.Block] @@ -307,28 +272,24 @@ class NestedPatternMatchTest extends CompilerTest { .asInstanceOf[IR.Case.Expr] nestedCase.scrutinee shouldBe an[IR.Name.Literal] - nestedCase.branches.length shouldEqual 2 + nestedCase.branches.length shouldEqual 1 + nestedCase.isNested shouldBe true - val consBranch = nestedCase.branches(0) - val fallbackBranch1 = nestedCase.branches(1) + val consBranch = nestedCase.branches(0) consBranch.expression shouldBe an[IR.Expression.Block] - fallbackBranch1.expression shouldBe an[IR.Expression.Block] + consBranch.terminalBranch shouldBe false val consBranchBody = consBranch.expression .asInstanceOf[IR.Expression.Block] .returnValue .asInstanceOf[IR.Case.Expr] - val fallbackBranch1Body = - fallbackBranch1.expression - .asInstanceOf[IR.Expression.Block] - .returnValue - .asInstanceOf[IR.Case.Expr] - consBranchBody.branches.length shouldEqual 2 + consBranchBody.branches.length shouldEqual 1 consBranchBody.branches.head.expression shouldBe an[IR.Expression.Block] val tpePattern = consBranchBody.branches.head.pattern .asInstanceOf[Pattern.Type] + consBranchBody.branches.head.terminalBranch shouldBe true tpePattern.name .asInstanceOf[IR.Name.Literal] @@ -339,46 +300,32 @@ class NestedPatternMatchTest extends CompilerTest { consBranchBody.branches.head.pattern ) shouldEqual false - consBranchBody.branches(1).pattern shouldBe an[Pattern.Name] - consBranchBody - .branches(1) - .pattern - .asInstanceOf[Pattern.Name] - .name shouldBe an[IR.Name.Blank] - val consTpeBranchBody = consBranchBody.branches.head.expression .asInstanceOf[IR.Expression.Block] .returnValue .asInstanceOf[IR.Case.Expr] - consTpeBranchBody.branches.length shouldEqual 2 + consTpeBranchBody.branches.length shouldEqual 1 consTpeBranchBody.branches.head.pattern shouldBe an[Pattern.Constructor] - consTpeBranchBody.branches(1).pattern shouldBe an[Pattern.Name] - - fallbackBranch1Body.branches.length shouldEqual 2 - fallbackBranch1Body.branches.head.pattern shouldBe a[Pattern.Constructor] - fallbackBranch1Body.branches.head.pattern - .asInstanceOf[Pattern.Constructor] - .constructor - .name shouldEqual "Cons" - NestedPatternMatch.containsNestedPatterns( - fallbackBranch1Body.branches.head.pattern - ) shouldEqual false } "work recursively" in { catchAllBranch.expression shouldBe an[IR.Expression.Block] + catchAllBranch.terminalBranch shouldBe true + val consANilCase = catchAllBranch.expression + .asInstanceOf[IR.Expression.Block] + .returnValue + .asInstanceOf[IR.Case.Expr] + + consANilCase.isNested shouldBe false + val consANilBranch2 = - catchAllBranch.expression - .asInstanceOf[IR.Expression.Block] - .returnValue - .asInstanceOf[IR.Case.Expr] - .branches - .head + consANilCase.branches.head NestedPatternMatch.containsNestedPatterns( consANilBranch2.pattern ) shouldEqual false + consANilBranch2.terminalBranch shouldBe false consANilBranch2.expression shouldBe an[IR.Expression.Block] val consANilBranch2Expr = consANilBranch2.expression @@ -386,11 +333,13 @@ class NestedPatternMatchTest extends CompilerTest { .returnValue .asInstanceOf[IR.Case.Expr] - consANilBranch2Expr.branches.length shouldEqual 2 + consANilBranch2Expr.isNested shouldBe true + consANilBranch2Expr.branches.length shouldEqual 1 consANilBranch2Expr.branches.head.pattern .asInstanceOf[Pattern.Constructor] .constructor .name shouldEqual "Nil" + consANilBranch2Expr.branches.head.terminalBranch shouldBe true } } }