Skip to content

Commit

Permalink
Improve error messages in DefRecursionCheck, Eval example
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 17, 2023
1 parent 7890b7b commit 2f3f5e2
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 15 deletions.
72 changes: 57 additions & 15 deletions core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.typelevel.paiges.Doc
import cats.implicits._

import Identifier.Bindable
import cats.Applicative

/**
* Recursion in bosatsu is only allowed on a substructural match
Expand Down Expand Up @@ -264,6 +265,47 @@ object DefRecursionCheck {
*/
type St[A] = StateT[Either[NonEmptyList[RecursionError], *], State, A]

implicit val parallelSt: cats.Parallel[St] = {
val m = cats.Monad[St]

new cats.Parallel[St] {
type F[A] = St[A]
def parallel = new cats.arrow.FunctionK[St, F] {
def apply[A](st: St[A]) = st
}
def sequential = new cats.arrow.FunctionK[F, St] {
def apply[A](st: St[A]) = st
}
def monad = m
def applicative: Applicative[F] =
new Applicative[F] {
def pure[A](a: A) = m.pure(a)
def ap[A, B](ff: F[A => B])(fa: F[A]): F[B] =
map(product(ff, fa)) { case (fn, a) => fn(a) }

override def product[A, B](fa: F[A], fb: F[B]): F[(A, B)] = {
type E[+X] = Either[NonEmptyList[RecursionError], X]
val fna: E[State => E[(State, A)]] = fa.runF
val fnb: E[State => E[(State, B)]] = fb.runF

new cats.data.IndexedStateT((fna, fnb).parMapN { (fn1, fn2) =>
{ (state: State) =>
fn1(state) match {
case Right((s2, a)) => fn2(s2).map { case (st, b) => (st, (a, b)) }
case Left(nel1) =>
// just skip and merge
fn2(state) match {
case Right(_) => Left(nel1)
case Left(nel2) => Left(nel1 ::: nel2)
}
}
}
})
}
override def map[A, B](fa: F[A])(fn: A => B) = m.map(fa)(fn)
}
}
}
// Scala has trouble infering types like St, we we make these typed
// helper functions to use below
def failSt[A](err: RecursionError): St[A] =
Expand Down Expand Up @@ -331,7 +373,7 @@ object DefRecursionCheck {
getSt.flatMap {
case TopLevel =>
// without any recursion, normal typechecking will detect bad states:
checkDecl(fn) *> args.traverse_(checkDecl)
checkDecl(fn) *> args.parTraverse_(checkDecl)
case irb@InRecurBranch(inrec, branch, names) =>

argsOnDefName(fn, NonEmptyList.one(args)) match {
Expand All @@ -353,7 +395,7 @@ object DefRecursionCheck {
}
else if (names.contains(nm)) {
// we are calling a reachable function. Any lambda args are new names:
args.traverse_[St, Unit] {
args.parTraverse_[St, Unit] {
case Declaration.Lambda(args, body) =>
val names1 = args.toList.flatMap(_.names)
unionNames(names1)(checkDecl(body))
Expand All @@ -367,18 +409,18 @@ object DefRecursionCheck {
else {
// traverse converting Var(name) to the lambda version to use the above check
// not a recursive call
args.traverse_(checkDecl)
args.parTraverse_(checkDecl)
}
case None =>
// this isn't a recursive call
checkDecl(fn) *> args.traverse_(checkDecl)
checkDecl(fn) *> args.parTraverse_(checkDecl)
}
case ir: InDefState =>
// we have either not yet, or already done the recursion
argsOnDefName(fn, NonEmptyList.one(args)) match {
case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region))
case _ =>
checkDecl(fn) *> args.traverse_(checkDecl)
checkDecl(fn) *> args.parTraverse_(checkDecl)
}
}
/*
Expand Down Expand Up @@ -409,7 +451,7 @@ object DefRecursionCheck {
defn *> nextRes
}
case IfElse(ifCases, elseCase) =>
val ifs = ifCases.traverse_ { case (d, od) =>
val ifs = ifCases.parTraverse_ { case (d, od) =>
checkDecl(d) *> checkDecl(od.get)
}
val e = checkDecl(elseCase.get)
Expand All @@ -428,7 +470,7 @@ object DefRecursionCheck {
case Match(RecursionKind.NonRecursive, arg, cases) =>
// the arg can't use state, but cases introduce new bindings:
val argRes = checkDecl(arg)
val optRes = cases.get.traverse_ { case (pat, next) =>
val optRes = cases.get.parTraverse_ { case (pat, next) =>
checkForIllegalBindsSt(pat.names, decl) *>
filterNames(pat.names)(checkDecl(next.get))
}
Expand Down Expand Up @@ -464,7 +506,7 @@ object DefRecursionCheck {
// $COVERAGE-ON$
}

cases.get.traverse_ { case (pat, next) =>
cases.get.parTraverse_ { case (pat, next) =>
for {
_ <- checkForIllegalBindsSt(pat.names, decl)
_ <- beginBranch(pat)
Expand All @@ -480,7 +522,7 @@ object DefRecursionCheck {
case Parens(p) =>
checkDecl(p)
case TupleCons(tups) =>
tups.traverse_(checkDecl)
tups.parTraverse_(checkDecl)
case Var(Identifier.Constructor(_)) =>
unitSt
case Var(v: Bindable) =>
Expand All @@ -494,28 +536,28 @@ object DefRecursionCheck {
else unitSt
}
case StringDecl(parts) =>
parts.traverse_ {
parts.parTraverse_ {
case Left(nb) => checkDecl(nb)
case Right(_) => unitSt
}
case ListDecl(ll) =>
ll match {
case ListLang.Cons(items) =>
items.traverse_ { s => checkDecl(s.value) }
items.parTraverse_ { s => checkDecl(s.value) }
case ListLang.Comprehension(e, _, i, f) =>
checkDecl(e.value) *>
checkDecl(i) *>
(f.traverse_(checkDecl))
(f.parTraverse_(checkDecl))
}
case DictDecl(ll) =>
ll match {
case ListLang.Cons(items) =>
items.traverse_ { s => checkDecl(s.key) *> checkDecl(s.value) }
items.parTraverse_ { s => checkDecl(s.key) *> checkDecl(s.value) }
case ListLang.Comprehension(e, _, i, f) =>
checkDecl(e.key) *>
checkDecl(e.value) *>
checkDecl(i) *>
(f.traverse_(checkDecl))
(f.parTraverse_(checkDecl))
}

case RecordConstructor(_, args) =>
Expand All @@ -526,7 +568,7 @@ object DefRecursionCheck {
case RecordArg.Pair(_, v) =>
checkDecl(v)
}
args.traverse_(checkArg)
args.parTraverse_(checkArg)
}
}

Expand Down
51 changes: 51 additions & 0 deletions test_workspace/Eval.bosatsu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package Eval

from Bosatsu/Nat import Nat, Zero, Succ, to_Nat

export Eval, done, map, flat_map, bind, eval
# port of cats.Eval to bosatsu

enum Eval[a]:
Pure(a: a)
FlatMap(use: forall x. (forall y. (Eval[y], y -> Eval[a]) -> x) -> x)

def done(a: a) -> Eval[a]: Pure(a)

def flat_map[a, b](e: Eval[a], fn: a -> Eval[b]) -> Eval[b]:
FlatMap(cb -> cb(e, fn))

def bind(e)(fn): flat_map(e, fn)

def map[a, b](e: Eval[a], fn: a -> b) -> Eval[b]:
x <- e.bind()
Pure(fn(x))

enum Stack[a, b]:
Done(fn: a -> b)
More(use: forall x. (forall y. (a -> Eval[y], Stack[y, b]) -> x) -> x)

def push_stack[a, b, c](fn: a -> Eval[b], stack: Stack[b, c]) -> Stack[a, c]:
More(use -> use(fn, stack))

enum Loop[a, b]:
RunStack(a: a, stack: Stack[a, b])
RunEval(e: Eval[a], stack: Stack[a, b])

def run[a, b](budget: Nat, arg: Loop[a, b]) -> Option[b]:
recur budget:
case Zero: None
case Succ(balance):
match arg:
case RunStack(a, Done(fn)): Some(fn(a))
case RunEval(Pure(a), stack):
run(balance, RunStack(a, stack))
case RunEval(FlatMap(use), stack):
use((prev, fn) -> run(balance, RunEval(prev, push_stack(fn, stack))))
case RunStack(a, More(use)):
use((fn, stack) -> (
evalb = fn(a)
run(balance, RunEval(evalb, stack))
))

def eval[a](budget: Int, ea: Eval[a]) -> Option[a]:
run(to_Nat(budget), RunEval(ea, Done(a -> a)))

0 comments on commit 2f3f5e2

Please sign in to comment.