Skip to content

Commit

Permalink
Fix bug in python if/else chains (#1183)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Mar 23, 2024
1 parent 5191197 commit 0bb465d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
15 changes: 14 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,20 @@ object Matchless {
l.nonEmpty || hasSideEffect(b)
}

case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr
case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr {
def flatten: (NonEmptyList[(BoolExpr, Expr)], Expr) = {
def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) =
expr match {
case If(c1, t1, e1) =>
val (ifs, e2) = combine(e1)
(((c1, t1)) :: ifs, e2)
case last => (Nil, last)
}

val (rest, last) = combine(elseExpr)
(NonEmptyList((cond, thenExpr), rest), last)
}
}
case class Always(cond: BoolExpr, thenExpr: Expr) extends Expr
def always(cond: BoolExpr, thenExpr: Expr): Expr =
if (hasSideEffect(cond)) Always(cond, thenExpr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1878,17 +1878,8 @@ object PythonGen {
// there is no need to
loop(in, slotName)
case Literal(lit) => Env.pure(Code.litToExpr(lit))
case If(cond, thenExpr, elseExpr) =>
def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) =
expr match {
case If(c1, t1, e1) =>
val (ifs, e2) = combine(e1)
(ifs :+ ((c1, t1)), e2)
case last => (Nil, last)
}

val (rest, last) = combine(elseExpr)
val ifs = NonEmptyList((cond, thenExpr), rest)
case ifExpr @ If(_, _, _) =>
val (ifs, last) = ifExpr.flatten

val ifsV = ifs.traverse { case (c, t) =>
(boolExpr(c, slotName), loop(t, slotName)).tupled
Expand Down
24 changes: 23 additions & 1 deletion core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.scalatest.funsuite.AnyFunSuite
class MatchlessTest extends AnyFunSuite {
implicit val generatorDrivenConfig: PropertyCheckConfiguration =
PropertyCheckConfiguration(minSuccessful =
if (Platform.isScalaJvm) 1000 else 20
if (Platform.isScalaJvm) 5000 else 20
)

type Fn = (PackageName, Constructor) => Option[DataRepr]
Expand Down Expand Up @@ -59,6 +59,11 @@ class MatchlessTest extends AnyFunSuite {
}
}

val genMatchlessExpr: Gen[Matchless.Expr] =
genInputs.map { case (b, r, t, fn) =>
Matchless.fromLet(b, r, t)(fn)
}

test("regressions") {
// this is illegal code, but it shouldn't throw a match error:
val name = Identifier.Name("foo")
Expand Down Expand Up @@ -151,4 +156,21 @@ class MatchlessTest extends AnyFunSuite {
assert(matchlessRes == matchRes)
}
}

test("If.flatten can be unflattened") {
forAll(genMatchlessExpr) {
case ifexpr @ Matchless.If(_, _, _) =>
val (chain, rest) = ifexpr.flatten
def unflatten(ifs: NonEmptyList[(Matchless.BoolExpr, Matchless.Expr)], elseX: Matchless.Expr): Matchless.If =
ifs.tail match {
case Nil => Matchless.If(ifs.head._1, ifs.head._2, elseX)
case head :: next =>
val end = unflatten(NonEmptyList(head, next), elseX)
Matchless.If(ifs.head._1, ifs.head._2, end)
}

assert(unflatten(chain, rest) == ifexpr)
case _ => ()
}
}
}
8 changes: 4 additions & 4 deletions test_workspace/BinNat.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def toBinNat(n: Int) -> BinNat:

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
Expand All @@ -71,6 +67,10 @@ def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
case GT | EQ: GT
case LT: LT
case Zero: GT
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ

# this is more efficient potentially than cmp_BinNat
# because at the first difference we can stop. In the worst
Expand Down

0 comments on commit 0bb465d

Please sign in to comment.