Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove SearchList from Matchless #1300

Merged
merged 4 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 169 additions & 53 deletions core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,6 @@ object Matchless {
size: Int,
famArities: List[Int]
) extends BoolExpr
// handle list matching, this is a while loop, that is evaluting
// lst is initialized to init, leftAcc is initialized to empty
// tail until it is true while mutating lst => lst.tail
// this has the side-effect of mutating lst and leftAcc as well as any side effects that check has
// which could have nested searches of its own
case class SearchList(
lst: LocalAnonMut,
init: CheapExpr,
check: BoolExpr,
leftAcc: Option[LocalAnonMut]
) extends BoolExpr
// set the mutable variable to the given expr and return true
// string matching is complex done at a lower level
case class MatchString(
Expand All @@ -128,6 +117,17 @@ object Matchless {
// set the mutable variable to the given expr and return true
case class SetMut(target: LocalAnonMut, expr: Expr) extends BoolExpr
case object TrueConst extends BoolExpr
case class LetBool(
arg: Either[LocalAnon, Bindable],
expr: Expr,
in: BoolExpr
) extends BoolExpr

case class LetMutBool(name: LocalAnonMut, span: BoolExpr) extends BoolExpr
object LetMutBool {
def apply(lst: List[LocalAnonMut], span: BoolExpr): BoolExpr =
lst.foldRight(span)(LetMutBool(_, _))
}

def hasSideEffect(bx: BoolExpr): Boolean =
bx match {
Expand All @@ -137,8 +137,31 @@ object Matchless {
false
case MatchString(_, _, b, _) => b.nonEmpty
case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2)
case SearchList(_, _, b, l) =>
l.nonEmpty || hasSideEffect(b)
case LetBool(_, x, b) =>
hasSideEffect(b) || hasSideEffect(x)
case LetMutBool(_, b) => hasSideEffect(b)
}

def hasSideEffect(bx: Expr): Boolean =
bx match {
case _: CheapExpr => false
case Always(b, x) => hasSideEffect(b) || hasSideEffect(x)
case App(f, as) =>
(f :: as).exists(hasSideEffect(_))
case If(c, t, f) =>
hasSideEffect(c) || hasSideEffect(t) || hasSideEffect(f)
case Let(_, x, b) =>
hasSideEffect(b) || hasSideEffect(x)
case LetMut(_, in) => hasSideEffect(in)
case PrevNat(n) => hasSideEffect(n)
case MakeEnum(_, _, _) | MakeStruct(_) | SuccNat | ZeroNat | Lambda(_, _, _, _) =>
// making a lambda or const is a pure function
false
case WhileExpr(_, _, _) =>
// not all while loops have side effects technically, but we assume yes
// for now. We could have a list of all the known control mutables here
// but it seems hard
true
}

case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr {
Expand Down Expand Up @@ -191,13 +214,30 @@ object Matchless {
extends ConsExpr

private val boolFamArities = 0 :: 0 :: Nil
private val listFamArities = 0 :: 2 :: Nil
val FalseExpr: Expr = MakeEnum(0, 0, boolFamArities)
val TrueExpr: Expr = MakeEnum(1, 0, boolFamArities)
val UnitExpr: Expr = MakeStruct(0)

def isTrueExpr(e: CheapExpr): BoolExpr =
CheckVariant(e, 1, 0, boolFamArities)


object ListExpr {
val Nil: Expr = MakeEnum(0, 0, listFamArities)
private val consFn = MakeEnum(1, 2, listFamArities)

def cons(h: Expr, t: Expr): Expr =
App(consFn, NonEmptyList(h, t :: List.empty))

def notNil(e: CheapExpr): BoolExpr =
CheckVariant(e, 1, 2, listFamArities)

def head(arg: CheapExpr): CheapExpr =
GetEnumElement(arg, 1, 0, 2)

def tail(arg: CheapExpr): CheapExpr =
GetEnumElement(arg, 1, 1, 2)
}
case class MakeStruct(arity: Int) extends ConsExpr
case object ZeroNat extends ConsExpr {
def arity = 0
Expand Down Expand Up @@ -311,67 +351,69 @@ object Matchless {
def inLet(b: Bindable): LambdaState = copy(name = Some(b))
}

def translateLocalsBool(m: Map[Bindable, LocalAnonMut], e: BoolExpr): BoolExpr =
def substituteLocalsBool(m: Map[Bindable, CheapExpr], e: BoolExpr): BoolExpr =
e match {
case SetMut(mut, e) => SetMut(mut, translateLocals(m, e))
case SetMut(mut, e) => SetMut(mut, substituteLocals(m, e))
case And(b1, b2) =>
And(translateLocalsBool(m, b1), translateLocalsBool(m, b2))
And(substituteLocalsBool(m, b1), substituteLocalsBool(m, b2))
case EqualsLit(x, l) =>
EqualsLit(translateLocalsCheap(m, x), l)
EqualsLit(substituteLocalsCheap(m, x), l)
case EqualsNat(x, n) =>
EqualsNat(translateLocalsCheap(m, x), n)
EqualsNat(substituteLocalsCheap(m, x), n)
case TrueConst => TrueConst
case CheckVariant(expr, expect, sz, fam) =>
CheckVariant(translateLocalsCheap(m, expr), expect, sz, fam)
CheckVariant(substituteLocalsCheap(m, expr), expect, sz, fam)
case ms: MatchString =>
ms.copy(arg = translateLocalsCheap(m, ms.arg))
case sl: SearchList =>
sl.copy(
init = translateLocalsCheap(m, sl.init),
check = translateLocalsBool(m, sl.check)
)
ms.copy(arg = substituteLocalsCheap(m, ms.arg))
case LetBool(b, a, in) =>
val m1 = b match {
case Right(b) => m - b
case _ => m
}
LetBool(b, substituteLocals(m, a), substituteLocalsBool(m1, in))
case LetMutBool(b, in) =>
LetMutBool(b, substituteLocalsBool(m, in))
}

def translateLocals(m: Map[Bindable, LocalAnonMut], e: Expr): Expr =
def substituteLocals(m: Map[Bindable, CheapExpr], e: Expr): Expr =
e match {
case App(fn, appArgs) =>
App(translateLocals(m, fn), appArgs.map(translateLocals(m, _)))
App(substituteLocals(m, fn), appArgs.map(substituteLocals(m, _)))
case If(c, tcase, fcase) =>
If(translateLocalsBool(m, c), translateLocals(m, tcase), translateLocals(m, fcase))
If(substituteLocalsBool(m, c), substituteLocals(m, tcase), substituteLocals(m, fcase))
case Always(c, e) =>
Always(translateLocalsBool(m, c), translateLocals(m, e))
Always(substituteLocalsBool(m, c), substituteLocals(m, e))
case LetMut(mut, e) =>
LetMut(mut, translateLocals(m, e))
LetMut(mut, substituteLocals(m, e))
case Let(n, v, in) =>
val m1 = n match {
case Right(b) => m - b
case _ => m
}
Let(n, translateLocals(m, v), translateLocals(m1, in))
// the rest cannot have a call in tail position
Let(n, substituteLocals(m, v), substituteLocals(m1, in))
case Local(n) =>
m.get(n) match {
case Some(mut) => mut
case None => e
}
case PrevNat(n) => PrevNat(translateLocals(m, n))
case PrevNat(n) => PrevNat(substituteLocals(m, n))
case ge: GetEnumElement =>
ge.copy(arg = translateLocalsCheap(m, ge.arg))
ge.copy(arg = substituteLocalsCheap(m, ge.arg))
case gs: GetStructElement =>
gs.copy(arg = translateLocalsCheap(m, gs.arg))
gs.copy(arg = substituteLocalsCheap(m, gs.arg))
case Lambda(c, r, as, b) =>
val m1 = m -- as.toList
val b1 = translateLocals(m1, b)
val b1 = substituteLocals(m1, b)
Lambda(c, r, as, b1)
case WhileExpr(c, ef, r) =>
WhileExpr(translateLocalsBool(m, c), translateLocals(m, ef), r)
WhileExpr(substituteLocalsBool(m, c), substituteLocals(m, ef), r)
case ClosureSlot(_) | Global(_, _) | LocalAnon(_) | LocalAnonMut(_) |
MakeEnum(_, _, _) | MakeStruct(_) | SuccNat | Literal(_) | ZeroNat => e
}
def translateLocalsCheap(m: Map[Bindable, LocalAnonMut], e: CheapExpr): CheapExpr =
translateLocals(m, e) match {
def substituteLocalsCheap(m: Map[Bindable, CheapExpr], e: CheapExpr): CheapExpr =
substituteLocals(m, e) match {
case ch: CheapExpr => ch
case notCheap => sys.error(s"invariant violation: translation didn't maintain cheap: $e => $notCheap")
case notCheap => sys.error(s"invariant violation: substitution didn't maintain cheap: $e => $notCheap")
}

def loopFn(
Expand All @@ -380,10 +422,6 @@ object Matchless {
args: NonEmptyList[Bindable],
body: Expr): F[Expr] = {

def setAll(ls: List[(LocalAnonMut, Expr)], ret: Expr): Expr =
ls.foldRight(ret) { case ((l, e), r) =>
Always(SetMut(l, e), r)
}
// assign any results to result and set the condition to false
// and replace any tail calls to nm(args) with assigning args to those values
case class ArgRecord(name: Bindable, tmp: LocalAnon, loopVar: LocalAnonMut)
Expand Down Expand Up @@ -445,7 +483,7 @@ object Matchless {
MakeEnum(_, _, _) | MakeStruct(_) | PrevNat(_) | SuccNat | WhileExpr(_, _, _) | ZeroNat => None
}

val bodyTrans = translateLocals(
val bodyTrans = substituteLocals(
args.toList.map(a => (a.name, a.loopVar)).toMap,
body)

Expand Down Expand Up @@ -631,6 +669,82 @@ object Matchless {
}
}

// handle list matching, this is a while loop, that is evaluting
// lst is initialized to init, leftAcc is initialized to empty
// tail until it is true while mutating lst => lst.tail
// this has the side-effect of mutating lst and leftAcc as well as any side effects that check has
// which could have nested searches of its own
def searchList(
lst: LocalAnonMut,
init: CheapExpr,
check: BoolExpr,
leftAcc: Option[LocalAnonMut]
): F[BoolExpr] = {
(
makeAnon.map(LocalAnonMut(_)),
makeAnon.map(LocalAnon(_)),
makeAnon.map(LocalAnonMut(_))
)
.mapN { (resMut, letBind, currentList) =>
val initSets =
(resMut, FalseExpr) ::
(currentList, init) ::
(leftAcc.toList.map { left =>
(left, ListExpr.Nil)
})

val whileCheck = ListExpr.notNil(currentList)
val effect: Expr = {
setAll((lst, currentList) :: Nil,
If(check, {
setAll(
(currentList, ListExpr.Nil) ::
(resMut, TrueExpr) ::
Nil,
UnitExpr
)
}, {
setAll(
(currentList, ListExpr.tail(currentList)) ::
leftAcc.toList.map { left =>
(left, ListExpr.cons(ListExpr.head(currentList), left))
},
UnitExpr
)
}))
}
val searchLoop = setAll(initSets, WhileExpr(whileCheck, effect, resMut))

LetMutBool(resMut :: currentList :: Nil,
LetBool(Left(letBind), searchLoop, isTrueExpr(resMut)))
}
/*
Dynamic { (scope: Scope) =>
var res = false
var currentList = initF(scope)
var leftList = VList.VNil
scope.updateMut(left, leftList)
while (currentList ne null) {
currentList match {
case nonempty @ VList.Cons(head, tail) =>
scope.updateMut(mutV, nonempty)
res = checkF(scope)
if (res) { currentList = null }
else {
currentList = tail
leftList = VList.Cons(head, leftList)
scope.updateMut(left, leftList)
}
case _ =>
currentList = null
// we don't match empty lists
}
}
res
}
*/
}

// return the check expression for the check we need to do, and the list of bindings
// if must match is true, we know that the pattern must match, so we can potentially remove some checks
def doesMatch(
Expand Down Expand Up @@ -708,8 +822,8 @@ object Matchless {
val anonList = LocalAnonMut(tmpList)

doesMatch(anonList, Pattern.ListPat(right.toList), false)
.map { cases =>
cases.map {
.flatMap { cases =>
cases.traverse {
case (_, TrueConst, _) =>
// $COVERAGE-OFF$

Expand Down Expand Up @@ -737,11 +851,8 @@ object Matchless {
(letTail, None, binds)
}

(
resLet,
SearchList(anonList, arg, expr, leftOpt),
resBind
)
searchList(anonList, arg, expr, leftOpt)
.map { s => (resLet, s, resBind) }
}
}
}
Expand Down Expand Up @@ -969,6 +1080,11 @@ object Matchless {
LetMut(anon, rest)
}

def setAll(ls: List[(LocalAnonMut, Expr)], ret: Expr): Expr =
ls.foldRight(ret) { case ((l, e), r) =>
Always(SetMut(l, e), r)
}

def matchExpr(
arg: Expr,
tmp: F[Long],
Expand Down
Loading
Loading