diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index e27f0a639..5efaeacb8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -6,6 +6,7 @@ import org.typelevel.paiges.Doc import cats.implicits._ import Identifier.Bindable +import cats.Applicative /** * Recursion in bosatsu is only allowed on a substructural match @@ -264,6 +265,47 @@ object DefRecursionCheck { */ type St[A] = StateT[Either[NonEmptyList[RecursionError], *], State, A] + implicit val parallelSt: cats.Parallel[St] = { + val m = cats.Monad[St] + + new cats.Parallel[St] { + type F[A] = St[A] + def parallel = new cats.arrow.FunctionK[St, F] { + def apply[A](st: St[A]) = st + } + def sequential = new cats.arrow.FunctionK[F, St] { + def apply[A](st: St[A]) = st + } + def monad = m + def applicative: Applicative[F] = + new Applicative[F] { + def pure[A](a: A) = m.pure(a) + def ap[A, B](ff: F[A => B])(fa: F[A]): F[B] = + map(product(ff, fa)) { case (fn, a) => fn(a) } + + override def product[A, B](fa: F[A], fb: F[B]): F[(A, B)] = { + type E[+X] = Either[NonEmptyList[RecursionError], X] + val fna: E[State => E[(State, A)]] = fa.runF + val fnb: E[State => E[(State, B)]] = fb.runF + + new cats.data.IndexedStateT((fna, fnb).parMapN { (fn1, fn2) => + { (state: State) => + fn1(state) match { + case Right((s2, a)) => fn2(s2).map { case (st, b) => (st, (a, b)) } + case Left(nel1) => + // just skip and merge + fn2(state) match { + case Right(_) => Left(nel1) + case Left(nel2) => Left(nel1 ::: nel2) + } + } + } + }) + } + override def map[A, B](fa: F[A])(fn: A => B) = m.map(fa)(fn) + } + } + } // Scala has trouble infering types like St, we we make these typed // helper functions to use below def failSt[A](err: RecursionError): St[A] = @@ -331,7 +373,7 @@ object DefRecursionCheck { getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: - checkDecl(fn) *> args.traverse_(checkDecl) + checkDecl(fn) *> args.parTraverse_(checkDecl) case irb@InRecurBranch(inrec, branch, names) => argsOnDefName(fn, NonEmptyList.one(args)) match { @@ -353,7 +395,7 @@ object DefRecursionCheck { } else if (names.contains(nm)) { // we are calling a reachable function. Any lambda args are new names: - args.traverse_[St, Unit] { + args.parTraverse_[St, Unit] { case Declaration.Lambda(args, body) => val names1 = args.toList.flatMap(_.names) unionNames(names1)(checkDecl(body)) @@ -367,18 +409,18 @@ object DefRecursionCheck { else { // traverse converting Var(name) to the lambda version to use the above check // not a recursive call - args.traverse_(checkDecl) + args.parTraverse_(checkDecl) } case None => // this isn't a recursive call - checkDecl(fn) *> args.traverse_(checkDecl) + checkDecl(fn) *> args.parTraverse_(checkDecl) } case ir: InDefState => // we have either not yet, or already done the recursion argsOnDefName(fn, NonEmptyList.one(args)) match { case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region)) case _ => - checkDecl(fn) *> args.traverse_(checkDecl) + checkDecl(fn) *> args.parTraverse_(checkDecl) } } /* @@ -409,7 +451,7 @@ object DefRecursionCheck { defn *> nextRes } case IfElse(ifCases, elseCase) => - val ifs = ifCases.traverse_ { case (d, od) => + val ifs = ifCases.parTraverse_ { case (d, od) => checkDecl(d) *> checkDecl(od.get) } val e = checkDecl(elseCase.get) @@ -428,7 +470,7 @@ object DefRecursionCheck { case Match(RecursionKind.NonRecursive, arg, cases) => // the arg can't use state, but cases introduce new bindings: val argRes = checkDecl(arg) - val optRes = cases.get.traverse_ { case (pat, next) => + val optRes = cases.get.parTraverse_ { case (pat, next) => checkForIllegalBindsSt(pat.names, decl) *> filterNames(pat.names)(checkDecl(next.get)) } @@ -464,7 +506,7 @@ object DefRecursionCheck { // $COVERAGE-ON$ } - cases.get.traverse_ { case (pat, next) => + cases.get.parTraverse_ { case (pat, next) => for { _ <- checkForIllegalBindsSt(pat.names, decl) _ <- beginBranch(pat) @@ -480,7 +522,7 @@ object DefRecursionCheck { case Parens(p) => checkDecl(p) case TupleCons(tups) => - tups.traverse_(checkDecl) + tups.parTraverse_(checkDecl) case Var(Identifier.Constructor(_)) => unitSt case Var(v: Bindable) => @@ -494,28 +536,28 @@ object DefRecursionCheck { else unitSt } case StringDecl(parts) => - parts.traverse_ { + parts.parTraverse_ { case Left(nb) => checkDecl(nb) case Right(_) => unitSt } case ListDecl(ll) => ll match { case ListLang.Cons(items) => - items.traverse_ { s => checkDecl(s.value) } + items.parTraverse_ { s => checkDecl(s.value) } case ListLang.Comprehension(e, _, i, f) => checkDecl(e.value) *> checkDecl(i) *> - (f.traverse_(checkDecl)) + (f.parTraverse_(checkDecl)) } case DictDecl(ll) => ll match { case ListLang.Cons(items) => - items.traverse_ { s => checkDecl(s.key) *> checkDecl(s.value) } + items.parTraverse_ { s => checkDecl(s.key) *> checkDecl(s.value) } case ListLang.Comprehension(e, _, i, f) => checkDecl(e.key) *> checkDecl(e.value) *> checkDecl(i) *> - (f.traverse_(checkDecl)) + (f.parTraverse_(checkDecl)) } case RecordConstructor(_, args) => @@ -526,7 +568,7 @@ object DefRecursionCheck { case RecordArg.Pair(_, v) => checkDecl(v) } - args.traverse_(checkArg) + args.parTraverse_(checkArg) } } diff --git a/test_workspace/Eval.bosatsu b/test_workspace/Eval.bosatsu new file mode 100644 index 000000000..9f733c4aa --- /dev/null +++ b/test_workspace/Eval.bosatsu @@ -0,0 +1,51 @@ +package Eval + +from Bosatsu/Nat import Nat, Zero, Succ, to_Nat + +export Eval, done, map, flat_map, bind, eval +# port of cats.Eval to bosatsu + +enum Eval[a]: + Pure(a: a) + FlatMap(use: forall x. (forall y. (Eval[y], y -> Eval[a]) -> x) -> x) + +def done(a: a) -> Eval[a]: Pure(a) + +def flat_map[a, b](e: Eval[a], fn: a -> Eval[b]) -> Eval[b]: + FlatMap(cb -> cb(e, fn)) + +def bind(e)(fn): flat_map(e, fn) + +def map[a, b](e: Eval[a], fn: a -> b) -> Eval[b]: + x <- e.bind() + Pure(fn(x)) + +enum Stack[a, b]: + Done(fn: a -> b) + More(use: forall x. (forall y. (a -> Eval[y], Stack[y, b]) -> x) -> x) + +def push_stack[a, b, c](fn: a -> Eval[b], stack: Stack[b, c]) -> Stack[a, c]: + More(use -> use(fn, stack)) + +enum Loop[a, b]: + RunStack(a: a, stack: Stack[a, b]) + RunEval(e: Eval[a], stack: Stack[a, b]) + +def run[a, b](budget: Nat, arg: Loop[a, b]) -> Option[b]: + recur budget: + case Zero: None + case Succ(balance): + match arg: + case RunStack(a, Done(fn)): Some(fn(a)) + case RunEval(Pure(a), stack): + run(balance, RunStack(a, stack)) + case RunEval(FlatMap(use), stack): + use((prev, fn) -> run(balance, RunEval(prev, push_stack(fn, stack)))) + case RunStack(a, More(use)): + use((fn, stack) -> ( + evalb = fn(a) + run(balance, RunEval(evalb, stack)) + )) + +def eval[a](budget: Int, ea: Eval[a]) -> Option[a]: + run(to_Nat(budget), RunEval(ea, Done(a -> a))) \ No newline at end of file